diff options
| author | ericmarin <maarin.eric@gmail.com> | 2026-03-10 11:05:26 +0100 |
|---|---|---|
| committer | ericmarin <maarin.eric@gmail.com> | 2026-03-10 12:05:35 +0100 |
| commit | 7933d744e06337f1d69b7da83f2cee1611556097 (patch) | |
| tree | 2a1594ab6f8d2cfeb784dee80e2d830d9999e4a5 /prover.py | |
| parent | 51cd389b4e322313671dd0e53513ce84b72a1652 (diff) | |
| download | vein-7933d744e06337f1d69b7da83f2cee1611556097.tar.gz vein-7933d744e06337f1d69b7da83f2cee1611556097.zip | |
added prover script
Diffstat (limited to 'prover.py')
| -rw-r--r-- | prover.py | 66 |
1 files changed, 66 insertions, 0 deletions
diff --git a/prover.py b/prover.py new file mode 100644 index 0000000..7ac8c5d --- /dev/null +++ b/prover.py @@ -0,0 +1,66 @@ +import z3 + +syms = {} +def Symbolic(id): + id = f"x_{id}" + if id not in syms: + syms[id] = z3.Real(id) + return syms[id] + +def Concrete(val): return z3.RealVal(val) +def TermAdd(a, b): return a + b +def TermMul(a, b): return a * b +def TermReLU(x): return z3.If(x > 0, x, 0) + +context = { + 'Concrete': Concrete, + 'Symbolic': Symbolic, + 'TermAdd': TermAdd, + 'TermMul': TermMul, + 'TermReLU': TermReLU +} + +def equivalence(net_a, net_b): + solver = z3.Solver() + + solver.add(net_a != net_b) + + result = solver.check() + + if result == z3.unsat: + print("VERIFIED: The networks are equivalent.") + elif result == z3.sat: + print("FAILED: The networks are different.") + print("Counter-example input:") + print(solver.model()) + else: + print("UNKNOWN: Solver could not decide.") + +def epsilon_equivalence(net_a, net_b, epsilon): + solver = z3.Solver() + + solver.add(z3.Abs(net_a - net_b) > epsilon) + + result = solver.check() + + if result == z3.unsat: + print(f"VERIFIED: The networks are epsilon equivalent, with epsilon={epsilon}.") + elif result == z3.sat: + print("FAILED: The networks are different.") + print("Counter-example input:") + print(solver.model()) + else: + print("UNKNOWN: Solver could not decide.") + +if __name__ == "__main__": + net_a_str = "TermAdd(TermMul(Symbolic(0), Concrete(2)), Concrete(3))" # 2x + 3 + net_b_str = "TermAdd(Concrete(3), TermMul(Concrete(2), Symbolic(0)))" # 3 + 2x + + try: + net_a = eval(net_a_str, context) + net_b = eval(net_b_str, context) + equivalence(net_a, net_b) + epsilon_equivalence(net_a, net_b, 1e-5) + except Exception as e: + print(f"; Error parsing Inpla output: {e}") + |
