aboutsummaryrefslogtreecommitdiff
path: root/nn.in
blob: fb9a863d228aafe9f351ab301fbd9c48d01527d6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
// Rules
Symbolic(id) >< Add(out, b) => b ~ AddCheck(out, (*L)Symbolic(id));
Concrete(int k) >< Add(out, b)
	| k == 0 => out ~ b
	| _ => b ~ AddCheckConcrete(out, k);
other >< Add(out, b) => out ~ Add((*L)other, b);

Symbolic(id) >< AddCheck(out, a) => out ~ Add(a, (*L)Symbolic(id));
Concrete(int j) >< AddCheck(out, a)
	| j == 0 => out ~ a
	| _ => out ~ Add(a, (*L)Concrete(j));
other >< AddCheck(out, a) => out ~ Add(a, (*L)other);

Symbolic(id) >< AddCheckConcrete(out, int k) => out ~ Add(Concrete(k), (*L) Symbolic(id));
Concrete(int j) >< AddCheckConcrete(out, int k)
	| j == 0 => out ~ Concrete(k)
	| _ => out ~ Concrete(k + j);
other >< AddCheckConcrete(out, int k) => out ~ Add(Concrete(k), (*L)other);

Symbolic(id) >< Mul(out, b) => b ~ MulCheck(out, (*L)Symbolic(id));
Concrete(int k) >< Mul(out, b)
	| k == 0 => b ~ Eraser, out ~ (*L)Concrete(0)
	| k == 1 => out ~ b
	| _ => b ~ MulCheckConcrete(out, k);
other >< Mul(out, b) => out ~ Mul((*L)other, b);

Symbolic(id) >< MulCheck(out, a) => out ~ Mul(a, (*L)Symbolic(id));
Concrete(int j) >< MulCheck(out, a)
  | j == 0 => a ~ Eraser, out ~ (*L)Concrete(0)
	| j == 1 => out ~ a
	| _ => out ~ Mul(a, (*L)Concrete(j));
other >< MulCheck(out, a) => out ~ Mul(a, (*L)other);

Symbolic(id) >< MulCheckConcrete(out, int k) => out ~ Mul(Concrete(k), (*L)Symbolic(id));
Concrete(int j) >< MulCheckConcrete(out, int k)
  | j == 0 => out ~ Concrete(0)
	| j == 1 => out ~ Concrete(k)
	| _ => out ~ Concrete(k * j);
other >< MulCheckConcrete(out, int k) => out ~ Mul(Concrete(k), (*L)other);

Symbolic(id) >< ReLU(out) => out ~ ReLU((*L)Symbolic(id));
Concrete(int k) >< ReLU(out)
	| k > 0 => out ~ (*L)Concrete(k)
	| _ => out ~ Concrete(0);
other >< ReLU(out) => out ~ ReLU((*L)other);

// Net
Symbolic(X) ~ Add(out, b);
Concrete(0) ~ Mul(b, Symbolic(Y));
out;

free out;
Symbolic(X) ~ Mul(out, Concrete(0));
out;

free out;
Symbolic(X) ~ Mul(out, b);
Symbolic(Y) ~ Add(b, Concrete(1));
out;

free out;
Concrete(1) ~ Add(out, b);
Concrete(2) ~ Add(b, c);
Concrete(3) ~ Mul(c, Concrete(2));
out;

free out;
Symbolic(X) ~ Mul(a, Symbolic(W1));
Symbolic(Y) ~ Mul(b, Symbolic(W2));
b ~ Add(c, Symbolic(B));
a ~ Add(d, c);
d ~ ReLU(out);
out;

free out;
Symbolic(X) ~ Dup(a, b);
a ~ Add(c, Concrete(1));
b ~ Mul(out, c);
out;