aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorericmarin <maarin.eric@gmail.com>2026-03-10 11:05:26 +0100
committerericmarin <maarin.eric@gmail.com>2026-03-10 12:05:35 +0100
commit7933d744e06337f1d69b7da83f2cee1611556097 (patch)
tree2a1594ab6f8d2cfeb784dee80e2d830d9999e4a5
parent51cd389b4e322313671dd0e53513ce84b72a1652 (diff)
downloadvein-7933d744e06337f1d69b7da83f2cee1611556097.tar.gz
vein-7933d744e06337f1d69b7da83f2cee1611556097.zip
added prover script
-rw-r--r--nn.in73
-rw-r--r--parser.py53
-rw-r--r--prover.py66
-rw-r--r--test.smt27
4 files changed, 112 insertions, 87 deletions
diff --git a/nn.in b/nn.in
index cfb5de9..3cc9e09 100644
--- a/nn.in
+++ b/nn.in
@@ -1,3 +1,19 @@
+// Agents
+// Built-in
+// Eraser: delete other agents recursively
+// Dup: duplicates other agents recursively
+
+// Implemented
+// Linear(x, int q, int r): represent "q*x + r"
+// Concrete(int k): represent a concrete value k
+// Symbolic(id): represent the variable id
+// Add(out, b): represent the addition (has various steps AddCheckLinear/AddCheckConcrete)
+// Mul(out, b): represent the multiplication (has various steps MulCheckLinear/MulCheckConcrete)
+// ReLU(out): represent "if x > 0 ? x ; 0"
+// Materialize(out): transforms a Linear packet into a final representation of TermAdd/TermMul/TermReLU
+
+// TODO: add range information to enable ReLU elimination
+
// Rules
Linear(x, int q, int r) >< Add(out, b) => b ~ AddCheckLinear(out, x, q, r);
@@ -6,10 +22,10 @@ Concrete(int k) >< Add(out, b)
| _ => b ~ AddCheckConcrete(out, k);
Linear(y, int s, int t) >< AddCheckLinear(out, x, int q, int r)
- | q == 0 && r == 0 && s == 0 && t == 0 => out ~ Concrete(0), x ~ Eraser, y ~ Eraser
- | s == 0 && t == 0 => Linear(x, q, r) ~ Materialize(out), y ~ Eraser
- | q == 0 && r == 0 => (*L)Linear(y, s, t) ~ Materialize(out), x ~ Eraser
- | _ => Linear(x, q, r) ~ Materialize(out_x), (*L)Linear(y, s, r) ~ Materialize(out_y), out ~ TermAdd(out_x, out_y);
+ | (q == 0) && (r == 0) && (s == 0) && (t == 0) => out ~ Concrete(0), x ~ Eraser, y ~ Eraser
+ | (s == 0) && (t == 0) => Linear(x, q, r) ~ Materialize(out), y ~ Eraser
+ | (q == 0) && (r == 0) => (*L)Linear(y, s, t) ~ Materialize(out), x ~ Eraser
+ | _ => Linear(x, q, r) ~ Materialize(out_x), (*L)Linear(y, s, t) ~ Materialize(out_y), out ~ Linear(TermAdd(out_x, out_y), 1, 0);
Concrete(int j) >< AddCheckLinear(out, x, int q, int r) => out ~ Linear(x, q, r + j);
@@ -19,20 +35,20 @@ Concrete(int j) >< AddCheckConcrete(out, int k)
| j == 0 => out ~ Concrete(k)
| _ => out ~ Concrete(k + j);
-Linear(x, int q, int r) >< Mul(out, b) => b ~ MulCheck(out, x, q, r);
+Linear(x, int q, int r) >< Mul(out, b) => b ~ MulCheckLinear(out, x, q, r);
Concrete(int k) >< Mul(out, b)
| k == 0 => b ~ Eraser, out ~ (*L)Concrete(0)
| k == 1 => out ~ b
| _ => b ~ MulCheckConcrete(out, k);
-Linear(y, int s, int t) >< MulCheck(out, x, int q, int r)
- | q == 0 && r == 0 && s == 0 && t == 0 => out ~ Concrete(0), x ~ Eraser, y ~ Eraser
- | s == 0 && t == 0 => Linear(x, q, r) ~ Materialize(out), y ~ Eraser
- | q == 0 && r == 0 => (*L)Linear(y, s, t) ~ Materialize(out), x ~ Eraser
- | _ => Linear(x, q, r) ~ Materialize(out_x), (*L)Linear(y, s, r) ~ Materialize(out_y), out ~ TermMul(out_x, out_y);
+Linear(y, int s, int t) >< MulCheckLinear(out, x, int q, int r)
+ | (q == 0) && (r == 0) && (s == 0) && (t == 0) => out ~ Concrete(0), x ~ Eraser, y ~ Eraser
+ | (s == 0) && (t == 0) => Linear(x, q, r) ~ Materialize(out), y ~ Eraser
+ | (q == 0) && (r == 0) => (*L)Linear(y, s, t) ~ Materialize(out), x ~ Eraser
+ | _ => Linear(x, q, r) ~ Materialize(out_x), (*L)Linear(y, s, t) ~ Materialize(out_y), out ~ Linear(TermMul(out_x, out_y), 1, 0);
-Concrete(int j) >< MulCheck(out, x, int q, int r) => out ~ Linear(x, q * j, r * j);
+Concrete(int j) >< MulCheckLinear(out, x, int q, int r) => out ~ Linear(x, q * j, r * j);
Linear(y, int s, int t) >< MulCheckConcrete(out, int k) => out ~ Linear(y, s * k, t * k);
@@ -41,32 +57,35 @@ Concrete(int j) >< MulCheckConcrete(out, int k)
| j == 1 => out ~ Concrete(k)
| _ => out ~ Concrete(k * j);
-Linear(x, int q, int r) >< ReLU(out) => (*L)Linear(x, q, r) ~ Materialize(out_x), out ~ TermReLU(out_x);
+Linear(x, int q, int r) >< ReLU(out) => (*L)Linear(x, q, r) ~ Materialize(out_x), out ~ Linear(TermReLU(out_x), 1, 0);
+
Concrete(int k) >< ReLU(out)
| k > 0 => out ~ (*L)Concrete(k)
| _ => out ~ Concrete(0);
Linear(x, int q, int r) >< Materialize(out)
- | ? => out ~ Concrete(r), x ~ Eraser
- | ? => out ~ Symbolic(x)
- | ? => out ~ TermAdd(x, Concrete(r))
- | ? => out ~ TermMul(Concrete(q), x)
+ | (q == 0) && (r == 0) => out ~ Concrete(r), x ~ Eraser
+ | (q == 1) && (r == 0) => out ~ x
+ | (q == 1) && (r != 0) => out ~ TermAdd(x, Concrete(r))
+ | (q != 0) && (r == 0) => out ~ TermMul(Concrete(q), x)
| _ => out ~ TermAdd(TermMul(Concrete(q), x), Concrete(r));
-// Net
-Linear(X, 1, 0) ~ Add(a, b);
-Concrete(0) ~ Mul(b, Linear(Y, 1, 0));
+
+// Net testing
+Linear(Symbolic(X), 1, 0) ~ Add(a, b);
+Concrete(0) ~ Mul(b, Linear(Symbolic(Y), 1, 0));
a ~ Materialize(out);
out; // Symbolic(X)
free out a b;
-Linear(X, 1, 0) ~ Mul(a, Concrete(0));
+Linear(Symbolic(X), 1, 0) ~ Mul(a, Concrete(0));
a ~ Materialize(out);
out; // Concrete(0)
free out a;
-Linear(X, 1, 0) ~ Mul(a, b);
-Linear(Y, 1, 0) ~ Add(b, Concrete(1));
+Linear(Symbolic(X), 1, 0) ~ Mul(a, b);
+Linear(Symbolic(Y), 1, 0) ~ Add(b, Concrete(1));
+a ~ Materialize(out);
out; // Mul(Symbolic(X),Add(Symbolic(Y),Concrete(1)))
free out a b;
@@ -76,13 +95,13 @@ Concrete(3) ~ Mul(c, Concrete(2));
out; // Concrete(9)
free out a b c;
-Linear(X, 1, 0) ~ Mul(a, Linear(W1, 1, 0));
-Linear(Y, 1, 0) ~ Mul(b, Linear(W2, 1, 0));
-b ~ Add(c, Linear(B, 1, 0));
+Linear(Symbolic(X), 1, 0) ~ Mul(a, Linear(Symbolic(W1), 1, 0));
+Linear(Symbolic(Y), 1, 0) ~ Mul(b, Linear(Symbolic(W2), 1, 0));
+b ~ Add(c, Linear(Symbolic(B), 1, 0));
a ~ Add(d, c);
-d ~ ReLU(out);
+d ~ ReLU(Materialize(out));
out; // ReLU(Add(Mul(Symbolic(X),Symbolic(W1)),Add(Mul(Symbolic(Y),Symbolic(W2)),Symbolic(B))))
free out a b c d;
-Linear(X, 2, 10) ~ Materialize(out);
+Linear(Symbolic(X), 2, 10) ~ Materialize(out);
out;
diff --git a/parser.py b/parser.py
deleted file mode 100644
index e66f9d2..0000000
--- a/parser.py
+++ /dev/null
@@ -1,53 +0,0 @@
-
-class Concrete:
- def __init__(self, val): self.val = val
- def __str__(self): return str(self.val)
-
-class Symbolic:
- def __init__(self, id): self.id = id
- def __str__(self): return f"x_{self.id}"
-
-class TermAdd:
- def __init__(self, a, b): self.a, self.b = a, b
- def __str__(self): return f"(+ {self.a} {self.b})"
-
-class TermMul:
- def __init__(self, a, b): self.a, self.b = a, b
- def __str__(self): return f"(* {self.a} {self.b})"
-
-class TermReLU:
- def __init__(self, a): self.a = a
- def __str__(self): return f"(ite (> {self.a} 0) {self.a} 0)"
-
-def generate_z3(net_a_str, net_b_str, epsilon=1e-5):
- context = {
- 'Concrete': Concrete,
- 'Symbolic': Symbolic,
- 'TermAdd': TermAdd,
- 'TermMul': TermMul,
- 'TermReLU': TermReLU
- }
-
- try:
- tree_a = eval(net_a_str, context)
- tree_b = eval(net_b_str, context)
- except Exception as e:
- print(f"; Error parsing Inpla output: {e}")
- return
-
- print("(declare-const x_0 Real)")
- print("(declare-const x_1 Real)")
-
- print(f"(define-fun net_a () Real {tree_a})")
- print(f"(define-fun net_b () Real {tree_b})")
-
- print(f"(assert (> (abs (- net_a net_b)) {epsilon:.10f}))")
-
- print("(check-sat)")
- print("(get-model)")
-
-if __name__ == "__main__":
- output_net_a = "TermAdd(TermMul(Symbolic(0), Concrete(2)), Concrete(3))" # 2x + 3
- output_net_b = "TermAdd(Concrete(3), TermMul(Concrete(2), Symbolic(0)))" # 3 + 2x
-
- generate_z3(output_net_a, output_net_b)
diff --git a/prover.py b/prover.py
new file mode 100644
index 0000000..7ac8c5d
--- /dev/null
+++ b/prover.py
@@ -0,0 +1,66 @@
+import z3
+
+syms = {}
+def Symbolic(id):
+ id = f"x_{id}"
+ if id not in syms:
+ syms[id] = z3.Real(id)
+ return syms[id]
+
+def Concrete(val): return z3.RealVal(val)
+def TermAdd(a, b): return a + b
+def TermMul(a, b): return a * b
+def TermReLU(x): return z3.If(x > 0, x, 0)
+
+context = {
+ 'Concrete': Concrete,
+ 'Symbolic': Symbolic,
+ 'TermAdd': TermAdd,
+ 'TermMul': TermMul,
+ 'TermReLU': TermReLU
+}
+
+def equivalence(net_a, net_b):
+ solver = z3.Solver()
+
+ solver.add(net_a != net_b)
+
+ result = solver.check()
+
+ if result == z3.unsat:
+ print("VERIFIED: The networks are equivalent.")
+ elif result == z3.sat:
+ print("FAILED: The networks are different.")
+ print("Counter-example input:")
+ print(solver.model())
+ else:
+ print("UNKNOWN: Solver could not decide.")
+
+def epsilon_equivalence(net_a, net_b, epsilon):
+ solver = z3.Solver()
+
+ solver.add(z3.Abs(net_a - net_b) > epsilon)
+
+ result = solver.check()
+
+ if result == z3.unsat:
+ print(f"VERIFIED: The networks are epsilon equivalent, with epsilon={epsilon}.")
+ elif result == z3.sat:
+ print("FAILED: The networks are different.")
+ print("Counter-example input:")
+ print(solver.model())
+ else:
+ print("UNKNOWN: Solver could not decide.")
+
+if __name__ == "__main__":
+ net_a_str = "TermAdd(TermMul(Symbolic(0), Concrete(2)), Concrete(3))" # 2x + 3
+ net_b_str = "TermAdd(Concrete(3), TermMul(Concrete(2), Symbolic(0)))" # 3 + 2x
+
+ try:
+ net_a = eval(net_a_str, context)
+ net_b = eval(net_b_str, context)
+ equivalence(net_a, net_b)
+ epsilon_equivalence(net_a, net_b, 1e-5)
+ except Exception as e:
+ print(f"; Error parsing Inpla output: {e}")
+
diff --git a/test.smt2 b/test.smt2
deleted file mode 100644
index 79d05c6..0000000
--- a/test.smt2
+++ /dev/null
@@ -1,7 +0,0 @@
-(declare-const x_0 Real)
-(declare-const x_1 Real)
-(define-fun net_a () Real (+ (* x_0 2) 3))
-(define-fun net_b () Real (+ 3 (* 2 x_0)))
-(assert (> (abs (- net_a net_b)) 0.0000100000))
-(check-sat)
-(get-model)