diff options
| author | ericmarin <maarin.eric@gmail.com> | 2026-03-13 16:42:00 +0100 |
|---|---|---|
| committer | ericmarin <maarin.eric@gmail.com> | 2026-03-16 10:23:02 +0100 |
| commit | a0b1e7f6a8c11ed98ae20ac484e2fe9f75b9b85f (patch) | |
| tree | 57d6aa106daf4a46d9132832eec88fdb79fc5543 /prover.py | |
| parent | 19652ec48be4c6faf3f7815a9281b611aed94727 (diff) | |
| download | vein-a0b1e7f6a8c11ed98ae20ac484e2fe9f75b9b85f.tar.gz vein-a0b1e7f6a8c11ed98ae20ac484e2fe9f75b9b85f.zip | |
defined modules
Diffstat (limited to 'prover.py')
| -rw-r--r-- | prover.py | 108 |
1 files changed, 0 insertions, 108 deletions
diff --git a/prover.py b/prover.py deleted file mode 100644 index 614cdb6..0000000 --- a/prover.py +++ /dev/null @@ -1,108 +0,0 @@ -import z3 -import re -import sys - -syms = {} -def Symbolic(id): - if id not in syms: - syms[id] = z3.Real(id) - return syms[id] - -def Concrete(val): return z3.RealVal(val) - -def TermAdd(a, b): return a + b -def TermMul(a, b): return a * b -def TermReLU(x): return z3.If(x > 0, x, 0) - -context = { - 'Concrete': Concrete, - 'Symbolic': Symbolic, - 'TermAdd': TermAdd, - 'TermMul': TermMul, - 'TermReLU': TermReLU -} - -def equivalence(net_a, net_b): - solver = z3.Solver() - - for sym in syms.values(): - solver.add(z3.Or(sym == 0, sym == 1)) - - solver.add(net_a != net_b) - - result = solver.check() - - if result == z3.unsat: - print("VERIFIED: The networks are equivalent.") - elif result == z3.sat: - print("FAILED: The networks are different.") - print("Counter-example input:") - print(solver.model()) - else: - print("UNKNOWN: Solver could not decide.") - -def epsilon_equivalence(net_a, net_b, epsilon): - solver = z3.Solver() - - for sym in syms.values(): - solver.add(z3.Or(sym == 0, sym == 1)) - - solver.add(z3.Abs(net_a - net_b) > epsilon) - - result = solver.check() - - if result == z3.unsat: - print(f"VERIFIED: The networks are epsilon equivalent, with epsilon={epsilon}.") - elif result == z3.sat: - print("FAILED: The networks are different.") - print("Counter-example input:") - print(solver.model()) - else: - print("UNKNOWN: Solver could not decide.") - -def argmax_equivalence(net_a, net_b): - solver = z3.Solver() - - for sym in syms.values(): - solver.add(z3.Or(sym == 0, sym == 1)) - - solver.add((net_a > 0.5) != (net_b > 0.5)) - - result = solver.check() - - if result == z3.unsat: - print("VERIFIED: The networks are argmax equivalent.") - elif result == z3.sat: - print("FAILED: The networks are classification-different.") - print("Counter-example input:") - print(solver.model()) - else: - print("UNKNOWN: Solver could not decide.") - -if __name__ == "__main__": - lines = [line.strip() for line in sys.stdin if line.strip() and not line.startswith("(")] - - if len(lines) < 2: - print(f"; Error: Expected at least 2 Inpla output strings, but got {len(lines)}.") - sys.exit(1) - - try: - wrap = re.compile(r"Symbolic\((.*?)\)") - net_a_str = wrap.sub(r'Symbolic("\1")', lines[-2]); - net_b_str = wrap.sub(r'Symbolic("\1")', lines[-1]); - - print(f"Comparing:\nA: {net_a_str}\n\nB: {net_b_str}") - - net_a = eval(net_a_str, context) - net_b = eval(net_b_str, context) - - print("\nStrict Equivalence") - equivalence(net_a, net_b) - print("\nEpsilon-Equivalence") - epsilon_equivalence(net_a, net_b, 1e-1) - print("\nARGMAX Equivalence") - argmax_equivalence(net_a, net_b) - - except Exception as e: - print(f"; Error parsing Inpla output: {e}") - sys.exit(1) |
