aboutsummaryrefslogtreecommitdiff
path: root/parser.py
blob: e66f9d2295d39486a1de91be8ea021f945c66562 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
class Concrete:
   def __init__(self, val): self.val = val
   def __str__(self): return str(self.val)

class Symbolic:
   def __init__(self, id): self.id = id
   def __str__(self): return f"x_{self.id}"

class TermAdd:
   def __init__(self, a, b): self.a, self.b = a, b
   def __str__(self): return f"(+ {self.a} {self.b})"

class TermMul:
   def __init__(self, a, b): self.a, self.b = a, b
   def __str__(self): return f"(* {self.a} {self.b})"

class TermReLU:
   def __init__(self, a): self.a = a
   def __str__(self): return f"(ite (> {self.a} 0) {self.a} 0)"

def generate_z3(net_a_str, net_b_str, epsilon=1e-5):
   context = {
       'Concrete': Concrete,
       'Symbolic': Symbolic,
       'TermAdd': TermAdd,
       'TermMul': TermMul,
       'TermReLU': TermReLU
   }

   try:
       tree_a = eval(net_a_str, context)
       tree_b = eval(net_b_str, context)
   except Exception as e:
       print(f"; Error parsing Inpla output: {e}")
       return

   print("(declare-const x_0 Real)") 
   print("(declare-const x_1 Real)")

   print(f"(define-fun net_a () Real {tree_a})")
   print(f"(define-fun net_b () Real {tree_b})")

   print(f"(assert (> (abs (- net_a net_b)) {epsilon:.10f}))")
   
   print("(check-sat)")
   print("(get-model)")

if __name__ == "__main__":
   output_net_a = "TermAdd(TermMul(Symbolic(0), Concrete(2)), Concrete(3))" # 2x + 3
   output_net_b = "TermAdd(Concrete(3), TermMul(Concrete(2), Symbolic(0)))" # 3 + 2x
   
   generate_z3(output_net_a, output_net_b)