From 8619ee7a61bafb8c401087508b886e37779be07b Mon Sep 17 00:00:00 2001 From: ericmarin Date: Tue, 10 Mar 2026 17:57:36 +0100 Subject: added scale --- prover.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) (limited to 'prover.py') diff --git a/prover.py b/prover.py index d4a398f..fb9658a 100644 --- a/prover.py +++ b/prover.py @@ -1,6 +1,9 @@ import z3 import sys +# Scale used during export to Inpla (must match the 'scale' parameter in the exporter) +SCALE = 1000.0 + syms = {} def Symbolic(id): id = f"x_{id}" @@ -8,7 +11,8 @@ def Symbolic(id): syms[id] = z3.Real(id) return syms[id] -def Concrete(val): return z3.RealVal(val) +def Concrete(val): return z3.RealVal(val) / SCALE + def TermAdd(a, b): return a + b def TermMul(a, b): return a * b def TermReLU(x): return z3.If(x > 0, x, 0) @@ -56,20 +60,19 @@ def epsilon_equivalence(net_a, net_b, epsilon): def argmax_equivalence(net_a, net_b): solver = z3.Solver() - solver.add(z3.IsInt(net_a > 0.5) != z3.IsInt(net_b > 0.5)) + solver.add((net_a > 0.5) != (net_b > 0.5)) result = solver.check() if result == z3.unsat: - print(f"VERIFIED: The networks are argmax equivalent.") + print("VERIFIED: The networks are argmax equivalent (binary).") elif result == z3.sat: - print("FAILED: The networks are different.") + print("FAILED: The networks are classification-different.") print("Counter-example input:") print(solver.model()) else: print("UNKNOWN: Solver could not decide.") - if __name__ == "__main__": lines = [line.strip() for line in sys.stdin if line.strip() and not line.startswith("(")] @@ -81,7 +84,7 @@ if __name__ == "__main__": net_a_str = lines[-2] net_b_str = lines[-1] - print(f"Comparing:\nA: {net_a_str}\nB: {net_b_str}") + print(f"Comparing:\nA: {net_a_str}\n\nB: {net_b_str}") net_a = eval(net_a_str, context) net_b = eval(net_b_str, context) -- cgit v1.2.3