aboutsummaryrefslogtreecommitdiff
path: root/verify_example.py
diff options
context:
space:
mode:
authorericmarin <maarin.eric@gmail.com>2026-04-10 15:06:56 +0200
committerericmarin <maarin.eric@gmail.com>2026-04-13 10:55:06 +0200
commit8f4f24523235965cfa2041ed00cc40fc0b4bd367 (patch)
tree716862b4c898431861c27ab69165edfb245467dc /verify_example.py
parent9fb816496d392638fa6981e71800466d71434680 (diff)
downloadvein-8f4f24523235965cfa2041ed00cc40fc0b4bd367.tar.gz
vein-8f4f24523235965cfa2041ed00cc40fc0b4bd367.zip
added MNIST, changed cache and parser
Diffstat (limited to '')
-rw-r--r--verify_example.py79
1 files changed, 79 insertions, 0 deletions
diff --git a/verify_example.py b/verify_example.py
new file mode 100644
index 0000000..1bf00fa
--- /dev/null
+++ b/verify_example.py
@@ -0,0 +1,79 @@
+import sys
+import vein
+
+def check_property(onnx_a, onnx_b, vnnlib):
+ solver = vein.Solver()
+
+ print(f"--- Checking {vnnlib} ---")
+
+ solver.load_onnx(onnx_a)
+ solver.load_onnx(onnx_b)
+ solver.load_vnnlib(vnnlib)
+
+ result = solver.check()
+
+ if result == vein.unsat:
+ print("VERIFIED (UNSAT): The networks are equivalent under this property.", end="\n\n")
+ elif result == vein.sat:
+ print("FAILED (SAT): The networks are NOT equivalent.")
+ print("Counter-example input:")
+ print(solver.model(), end="\n\n")
+ # m = solver.model()
+ # sorted_symbols = sorted([s for s in m.decls() if s.name().startswith("X_")], key=lambda s: s.name())
+ # for s in sorted_symbols:
+ # print(f" {s.name()} = {m[s]}")
+ else:
+ print("UNKNOWN", end="\n\n")
+
+if __name__ == "__main__":
+ if len(sys.argv) <= 1:
+ print("Net not provided")
+ print("Available Nets: 'xor', 'mnist', 'iris', 'acasxu', 'pendulum', 'double_integrator'")
+ sys.exit()
+
+ match sys.argv[1]:
+ case "xor":
+ net_a = "./examples/xor/xor_a.onnx"
+ net_b = "./examples/xor/xor_b.onnx"
+ strict = "./examples/xor/xor_strict.vnnlib"
+ epsilon = "./examples/xor/xor_epsilon.vnnlib"
+ argmax = "./examples/xor/xor_argmax.vnnlib"
+ case "mnist":
+ net_a = "./examples/mnist/mnist_a.onnx"
+ net_b = "./examples/mnist/mnist_b.onnx"
+ strict = "./examples/mnist/mnist_strict.vnnlib"
+ epsilon = "./examples/mnist/mnist_epsilon.vnnlib"
+ argmax = "./examples/mnist/mnist_argmax.vnnlib"
+ case "iris":
+ net_a = "./examples/iris/iris_a.onnx"
+ net_b = "./examples/iris/iris_b.onnx"
+ strict = "./examples/iris/iris_strict.vnnlib"
+ epsilon = "./examples/iris/iris_epsilon.vnnlib"
+ argmax = "./examples/iris/iris_argmax.vnnlib"
+ case "acasxu":
+ net_a = "./examples/ACASXU/ACASXU_run2a_1_1_batch_2000.onnx"
+ net_b = "./examples/ACASXU/ACASXU_run2a_1_1_batch_2000.onnx"
+ strict = "./examples/ACASXU/ACASXU_strict.vnnlib"
+ epsilon = "./examples/ACASXU/ACASXU_epsilon.vnnlib"
+ argmax = "./examples/ACASXU/ACASXU_argmax.vnnlib"
+ case "pendulum":
+ net_a = "./examples/pendulum/pendulum_finetune_con.onnx"
+ net_b = "./examples/pendulum/pendulum_finetune_con.onnx"
+ strict = "./examples/pendulum/pendulum_strict.vnnlib"
+ epsilon = "./examples/pendulum/pendulum_epsilon.vnnlib"
+ argmax = "./examples/pendulum/pendulum_argmax.vnnlib"
+ case "double_integrator":
+ net_a = "./examples/double_integrator/double_integrator_finetune_inv.onnx"
+ net_b = "./examples/double_integrator/double_integrator_finetune_inv.onnx"
+ strict = "./examples/double_integrator/double_integrator_strict.vnnlib"
+ epsilon = "./examples/double_integrator/double_integrator_epsilon.vnnlib"
+ argmax = "./examples/double_integrator/double_integrator_argmax.vnnlib"
+ case _:
+ print("Available Nets: 'xor', 'mnist', 'iris', 'acasxu', 'pendulum', 'double_integrator'")
+ sys.exit()
+
+ print(f"=== Comparing {net_a} and {net_b} ===", end="\n\n")
+
+ check_property(net_a, net_b, strict)
+ check_property(net_a, net_b, epsilon)
+ check_property(net_a, net_b, argmax)