1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
|
// Agents
// Built-in
// Eraser: delete other agents recursively
// Dup: duplicates other agents recursively
// Implemented
// Linear(x, int q, int r): represent "q*x + r"
// Concrete(int k): represent a concrete value k
// Symbolic(id): represent the variable id
// Add(out, b): represent the addition (has various steps AddCheckLinear/AddCheckConcrete)
// Mul(out, b): represent the multiplication (has various steps MulCheckLinear/MulCheckConcrete)
// ReLU(out): represent "if x > 0 ? x ; 0"
// Materialize(out): transforms a Linear packet into a final representation of TermAdd/TermMul/TermReLU
// TODO: add range information to enable ReLU elimination
// Rules
Linear(x, int q, int r) >< Add(out, b) => b ~ AddCheckLinear(out, x, q, r);
Concrete(int k) >< Add(out, b)
| k == 0 => out ~ b
| _ => b ~ AddCheckConcrete(out, k);
Linear(y, int s, int t) >< AddCheckLinear(out, x, int q, int r)
| (q == 0) && (r == 0) && (s == 0) && (t == 0) => out ~ Concrete(0), x ~ Eraser, y ~ Eraser
| (s == 0) && (t == 0) => Linear(x, q, r) ~ Materialize(out), y ~ Eraser
| (q == 0) && (r == 0) => (*L)Linear(y, s, t) ~ Materialize(out), x ~ Eraser
| _ => Linear(x, q, r) ~ Materialize(out_x), (*L)Linear(y, s, t) ~ Materialize(out_y), out ~ Linear(TermAdd(out_x, out_y), 1, 0);
Concrete(int j) >< AddCheckLinear(out, x, int q, int r) => out ~ Linear(x, q, r + j);
Linear(y, int s, int t) >< AddCheckConcrete(out, int k) => out ~ Linear(y, s, t + k);
Concrete(int j) >< AddCheckConcrete(out, int k)
| j == 0 => out ~ Concrete(k)
| _ => out ~ Concrete(k + j);
Linear(x, int q, int r) >< Mul(out, b) => b ~ MulCheckLinear(out, x, q, r);
Concrete(int k) >< Mul(out, b)
| k == 0 => b ~ Eraser, out ~ (*L)Concrete(0)
| k == 1 => out ~ b
| _ => b ~ MulCheckConcrete(out, k);
Linear(y, int s, int t) >< MulCheckLinear(out, x, int q, int r)
| (q == 0) && (r == 0) && (s == 0) && (t == 0) => out ~ Concrete(0), x ~ Eraser, y ~ Eraser
| (s == 0) && (t == 0) => Linear(x, q, r) ~ Materialize(out), y ~ Eraser
| (q == 0) && (r == 0) => (*L)Linear(y, s, t) ~ Materialize(out), x ~ Eraser
| _ => Linear(x, q, r) ~ Materialize(out_x), (*L)Linear(y, s, t) ~ Materialize(out_y), out ~ Linear(TermMul(out_x, out_y), 1, 0);
Concrete(int j) >< MulCheckLinear(out, x, int q, int r) => out ~ Linear(x, q * j, r * j);
Linear(y, int s, int t) >< MulCheckConcrete(out, int k) => out ~ Linear(y, s * k, t * k);
Concrete(int j) >< MulCheckConcrete(out, int k)
| j == 0 => out ~ Concrete(0)
| j == 1 => out ~ Concrete(k)
| _ => out ~ Concrete(k * j);
Linear(x, int q, int r) >< ReLU(out) => (*L)Linear(x, q, r) ~ Materialize(out_x), out ~ Linear(TermReLU(out_x), 1, 0);
Concrete(int k) >< ReLU(out)
| k > 0 => out ~ (*L)Concrete(k)
| _ => out ~ Concrete(0);
Linear(x, int q, int r) >< Materialize(out)
| (q == 0) => out ~ Concrete(r), x ~ Eraser
| (q == 1) && (r == 0) => out ~ x
| (q == 1) && (r != 0) => out ~ TermAdd(x, Concrete(r))
| (q != 0) && (r == 0) => out ~ TermMul(Concrete(q), x)
| _ => out ~ TermAdd(TermMul(Concrete(q), x), Concrete(r));
// Network A
Dup(v0, Dup(v1, Dup(v2, v3))) ~ Linear(Symbolic(0), 1, 0);
Mul(v4, Concrete(865)) ~ v0;
Add(v5, v4) ~ Concrete(0);
Mul(v6, Concrete(1029)) ~ v1;
Add(v7, v6) ~ Concrete(0);
Mul(v8, Concrete(1087)) ~ v2;
Add(v9, v8) ~ Concrete(1086);
Mul(v10, Concrete(676)) ~ v3;
Add(v11, v10) ~ Concrete(-693);
Dup(v12, Dup(v13, Dup(v14, v15))) ~ Linear(Symbolic(1), 1, 0);
Mul(v16, Concrete(-865)) ~ v12;
Add(v17, v16) ~ v5;
Mul(v18, Concrete(-1029)) ~ v13;
Add(v19, v18) ~ v7;
Mul(v20, Concrete(-1087)) ~ v14;
Add(v21, v20) ~ v9;
Mul(v22, Concrete(-378)) ~ v15;
Add(v23, v22) ~ v11;
ReLU(v24) ~ v17;
ReLU(v25) ~ v19;
ReLU(v26) ~ v21;
ReLU(v27) ~ v23;
Mul(v28, Concrete(1153)) ~ v24;
Add(v29, v28) ~ Concrete(1000);
Mul(v30, Concrete(974)) ~ v25;
Add(v31, v30) ~ v29;
Mul(v32, Concrete(-920)) ~ v26;
Add(v33, v32) ~ v31;
Mul(v34, Concrete(367)) ~ v27;
Add(v35, v34) ~ v33;
Materialize(result0) ~ v35;
result0;
free ifce;
// Network B
Dup(v0, Dup(v1, Dup(v2, v3))) ~ Linear(Symbolic(0), 1, 0);
Mul(v4, Concrete(-238)) ~ v0;
Add(v5, v4) ~ Concrete(-704);
Mul(v6, Concrete(-111)) ~ v1;
Add(v7, v6) ~ Concrete(-515);
Mul(v8, Concrete(-1232)) ~ v2;
Add(v9, v8) ~ Concrete(-8);
Mul(v10, Concrete(1113)) ~ v3;
Add(v11, v10) ~ Concrete(189);
Dup(v12, Dup(v13, Dup(v14, v15))) ~ Linear(Symbolic(1), 1, 0);
Mul(v16, Concrete(639)) ~ v12;
Add(v17, v16) ~ v5;
Mul(v18, Concrete(66)) ~ v13;
Add(v19, v18) ~ v7;
Mul(v20, Concrete(1226)) ~ v14;
Add(v21, v20) ~ v9;
Mul(v22, Concrete(-1113)) ~ v15;
Add(v23, v22) ~ v11;
ReLU(v24) ~ v17;
ReLU(v25) ~ v19;
ReLU(v26) ~ v21;
ReLU(v27) ~ v23;
Mul(v28, Concrete(111)) ~ v24;
Add(v29, v28) ~ Concrete(-170);
Mul(v30, Concrete(239)) ~ v25;
Add(v31, v30) ~ v29;
Mul(v32, Concrete(961)) ~ v26;
Add(v33, v32) ~ v31;
Mul(v34, Concrete(897)) ~ v27;
Add(v35, v34) ~ v33;
Materialize(result0) ~ v35;
result0;
|