1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
|
import z3
syms = {}
def Symbolic(id):
id = f"x_{id}"
if id not in syms:
syms[id] = z3.Real(id)
return syms[id]
def Concrete(val): return z3.RealVal(val)
def TermAdd(a, b): return a + b
def TermMul(a, b): return a * b
def TermReLU(x): return z3.If(x > 0, x, 0)
context = {
'Concrete': Concrete,
'Symbolic': Symbolic,
'TermAdd': TermAdd,
'TermMul': TermMul,
'TermReLU': TermReLU
}
def equivalence(net_a, net_b):
solver = z3.Solver()
solver.add(net_a != net_b)
result = solver.check()
if result == z3.unsat:
print("VERIFIED: The networks are equivalent.")
elif result == z3.sat:
print("FAILED: The networks are different.")
print("Counter-example input:")
print(solver.model())
else:
print("UNKNOWN: Solver could not decide.")
def epsilon_equivalence(net_a, net_b, epsilon):
solver = z3.Solver()
solver.add(z3.Abs(net_a - net_b) > epsilon)
result = solver.check()
if result == z3.unsat:
print(f"VERIFIED: The networks are epsilon equivalent, with epsilon={epsilon}.")
elif result == z3.sat:
print("FAILED: The networks are different.")
print("Counter-example input:")
print(solver.model())
else:
print("UNKNOWN: Solver could not decide.")
if __name__ == "__main__":
net_a_str = "TermAdd(TermMul(Symbolic(0), Concrete(2)), Concrete(3))" # 2x + 3
net_b_str = "TermAdd(Concrete(3), TermMul(Concrete(2), Symbolic(0)))" # 3 + 2x
try:
net_a = eval(net_a_str, context)
net_b = eval(net_b_str, context)
equivalence(net_a, net_b)
epsilon_equivalence(net_a, net_b, 1e-5)
except Exception as e:
print(f"; Error parsing Inpla output: {e}")
|