aboutsummaryrefslogtreecommitdiff
path: root/prover.py
blob: 7ac8c5dc9108d786fb7400bf92e04d50952b0861 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import z3

syms = {}
def Symbolic(id):
    id = f"x_{id}"
    if id not in syms:
        syms[id] = z3.Real(id)
    return syms[id]

def Concrete(val): return z3.RealVal(val)
def TermAdd(a, b): return a + b
def TermMul(a, b): return a * b
def TermReLU(x): return z3.If(x > 0, x, 0)

context = {
   'Concrete': Concrete,
   'Symbolic': Symbolic,
   'TermAdd': TermAdd,
   'TermMul': TermMul,
   'TermReLU': TermReLU
}

def equivalence(net_a, net_b):
    solver = z3.Solver()

    solver.add(net_a != net_b)

    result = solver.check()

    if result == z3.unsat:
        print("VERIFIED: The networks are equivalent.")
    elif result == z3.sat:
        print("FAILED: The networks are different.")
        print("Counter-example input:")
        print(solver.model())
    else:
        print("UNKNOWN: Solver could not decide.")

def epsilon_equivalence(net_a, net_b, epsilon):
    solver = z3.Solver()
    
    solver.add(z3.Abs(net_a - net_b) > epsilon)

    result = solver.check()

    if result == z3.unsat:
        print(f"VERIFIED: The networks are epsilon equivalent, with epsilon={epsilon}.")
    elif result == z3.sat:
        print("FAILED: The networks are different.")
        print("Counter-example input:")
        print(solver.model())
    else:
        print("UNKNOWN: Solver could not decide.")

if __name__ == "__main__":
   net_a_str = "TermAdd(TermMul(Symbolic(0), Concrete(2)), Concrete(3))" # 2x + 3
   net_b_str = "TermAdd(Concrete(3), TermMul(Concrete(2), Symbolic(0)))" # 3 + 2x

   try:
       net_a = eval(net_a_str, context)
       net_b = eval(net_b_str, context)
       equivalence(net_a, net_b)
       epsilon_equivalence(net_a, net_b, 1e-5)
   except Exception as e:
       print(f"; Error parsing Inpla output: {e}")