From fb544e2089e0c52bd83ffe56f2f4e8d7176564ee Mon Sep 17 00:00:00 2001 From: ericmarin Date: Wed, 11 Mar 2026 16:07:04 +0100 Subject: added constraint to prover --- prover.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) (limited to 'prover.py') diff --git a/prover.py b/prover.py index fb9658a..d6d1fcc 100644 --- a/prover.py +++ b/prover.py @@ -28,6 +28,9 @@ context = { def equivalence(net_a, net_b): solver = z3.Solver() + for sym in syms.values(): + solver.add(z3.Or(sym == 0, sym == 1)) + solver.add(net_a != net_b) result = solver.check() @@ -43,6 +46,9 @@ def equivalence(net_a, net_b): def epsilon_equivalence(net_a, net_b, epsilon): solver = z3.Solver() + + for sym in syms.values(): + solver.add(z3.Or(sym == 0, sym == 1)) solver.add(z3.Abs(net_a - net_b) > epsilon) @@ -59,13 +65,16 @@ def epsilon_equivalence(net_a, net_b, epsilon): def argmax_equivalence(net_a, net_b): solver = z3.Solver() + + for sym in syms.values(): + solver.add(z3.Or(sym == 0, sym == 1)) solver.add((net_a > 0.5) != (net_b > 0.5)) result = solver.check() if result == z3.unsat: - print("VERIFIED: The networks are argmax equivalent (binary).") + print("VERIFIED: The networks are argmax equivalent.") elif result == z3.sat: print("FAILED: The networks are classification-different.") print("Counter-example input:") @@ -92,7 +101,7 @@ if __name__ == "__main__": print("\nStrict Equivalence") equivalence(net_a, net_b) print("\nEpsilon-Equivalence") - epsilon_equivalence(net_a, net_b, 1e-5) + epsilon_equivalence(net_a, net_b, 1e-2) print("\nARGMAX Equivalence") argmax_equivalence(net_a, net_b) -- cgit v1.2.3