From fb544e2089e0c52bd83ffe56f2f4e8d7176564ee Mon Sep 17 00:00:00 2001 From: ericmarin Date: Wed, 11 Mar 2026 16:07:04 +0100 Subject: added constraint to prover --- appunti | 7 +++++++ prover.py | 13 +++++++++++-- 2 files changed, 18 insertions(+), 2 deletions(-) create mode 100644 appunti diff --git a/appunti b/appunti new file mode 100644 index 0000000..b64b282 --- /dev/null +++ b/appunti @@ -0,0 +1,7 @@ +scalability + +soundness of translated nn + +compatibility with other types of network + +comparison with other tool diff --git a/prover.py b/prover.py index fb9658a..d6d1fcc 100644 --- a/prover.py +++ b/prover.py @@ -28,6 +28,9 @@ context = { def equivalence(net_a, net_b): solver = z3.Solver() + for sym in syms.values(): + solver.add(z3.Or(sym == 0, sym == 1)) + solver.add(net_a != net_b) result = solver.check() @@ -43,6 +46,9 @@ def equivalence(net_a, net_b): def epsilon_equivalence(net_a, net_b, epsilon): solver = z3.Solver() + + for sym in syms.values(): + solver.add(z3.Or(sym == 0, sym == 1)) solver.add(z3.Abs(net_a - net_b) > epsilon) @@ -59,13 +65,16 @@ def epsilon_equivalence(net_a, net_b, epsilon): def argmax_equivalence(net_a, net_b): solver = z3.Solver() + + for sym in syms.values(): + solver.add(z3.Or(sym == 0, sym == 1)) solver.add((net_a > 0.5) != (net_b > 0.5)) result = solver.check() if result == z3.unsat: - print("VERIFIED: The networks are argmax equivalent (binary).") + print("VERIFIED: The networks are argmax equivalent.") elif result == z3.sat: print("FAILED: The networks are classification-different.") print("Counter-example input:") @@ -92,7 +101,7 @@ if __name__ == "__main__": print("\nStrict Equivalence") equivalence(net_a, net_b) print("\nEpsilon-Equivalence") - epsilon_equivalence(net_a, net_b, 1e-5) + epsilon_equivalence(net_a, net_b, 1e-2) print("\nARGMAX Equivalence") argmax_equivalence(net_a, net_b) -- cgit v1.2.3