aboutsummaryrefslogtreecommitdiff
path: root/prover.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--prover.py15
1 files changed, 9 insertions, 6 deletions
diff --git a/prover.py b/prover.py
index d4a398f..fb9658a 100644
--- a/prover.py
+++ b/prover.py
@@ -1,6 +1,9 @@
import z3
import sys
+# Scale used during export to Inpla (must match the 'scale' parameter in the exporter)
+SCALE = 1000.0
+
syms = {}
def Symbolic(id):
id = f"x_{id}"
@@ -8,7 +11,8 @@ def Symbolic(id):
syms[id] = z3.Real(id)
return syms[id]
-def Concrete(val): return z3.RealVal(val)
+def Concrete(val): return z3.RealVal(val) / SCALE
+
def TermAdd(a, b): return a + b
def TermMul(a, b): return a * b
def TermReLU(x): return z3.If(x > 0, x, 0)
@@ -56,20 +60,19 @@ def epsilon_equivalence(net_a, net_b, epsilon):
def argmax_equivalence(net_a, net_b):
solver = z3.Solver()
- solver.add(z3.IsInt(net_a > 0.5) != z3.IsInt(net_b > 0.5))
+ solver.add((net_a > 0.5) != (net_b > 0.5))
result = solver.check()
if result == z3.unsat:
- print(f"VERIFIED: The networks are argmax equivalent.")
+ print("VERIFIED: The networks are argmax equivalent (binary).")
elif result == z3.sat:
- print("FAILED: The networks are different.")
+ print("FAILED: The networks are classification-different.")
print("Counter-example input:")
print(solver.model())
else:
print("UNKNOWN: Solver could not decide.")
-
if __name__ == "__main__":
lines = [line.strip() for line in sys.stdin if line.strip() and not line.startswith("(")]
@@ -81,7 +84,7 @@ if __name__ == "__main__":
net_a_str = lines[-2]
net_b_str = lines[-1]
- print(f"Comparing:\nA: {net_a_str}\nB: {net_b_str}")
+ print(f"Comparing:\nA: {net_a_str}\n\nB: {net_b_str}")
net_a = eval(net_a_str, context)
net_b = eval(net_b_str, context)