aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorericmarin <maarin.eric@gmail.com>2026-03-11 16:07:04 +0100
committerericmarin <maarin.eric@gmail.com>2026-03-11 16:07:04 +0100
commitfb544e2089e0c52bd83ffe56f2f4e8d7176564ee (patch)
tree465c8a957baa001ea3e4733774b380dc2cc79e22
parent8619ee7a61bafb8c401087508b886e37779be07b (diff)
downloadvein-fb544e2089e0c52bd83ffe56f2f4e8d7176564ee.tar.gz
vein-fb544e2089e0c52bd83ffe56f2f4e8d7176564ee.zip
added constraint to prover
-rw-r--r--appunti7
-rw-r--r--prover.py13
2 files changed, 18 insertions, 2 deletions
diff --git a/appunti b/appunti
new file mode 100644
index 0000000..b64b282
--- /dev/null
+++ b/appunti
@@ -0,0 +1,7 @@
+scalability
+
+soundness of translated nn
+
+compatibility with other types of network
+
+comparison with other tool
diff --git a/prover.py b/prover.py
index fb9658a..d6d1fcc 100644
--- a/prover.py
+++ b/prover.py
@@ -28,6 +28,9 @@ context = {
def equivalence(net_a, net_b):
solver = z3.Solver()
+ for sym in syms.values():
+ solver.add(z3.Or(sym == 0, sym == 1))
+
solver.add(net_a != net_b)
result = solver.check()
@@ -43,6 +46,9 @@ def equivalence(net_a, net_b):
def epsilon_equivalence(net_a, net_b, epsilon):
solver = z3.Solver()
+
+ for sym in syms.values():
+ solver.add(z3.Or(sym == 0, sym == 1))
solver.add(z3.Abs(net_a - net_b) > epsilon)
@@ -59,13 +65,16 @@ def epsilon_equivalence(net_a, net_b, epsilon):
def argmax_equivalence(net_a, net_b):
solver = z3.Solver()
+
+ for sym in syms.values():
+ solver.add(z3.Or(sym == 0, sym == 1))
solver.add((net_a > 0.5) != (net_b > 0.5))
result = solver.check()
if result == z3.unsat:
- print("VERIFIED: The networks are argmax equivalent (binary).")
+ print("VERIFIED: The networks are argmax equivalent.")
elif result == z3.sat:
print("FAILED: The networks are classification-different.")
print("Counter-example input:")
@@ -92,7 +101,7 @@ if __name__ == "__main__":
print("\nStrict Equivalence")
equivalence(net_a, net_b)
print("\nEpsilon-Equivalence")
- epsilon_equivalence(net_a, net_b, 1e-5)
+ epsilon_equivalence(net_a, net_b, 1e-2)
print("\nARGMAX Equivalence")
argmax_equivalence(net_a, net_b)