aboutsummaryrefslogtreecommitdiff
path: root/examples/verify_fashion_mnist.py
blob: 4de0d1e62833846e56c82ec0ea70a77e2c78cc22 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
# Copyright (C) 2026 Eric Marin
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.

import sys, os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import z3
import nneq

def check_property(onnx_a, onnx_b, vnnlib):
    solver = nneq.Solver()
    
    print(f"--- Checking {vnnlib} ---")
    
    solver.load_onnx(onnx_a)
    solver.load_onnx(onnx_b)
    solver.load_vnnlib(vnnlib)
    
    result = solver.check()
    
    if result == z3.unsat:
        print("VERIFIED (UNSAT): The networks are equivalent under this property.")
    elif result == z3.sat:
        print("FAILED (SAT): The networks are NOT equivalent.")
        print("Counter-example input:")
        print(solver.model())
        # m = solver.model()
        # sorted_symbols = sorted([s for s in m.decls() if s.name().startswith("X_")], key=lambda s: s.name())
        # for s in sorted_symbols:
            # print(f"  {s.name()} = {m[s]}")
    else:
        print("UNKNOWN")
    print("")

if __name__ == "__main__":
    check_property("./examples/fashion_mnist/fashion_mnist_a.onnx", "./examples/fashion_mnist/fashion_mnist_b.onnx", "./examples/fashion_mnist/fashion_mnist_strict.vnnlib")
    check_property("./examples/fashion_mnist/fashion_mnist_a.onnx", "./examples/fashion_mnist/fashion_mnist_b.onnx", "./examples/fashion_mnist/fashion_mnist_epsilon.vnnlib")
    check_property("./examples/fashion_mnist/fashion_mnist_a.onnx", "./examples/fashion_mnist/fashion_mnist_b.onnx", "./examples/fashion_mnist/fashion_mnist_argmax.vnnlib")