From a0b1e7f6a8c11ed98ae20ac484e2fe9f75b9b85f Mon Sep 17 00:00:00 2001 From: ericmarin Date: Fri, 13 Mar 2026 16:42:00 +0100 Subject: defined modules --- prover.py | 108 -------------------------------------------------------------- 1 file changed, 108 deletions(-) delete mode 100644 prover.py (limited to 'prover.py') diff --git a/prover.py b/prover.py deleted file mode 100644 index 614cdb6..0000000 --- a/prover.py +++ /dev/null @@ -1,108 +0,0 @@ -import z3 -import re -import sys - -syms = {} -def Symbolic(id): - if id not in syms: - syms[id] = z3.Real(id) - return syms[id] - -def Concrete(val): return z3.RealVal(val) - -def TermAdd(a, b): return a + b -def TermMul(a, b): return a * b -def TermReLU(x): return z3.If(x > 0, x, 0) - -context = { - 'Concrete': Concrete, - 'Symbolic': Symbolic, - 'TermAdd': TermAdd, - 'TermMul': TermMul, - 'TermReLU': TermReLU -} - -def equivalence(net_a, net_b): - solver = z3.Solver() - - for sym in syms.values(): - solver.add(z3.Or(sym == 0, sym == 1)) - - solver.add(net_a != net_b) - - result = solver.check() - - if result == z3.unsat: - print("VERIFIED: The networks are equivalent.") - elif result == z3.sat: - print("FAILED: The networks are different.") - print("Counter-example input:") - print(solver.model()) - else: - print("UNKNOWN: Solver could not decide.") - -def epsilon_equivalence(net_a, net_b, epsilon): - solver = z3.Solver() - - for sym in syms.values(): - solver.add(z3.Or(sym == 0, sym == 1)) - - solver.add(z3.Abs(net_a - net_b) > epsilon) - - result = solver.check() - - if result == z3.unsat: - print(f"VERIFIED: The networks are epsilon equivalent, with epsilon={epsilon}.") - elif result == z3.sat: - print("FAILED: The networks are different.") - print("Counter-example input:") - print(solver.model()) - else: - print("UNKNOWN: Solver could not decide.") - -def argmax_equivalence(net_a, net_b): - solver = z3.Solver() - - for sym in syms.values(): - solver.add(z3.Or(sym == 0, sym == 1)) - - solver.add((net_a > 0.5) != (net_b > 0.5)) - - result = solver.check() - - if result == z3.unsat: - print("VERIFIED: The networks are argmax equivalent.") - elif result == z3.sat: - print("FAILED: The networks are classification-different.") - print("Counter-example input:") - print(solver.model()) - else: - print("UNKNOWN: Solver could not decide.") - -if __name__ == "__main__": - lines = [line.strip() for line in sys.stdin if line.strip() and not line.startswith("(")] - - if len(lines) < 2: - print(f"; Error: Expected at least 2 Inpla output strings, but got {len(lines)}.") - sys.exit(1) - - try: - wrap = re.compile(r"Symbolic\((.*?)\)") - net_a_str = wrap.sub(r'Symbolic("\1")', lines[-2]); - net_b_str = wrap.sub(r'Symbolic("\1")', lines[-1]); - - print(f"Comparing:\nA: {net_a_str}\n\nB: {net_b_str}") - - net_a = eval(net_a_str, context) - net_b = eval(net_b_str, context) - - print("\nStrict Equivalence") - equivalence(net_a, net_b) - print("\nEpsilon-Equivalence") - epsilon_equivalence(net_a, net_b, 1e-1) - print("\nARGMAX Equivalence") - argmax_equivalence(net_a, net_b) - - except Exception as e: - print(f"; Error parsing Inpla output: {e}") - sys.exit(1) -- cgit v1.2.3