diff options
| author | ericmarin <maarin.eric@gmail.com> | 2026-03-10 11:05:26 +0100 |
|---|---|---|
| committer | ericmarin <maarin.eric@gmail.com> | 2026-03-10 12:05:35 +0100 |
| commit | 7933d744e06337f1d69b7da83f2cee1611556097 (patch) | |
| tree | 2a1594ab6f8d2cfeb784dee80e2d830d9999e4a5 | |
| parent | 51cd389b4e322313671dd0e53513ce84b72a1652 (diff) | |
| download | vein-7933d744e06337f1d69b7da83f2cee1611556097.tar.gz vein-7933d744e06337f1d69b7da83f2cee1611556097.zip | |
added prover script
| -rw-r--r-- | nn.in | 73 | ||||
| -rw-r--r-- | parser.py | 53 | ||||
| -rw-r--r-- | prover.py | 66 | ||||
| -rw-r--r-- | test.smt2 | 7 |
4 files changed, 112 insertions, 87 deletions
@@ -1,3 +1,19 @@ +// Agents +// Built-in +// Eraser: delete other agents recursively +// Dup: duplicates other agents recursively + +// Implemented +// Linear(x, int q, int r): represent "q*x + r" +// Concrete(int k): represent a concrete value k +// Symbolic(id): represent the variable id +// Add(out, b): represent the addition (has various steps AddCheckLinear/AddCheckConcrete) +// Mul(out, b): represent the multiplication (has various steps MulCheckLinear/MulCheckConcrete) +// ReLU(out): represent "if x > 0 ? x ; 0" +// Materialize(out): transforms a Linear packet into a final representation of TermAdd/TermMul/TermReLU + +// TODO: add range information to enable ReLU elimination + // Rules Linear(x, int q, int r) >< Add(out, b) => b ~ AddCheckLinear(out, x, q, r); @@ -6,10 +22,10 @@ Concrete(int k) >< Add(out, b) | _ => b ~ AddCheckConcrete(out, k); Linear(y, int s, int t) >< AddCheckLinear(out, x, int q, int r) - | q == 0 && r == 0 && s == 0 && t == 0 => out ~ Concrete(0), x ~ Eraser, y ~ Eraser - | s == 0 && t == 0 => Linear(x, q, r) ~ Materialize(out), y ~ Eraser - | q == 0 && r == 0 => (*L)Linear(y, s, t) ~ Materialize(out), x ~ Eraser - | _ => Linear(x, q, r) ~ Materialize(out_x), (*L)Linear(y, s, r) ~ Materialize(out_y), out ~ TermAdd(out_x, out_y); + | (q == 0) && (r == 0) && (s == 0) && (t == 0) => out ~ Concrete(0), x ~ Eraser, y ~ Eraser + | (s == 0) && (t == 0) => Linear(x, q, r) ~ Materialize(out), y ~ Eraser + | (q == 0) && (r == 0) => (*L)Linear(y, s, t) ~ Materialize(out), x ~ Eraser + | _ => Linear(x, q, r) ~ Materialize(out_x), (*L)Linear(y, s, t) ~ Materialize(out_y), out ~ Linear(TermAdd(out_x, out_y), 1, 0); Concrete(int j) >< AddCheckLinear(out, x, int q, int r) => out ~ Linear(x, q, r + j); @@ -19,20 +35,20 @@ Concrete(int j) >< AddCheckConcrete(out, int k) | j == 0 => out ~ Concrete(k) | _ => out ~ Concrete(k + j); -Linear(x, int q, int r) >< Mul(out, b) => b ~ MulCheck(out, x, q, r); +Linear(x, int q, int r) >< Mul(out, b) => b ~ MulCheckLinear(out, x, q, r); Concrete(int k) >< Mul(out, b) | k == 0 => b ~ Eraser, out ~ (*L)Concrete(0) | k == 1 => out ~ b | _ => b ~ MulCheckConcrete(out, k); -Linear(y, int s, int t) >< MulCheck(out, x, int q, int r) - | q == 0 && r == 0 && s == 0 && t == 0 => out ~ Concrete(0), x ~ Eraser, y ~ Eraser - | s == 0 && t == 0 => Linear(x, q, r) ~ Materialize(out), y ~ Eraser - | q == 0 && r == 0 => (*L)Linear(y, s, t) ~ Materialize(out), x ~ Eraser - | _ => Linear(x, q, r) ~ Materialize(out_x), (*L)Linear(y, s, r) ~ Materialize(out_y), out ~ TermMul(out_x, out_y); +Linear(y, int s, int t) >< MulCheckLinear(out, x, int q, int r) + | (q == 0) && (r == 0) && (s == 0) && (t == 0) => out ~ Concrete(0), x ~ Eraser, y ~ Eraser + | (s == 0) && (t == 0) => Linear(x, q, r) ~ Materialize(out), y ~ Eraser + | (q == 0) && (r == 0) => (*L)Linear(y, s, t) ~ Materialize(out), x ~ Eraser + | _ => Linear(x, q, r) ~ Materialize(out_x), (*L)Linear(y, s, t) ~ Materialize(out_y), out ~ Linear(TermMul(out_x, out_y), 1, 0); -Concrete(int j) >< MulCheck(out, x, int q, int r) => out ~ Linear(x, q * j, r * j); +Concrete(int j) >< MulCheckLinear(out, x, int q, int r) => out ~ Linear(x, q * j, r * j); Linear(y, int s, int t) >< MulCheckConcrete(out, int k) => out ~ Linear(y, s * k, t * k); @@ -41,32 +57,35 @@ Concrete(int j) >< MulCheckConcrete(out, int k) | j == 1 => out ~ Concrete(k) | _ => out ~ Concrete(k * j); -Linear(x, int q, int r) >< ReLU(out) => (*L)Linear(x, q, r) ~ Materialize(out_x), out ~ TermReLU(out_x); +Linear(x, int q, int r) >< ReLU(out) => (*L)Linear(x, q, r) ~ Materialize(out_x), out ~ Linear(TermReLU(out_x), 1, 0); + Concrete(int k) >< ReLU(out) | k > 0 => out ~ (*L)Concrete(k) | _ => out ~ Concrete(0); Linear(x, int q, int r) >< Materialize(out) - | ? => out ~ Concrete(r), x ~ Eraser - | ? => out ~ Symbolic(x) - | ? => out ~ TermAdd(x, Concrete(r)) - | ? => out ~ TermMul(Concrete(q), x) + | (q == 0) && (r == 0) => out ~ Concrete(r), x ~ Eraser + | (q == 1) && (r == 0) => out ~ x + | (q == 1) && (r != 0) => out ~ TermAdd(x, Concrete(r)) + | (q != 0) && (r == 0) => out ~ TermMul(Concrete(q), x) | _ => out ~ TermAdd(TermMul(Concrete(q), x), Concrete(r)); -// Net -Linear(X, 1, 0) ~ Add(a, b); -Concrete(0) ~ Mul(b, Linear(Y, 1, 0)); + +// Net testing +Linear(Symbolic(X), 1, 0) ~ Add(a, b); +Concrete(0) ~ Mul(b, Linear(Symbolic(Y), 1, 0)); a ~ Materialize(out); out; // Symbolic(X) free out a b; -Linear(X, 1, 0) ~ Mul(a, Concrete(0)); +Linear(Symbolic(X), 1, 0) ~ Mul(a, Concrete(0)); a ~ Materialize(out); out; // Concrete(0) free out a; -Linear(X, 1, 0) ~ Mul(a, b); -Linear(Y, 1, 0) ~ Add(b, Concrete(1)); +Linear(Symbolic(X), 1, 0) ~ Mul(a, b); +Linear(Symbolic(Y), 1, 0) ~ Add(b, Concrete(1)); +a ~ Materialize(out); out; // Mul(Symbolic(X),Add(Symbolic(Y),Concrete(1))) free out a b; @@ -76,13 +95,13 @@ Concrete(3) ~ Mul(c, Concrete(2)); out; // Concrete(9) free out a b c; -Linear(X, 1, 0) ~ Mul(a, Linear(W1, 1, 0)); -Linear(Y, 1, 0) ~ Mul(b, Linear(W2, 1, 0)); -b ~ Add(c, Linear(B, 1, 0)); +Linear(Symbolic(X), 1, 0) ~ Mul(a, Linear(Symbolic(W1), 1, 0)); +Linear(Symbolic(Y), 1, 0) ~ Mul(b, Linear(Symbolic(W2), 1, 0)); +b ~ Add(c, Linear(Symbolic(B), 1, 0)); a ~ Add(d, c); -d ~ ReLU(out); +d ~ ReLU(Materialize(out)); out; // ReLU(Add(Mul(Symbolic(X),Symbolic(W1)),Add(Mul(Symbolic(Y),Symbolic(W2)),Symbolic(B)))) free out a b c d; -Linear(X, 2, 10) ~ Materialize(out); +Linear(Symbolic(X), 2, 10) ~ Materialize(out); out; diff --git a/parser.py b/parser.py deleted file mode 100644 index e66f9d2..0000000 --- a/parser.py +++ /dev/null @@ -1,53 +0,0 @@ - -class Concrete: - def __init__(self, val): self.val = val - def __str__(self): return str(self.val) - -class Symbolic: - def __init__(self, id): self.id = id - def __str__(self): return f"x_{self.id}" - -class TermAdd: - def __init__(self, a, b): self.a, self.b = a, b - def __str__(self): return f"(+ {self.a} {self.b})" - -class TermMul: - def __init__(self, a, b): self.a, self.b = a, b - def __str__(self): return f"(* {self.a} {self.b})" - -class TermReLU: - def __init__(self, a): self.a = a - def __str__(self): return f"(ite (> {self.a} 0) {self.a} 0)" - -def generate_z3(net_a_str, net_b_str, epsilon=1e-5): - context = { - 'Concrete': Concrete, - 'Symbolic': Symbolic, - 'TermAdd': TermAdd, - 'TermMul': TermMul, - 'TermReLU': TermReLU - } - - try: - tree_a = eval(net_a_str, context) - tree_b = eval(net_b_str, context) - except Exception as e: - print(f"; Error parsing Inpla output: {e}") - return - - print("(declare-const x_0 Real)") - print("(declare-const x_1 Real)") - - print(f"(define-fun net_a () Real {tree_a})") - print(f"(define-fun net_b () Real {tree_b})") - - print(f"(assert (> (abs (- net_a net_b)) {epsilon:.10f}))") - - print("(check-sat)") - print("(get-model)") - -if __name__ == "__main__": - output_net_a = "TermAdd(TermMul(Symbolic(0), Concrete(2)), Concrete(3))" # 2x + 3 - output_net_b = "TermAdd(Concrete(3), TermMul(Concrete(2), Symbolic(0)))" # 3 + 2x - - generate_z3(output_net_a, output_net_b) diff --git a/prover.py b/prover.py new file mode 100644 index 0000000..7ac8c5d --- /dev/null +++ b/prover.py @@ -0,0 +1,66 @@ +import z3 + +syms = {} +def Symbolic(id): + id = f"x_{id}" + if id not in syms: + syms[id] = z3.Real(id) + return syms[id] + +def Concrete(val): return z3.RealVal(val) +def TermAdd(a, b): return a + b +def TermMul(a, b): return a * b +def TermReLU(x): return z3.If(x > 0, x, 0) + +context = { + 'Concrete': Concrete, + 'Symbolic': Symbolic, + 'TermAdd': TermAdd, + 'TermMul': TermMul, + 'TermReLU': TermReLU +} + +def equivalence(net_a, net_b): + solver = z3.Solver() + + solver.add(net_a != net_b) + + result = solver.check() + + if result == z3.unsat: + print("VERIFIED: The networks are equivalent.") + elif result == z3.sat: + print("FAILED: The networks are different.") + print("Counter-example input:") + print(solver.model()) + else: + print("UNKNOWN: Solver could not decide.") + +def epsilon_equivalence(net_a, net_b, epsilon): + solver = z3.Solver() + + solver.add(z3.Abs(net_a - net_b) > epsilon) + + result = solver.check() + + if result == z3.unsat: + print(f"VERIFIED: The networks are epsilon equivalent, with epsilon={epsilon}.") + elif result == z3.sat: + print("FAILED: The networks are different.") + print("Counter-example input:") + print(solver.model()) + else: + print("UNKNOWN: Solver could not decide.") + +if __name__ == "__main__": + net_a_str = "TermAdd(TermMul(Symbolic(0), Concrete(2)), Concrete(3))" # 2x + 3 + net_b_str = "TermAdd(Concrete(3), TermMul(Concrete(2), Symbolic(0)))" # 3 + 2x + + try: + net_a = eval(net_a_str, context) + net_b = eval(net_b_str, context) + equivalence(net_a, net_b) + epsilon_equivalence(net_a, net_b, 1e-5) + except Exception as e: + print(f"; Error parsing Inpla output: {e}") + diff --git a/test.smt2 b/test.smt2 deleted file mode 100644 index 79d05c6..0000000 --- a/test.smt2 +++ /dev/null @@ -1,7 +0,0 @@ -(declare-const x_0 Real) -(declare-const x_1 Real) -(define-fun net_a () Real (+ (* x_0 2) 3)) -(define-fun net_b () Real (+ 3 (* 2 x_0))) -(assert (> (abs (- net_a net_b)) 0.0000100000)) -(check-sat) -(get-model) |
