1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
|
// Rules
Linear(x, int q, int r) >< Add(out, b) => b ~ AddCheckLinear(out, x, q, r);
Concrete(int k) >< Add(out, b)
| k == 0 => out ~ b
| _ => b ~ AddCheckConcrete(out, k);
Linear(y, int s, int t) >< AddCheckLinear(out, x, int q, int r)
| q == 0 && r == 0 && s == 0 && t == 0 => out ~ Concrete(0), x ~ Eraser, y ~ Eraser
| s == 0 && t == 0 => Linear(x, q, r) ~ Materialize(out), y ~ Eraser
| q == 0 && r == 0 => (*L)Linear(y, s, t) ~ Materialize(out), x ~ Eraser
| _ => Linear(x, q, r) ~ Materialize(out_x), (*L)Linear(y, s, r) ~ Materialize(out_y), out ~ TermAdd(out_x, out_y);
Concrete(int j) >< AddCheckLinear(out, x, int q, int r) => out ~ Linear(x, q, r + j);
Linear(y, int s, int t) >< AddCheckConcrete(out, int k) => out ~ Linear(y, s, t + k);
Concrete(int j) >< AddCheckConcrete(out, int k)
| j == 0 => out ~ Concrete(k)
| _ => out ~ Concrete(k + j);
Linear(x, int q, int r) >< Mul(out, b) => b ~ MulCheck(out, x, q, r);
Concrete(int k) >< Mul(out, b)
| k == 0 => b ~ Eraser, out ~ (*L)Concrete(0)
| k == 1 => out ~ b
| _ => b ~ MulCheckConcrete(out, k);
Linear(y, int s, int t) >< MulCheck(out, x, int q, int r)
| q == 0 && r == 0 && s == 0 && t == 0 => out ~ Concrete(0), x ~ Eraser, y ~ Eraser
| s == 0 && t == 0 => Linear(x, q, r) ~ Materialize(out), y ~ Eraser
| q == 0 && r == 0 => (*L)Linear(y, s, t) ~ Materialize(out), x ~ Eraser
| _ => Linear(x, q, r) ~ Materialize(out_x), (*L)Linear(y, s, r) ~ Materialize(out_y), out ~ TermMul(out_x, out_y);
Concrete(int j) >< MulCheck(out, x, int q, int r) => out ~ Linear(x, q * j, r * j);
Linear(y, int s, int t) >< MulCheckConcrete(out, int k) => out ~ Linear(y, s * k, t * k);
Concrete(int j) >< MulCheckConcrete(out, int k)
| j == 0 => out ~ Concrete(0)
| j == 1 => out ~ Concrete(k)
| _ => out ~ Concrete(k * j);
Linear(x, int q, int r) >< ReLU(out) => (*L)Linear(x, q, r) ~ Materialize(out_x), out ~ TermReLU(out_x);
Concrete(int k) >< ReLU(out)
| k > 0 => out ~ (*L)Concrete(k)
| _ => out ~ Concrete(0);
Linear(x, int q, int r) >< Materialize(out)
| ? => out ~ Concrete(r), x ~ Eraser
| ? => out ~ Symbolic(x)
| ? => out ~ TermAdd(x, Concrete(r))
| ? => out ~ TermMul(Concrete(q), x)
| _ => out ~ TermAdd(TermMul(Concrete(q), x), Concrete(r));
// Net
Linear(X, 1, 0) ~ Add(a, b);
Concrete(0) ~ Mul(b, Linear(Y, 1, 0));
a ~ Materialize(out);
out; // Symbolic(X)
free out a b;
Linear(X, 1, 0) ~ Mul(a, Concrete(0));
a ~ Materialize(out);
out; // Concrete(0)
free out a;
Linear(X, 1, 0) ~ Mul(a, b);
Linear(Y, 1, 0) ~ Add(b, Concrete(1));
out; // Mul(Symbolic(X),Add(Symbolic(Y),Concrete(1)))
free out a b;
Concrete(1) ~ Add(out, b);
Concrete(2) ~ Add(b, c);
Concrete(3) ~ Mul(c, Concrete(2));
out; // Concrete(9)
free out a b c;
Linear(X, 1, 0) ~ Mul(a, Linear(W1, 1, 0));
Linear(Y, 1, 0) ~ Mul(b, Linear(W2, 1, 0));
b ~ Add(c, Linear(B, 1, 0));
a ~ Add(d, c);
d ~ ReLU(out);
out; // ReLU(Add(Mul(Symbolic(X),Symbolic(W1)),Add(Mul(Symbolic(Y),Symbolic(W2)),Symbolic(B))))
free out a b c d;
Linear(X, 2, 10) ~ Materialize(out);
out;
|