aboutsummaryrefslogtreecommitdiff
path: root/prover.py
diff options
context:
space:
mode:
authorericmarin <maarin.eric@gmail.com>2026-03-10 17:42:02 +0100
committerericmarin <maarin.eric@gmail.com>2026-03-10 17:50:08 +0100
commit0882fc5328127f68a7d79c06d0c7decdee770bb9 (patch)
tree805051f8c5f830fee28d82fc5c9daedcff15f91f /prover.py
parentaf2b13214579d78827392e762149fa8824526aa9 (diff)
downloadvein-0882fc5328127f68a7d79c06d0c7decdee770bb9.tar.gz
vein-0882fc5328127f68a7d79c06d0c7decdee770bb9.zip
two nets
Diffstat (limited to 'prover.py')
-rw-r--r--prover.py38
1 files changed, 35 insertions, 3 deletions
diff --git a/prover.py b/prover.py
index 7ac8c5d..d4a398f 100644
--- a/prover.py
+++ b/prover.py
@@ -1,4 +1,5 @@
import z3
+import sys
syms = {}
def Symbolic(id):
@@ -52,15 +53,46 @@ def epsilon_equivalence(net_a, net_b, epsilon):
else:
print("UNKNOWN: Solver could not decide.")
+def argmax_equivalence(net_a, net_b):
+ solver = z3.Solver()
+
+ solver.add(z3.IsInt(net_a > 0.5) != z3.IsInt(net_b > 0.5))
+
+ result = solver.check()
+
+ if result == z3.unsat:
+ print(f"VERIFIED: The networks are argmax equivalent.")
+ elif result == z3.sat:
+ print("FAILED: The networks are different.")
+ print("Counter-example input:")
+ print(solver.model())
+ else:
+ print("UNKNOWN: Solver could not decide.")
+
+
if __name__ == "__main__":
- net_a_str = "TermAdd(TermMul(Symbolic(0), Concrete(2)), Concrete(3))" # 2x + 3
- net_b_str = "TermAdd(Concrete(3), TermMul(Concrete(2), Symbolic(0)))" # 3 + 2x
+ lines = [line.strip() for line in sys.stdin if line.strip() and not line.startswith("(")]
+
+ if len(lines) < 2:
+ print(f"; Error: Expected at least 2 Inpla output strings, but got {len(lines)}.")
+ sys.exit(1)
try:
+ net_a_str = lines[-2]
+ net_b_str = lines[-1]
+
+ print(f"Comparing:\nA: {net_a_str}\nB: {net_b_str}")
+
net_a = eval(net_a_str, context)
net_b = eval(net_b_str, context)
+
+ print("\nStrict Equivalence")
equivalence(net_a, net_b)
+ print("\nEpsilon-Equivalence")
epsilon_equivalence(net_a, net_b, 1e-5)
+ print("\nARGMAX Equivalence")
+ argmax_equivalence(net_a, net_b)
+
except Exception as e:
print(f"; Error parsing Inpla output: {e}")
-
+ sys.exit(1)