blob: 90623a0ef741a9777340de361984a7c4765a4485 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
|
import z3
import nneq
def check_property(onnx_a, onnx_b, vnnlib):
solver = nneq.Solver()
print(f"--- Checking {vnnlib} ---")
solver.load_onnx(onnx_a)
solver.load_onnx(onnx_b)
solver.load_vnnlib(vnnlib)
result = solver.check()
if result == z3.unsat:
print("VERIFIED (UNSAT): The networks are equivalent under this property.")
elif result == z3.sat:
print("FAILED (SAT): The networks are NOT equivalent.")
print("Counter-example input:")
print(solver.model())
# m = solver.model()
# sorted_symbols = sorted([s for s in m.decls() if s.name().startswith("X_")], key=lambda s: s.name())
# for s in sorted_symbols:
# print(f" {s.name()} = {m[s]}")
else:
print("UNKNOWN")
print("")
if __name__ == "__main__":
check_property("./xor/xor_a.onnx", "./xor/xor_b.onnx", "./xor/xor_strict.vnnlib")
check_property("./xor/xor_a.onnx", "./xor/xor_b.onnx", "./xor/xor_epsilon.vnnlib")
check_property("./xor/xor_a.onnx", "./xor/xor_b.onnx", "./xor/xor_argmax.vnnlib")
|