diff options
| author | ericmarin <maarin.eric@gmail.com> | 2026-03-10 17:42:02 +0100 |
|---|---|---|
| committer | ericmarin <maarin.eric@gmail.com> | 2026-03-10 17:50:08 +0100 |
| commit | 0882fc5328127f68a7d79c06d0c7decdee770bb9 (patch) | |
| tree | 805051f8c5f830fee28d82fc5c9daedcff15f91f /prover.py | |
| parent | af2b13214579d78827392e762149fa8824526aa9 (diff) | |
| download | vein-0882fc5328127f68a7d79c06d0c7decdee770bb9.tar.gz vein-0882fc5328127f68a7d79c06d0c7decdee770bb9.zip | |
two nets
Diffstat (limited to 'prover.py')
| -rw-r--r-- | prover.py | 38 |
1 files changed, 35 insertions, 3 deletions
@@ -1,4 +1,5 @@ import z3 +import sys syms = {} def Symbolic(id): @@ -52,15 +53,46 @@ def epsilon_equivalence(net_a, net_b, epsilon): else: print("UNKNOWN: Solver could not decide.") +def argmax_equivalence(net_a, net_b): + solver = z3.Solver() + + solver.add(z3.IsInt(net_a > 0.5) != z3.IsInt(net_b > 0.5)) + + result = solver.check() + + if result == z3.unsat: + print(f"VERIFIED: The networks are argmax equivalent.") + elif result == z3.sat: + print("FAILED: The networks are different.") + print("Counter-example input:") + print(solver.model()) + else: + print("UNKNOWN: Solver could not decide.") + + if __name__ == "__main__": - net_a_str = "TermAdd(TermMul(Symbolic(0), Concrete(2)), Concrete(3))" # 2x + 3 - net_b_str = "TermAdd(Concrete(3), TermMul(Concrete(2), Symbolic(0)))" # 3 + 2x + lines = [line.strip() for line in sys.stdin if line.strip() and not line.startswith("(")] + + if len(lines) < 2: + print(f"; Error: Expected at least 2 Inpla output strings, but got {len(lines)}.") + sys.exit(1) try: + net_a_str = lines[-2] + net_b_str = lines[-1] + + print(f"Comparing:\nA: {net_a_str}\nB: {net_b_str}") + net_a = eval(net_a_str, context) net_b = eval(net_b_str, context) + + print("\nStrict Equivalence") equivalence(net_a, net_b) + print("\nEpsilon-Equivalence") epsilon_equivalence(net_a, net_b, 1e-5) + print("\nARGMAX Equivalence") + argmax_equivalence(net_a, net_b) + except Exception as e: print(f"; Error parsing Inpla output: {e}") - + sys.exit(1) |
