aboutsummaryrefslogtreecommitdiff
path: root/prover.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--prover.py108
1 files changed, 0 insertions, 108 deletions
diff --git a/prover.py b/prover.py
deleted file mode 100644
index 614cdb6..0000000
--- a/prover.py
+++ /dev/null
@@ -1,108 +0,0 @@
-import z3
-import re
-import sys
-
-syms = {}
-def Symbolic(id):
- if id not in syms:
- syms[id] = z3.Real(id)
- return syms[id]
-
-def Concrete(val): return z3.RealVal(val)
-
-def TermAdd(a, b): return a + b
-def TermMul(a, b): return a * b
-def TermReLU(x): return z3.If(x > 0, x, 0)
-
-context = {
- 'Concrete': Concrete,
- 'Symbolic': Symbolic,
- 'TermAdd': TermAdd,
- 'TermMul': TermMul,
- 'TermReLU': TermReLU
-}
-
-def equivalence(net_a, net_b):
- solver = z3.Solver()
-
- for sym in syms.values():
- solver.add(z3.Or(sym == 0, sym == 1))
-
- solver.add(net_a != net_b)
-
- result = solver.check()
-
- if result == z3.unsat:
- print("VERIFIED: The networks are equivalent.")
- elif result == z3.sat:
- print("FAILED: The networks are different.")
- print("Counter-example input:")
- print(solver.model())
- else:
- print("UNKNOWN: Solver could not decide.")
-
-def epsilon_equivalence(net_a, net_b, epsilon):
- solver = z3.Solver()
-
- for sym in syms.values():
- solver.add(z3.Or(sym == 0, sym == 1))
-
- solver.add(z3.Abs(net_a - net_b) > epsilon)
-
- result = solver.check()
-
- if result == z3.unsat:
- print(f"VERIFIED: The networks are epsilon equivalent, with epsilon={epsilon}.")
- elif result == z3.sat:
- print("FAILED: The networks are different.")
- print("Counter-example input:")
- print(solver.model())
- else:
- print("UNKNOWN: Solver could not decide.")
-
-def argmax_equivalence(net_a, net_b):
- solver = z3.Solver()
-
- for sym in syms.values():
- solver.add(z3.Or(sym == 0, sym == 1))
-
- solver.add((net_a > 0.5) != (net_b > 0.5))
-
- result = solver.check()
-
- if result == z3.unsat:
- print("VERIFIED: The networks are argmax equivalent.")
- elif result == z3.sat:
- print("FAILED: The networks are classification-different.")
- print("Counter-example input:")
- print(solver.model())
- else:
- print("UNKNOWN: Solver could not decide.")
-
-if __name__ == "__main__":
- lines = [line.strip() for line in sys.stdin if line.strip() and not line.startswith("(")]
-
- if len(lines) < 2:
- print(f"; Error: Expected at least 2 Inpla output strings, but got {len(lines)}.")
- sys.exit(1)
-
- try:
- wrap = re.compile(r"Symbolic\((.*?)\)")
- net_a_str = wrap.sub(r'Symbolic("\1")', lines[-2]);
- net_b_str = wrap.sub(r'Symbolic("\1")', lines[-1]);
-
- print(f"Comparing:\nA: {net_a_str}\n\nB: {net_b_str}")
-
- net_a = eval(net_a_str, context)
- net_b = eval(net_b_str, context)
-
- print("\nStrict Equivalence")
- equivalence(net_a, net_b)
- print("\nEpsilon-Equivalence")
- epsilon_equivalence(net_a, net_b, 1e-1)
- print("\nARGMAX Equivalence")
- argmax_equivalence(net_a, net_b)
-
- except Exception as e:
- print(f"; Error parsing Inpla output: {e}")
- sys.exit(1)