aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorericmarin <maarin.eric@gmail.com>2026-03-10 16:29:33 +0100
committerericmarin <maarin.eric@gmail.com>2026-03-10 17:34:09 +0100
commitaf2b13214579d78827392e762149fa8824526aa9 (patch)
tree81dd6bb8e442668cecff72ec20bcc77e2e9f5c2d
parent7933d744e06337f1d69b7da83f2cee1611556097 (diff)
downloadvein-af2b13214579d78827392e762149fa8824526aa9.tar.gz
vein-af2b13214579d78827392e762149fa8824526aa9.zip
added MLP
-rw-r--r--rules.in71
-rw-r--r--xor.in (renamed from nn.in)69
-rw-r--r--xor.py161
3 files changed, 266 insertions, 35 deletions
diff --git a/rules.in b/rules.in
new file mode 100644
index 0000000..2ff92dd
--- /dev/null
+++ b/rules.in
@@ -0,0 +1,71 @@
+// Agents
+// Built-in
+// Eraser: delete other agents recursively
+// Dup: duplicates other agents recursively
+
+// Implemented
+// Linear(x, int q, int r): represent "q*x + r"
+// Concrete(int k): represent a concrete value k
+// Symbolic(id): represent the variable id
+// Add(out, b): represent the addition (has various steps AddCheckLinear/AddCheckConcrete)
+// Mul(out, b): represent the multiplication (has various steps MulCheckLinear/MulCheckConcrete)
+// ReLU(out): represent "if x > 0 ? x ; 0"
+// Materialize(out): transforms a Linear packet into a final representation of TermAdd/TermMul/TermReLU
+
+// TODO: add range information to enable ReLU elimination
+
+// Rules
+Linear(x, int q, int r) >< Add(out, b) => b ~ AddCheckLinear(out, x, q, r);
+
+Concrete(int k) >< Add(out, b)
+ | k == 0 => out ~ b
+ | _ => b ~ AddCheckConcrete(out, k);
+
+Linear(y, int s, int t) >< AddCheckLinear(out, x, int q, int r)
+ | (q == 0) && (r == 0) && (s == 0) && (t == 0) => out ~ Concrete(0), x ~ Eraser, y ~ Eraser
+ | (s == 0) && (t == 0) => Linear(x, q, r) ~ Materialize(out), y ~ Eraser
+ | (q == 0) && (r == 0) => (*L)Linear(y, s, t) ~ Materialize(out), x ~ Eraser
+ | _ => Linear(x, q, r) ~ Materialize(out_x), (*L)Linear(y, s, t) ~ Materialize(out_y), out ~ Linear(TermAdd(out_x, out_y), 1, 0);
+
+Concrete(int j) >< AddCheckLinear(out, x, int q, int r) => out ~ Linear(x, q, r + j);
+
+Linear(y, int s, int t) >< AddCheckConcrete(out, int k) => out ~ Linear(y, s, t + k);
+
+Concrete(int j) >< AddCheckConcrete(out, int k)
+ | j == 0 => out ~ Concrete(k)
+ | _ => out ~ Concrete(k + j);
+
+Linear(x, int q, int r) >< Mul(out, b) => b ~ MulCheckLinear(out, x, q, r);
+
+Concrete(int k) >< Mul(out, b)
+ | k == 0 => b ~ Eraser, out ~ (*L)Concrete(0)
+ | k == 1 => out ~ b
+ | _ => b ~ MulCheckConcrete(out, k);
+
+Linear(y, int s, int t) >< MulCheckLinear(out, x, int q, int r)
+ | (q == 0) && (r == 0) && (s == 0) && (t == 0) => out ~ Concrete(0), x ~ Eraser, y ~ Eraser
+ | (s == 0) && (t == 0) => Linear(x, q, r) ~ Materialize(out), y ~ Eraser
+ | (q == 0) && (r == 0) => (*L)Linear(y, s, t) ~ Materialize(out), x ~ Eraser
+ | _ => Linear(x, q, r) ~ Materialize(out_x), (*L)Linear(y, s, t) ~ Materialize(out_y), out ~ Linear(TermMul(out_x, out_y), 1, 0);
+
+Concrete(int j) >< MulCheckLinear(out, x, int q, int r) => out ~ Linear(x, q * j, r * j);
+
+Linear(y, int s, int t) >< MulCheckConcrete(out, int k) => out ~ Linear(y, s * k, t * k);
+
+Concrete(int j) >< MulCheckConcrete(out, int k)
+ | j == 0 => out ~ Concrete(0)
+ | j == 1 => out ~ Concrete(k)
+ | _ => out ~ Concrete(k * j);
+
+Linear(x, int q, int r) >< ReLU(out) => (*L)Linear(x, q, r) ~ Materialize(out_x), out ~ Linear(TermReLU(out_x), 1, 0);
+
+Concrete(int k) >< ReLU(out)
+ | k > 0 => out ~ (*L)Concrete(k)
+ | _ => out ~ Concrete(0);
+
+Linear(x, int q, int r) >< Materialize(out)
+ | (q == 0) => out ~ Concrete(r), x ~ Eraser
+ | (q == 1) && (r == 0) => out ~ x
+ | (q == 1) && (r != 0) => out ~ TermAdd(x, Concrete(r))
+ | (q != 0) && (r == 0) => out ~ TermMul(Concrete(q), x)
+ | _ => out ~ TermAdd(TermMul(Concrete(q), x), Concrete(r));
diff --git a/nn.in b/xor.in
index 3cc9e09..b7e3b6d 100644
--- a/nn.in
+++ b/xor.in
@@ -64,44 +64,43 @@ Concrete(int k) >< ReLU(out)
| _ => out ~ Concrete(0);
Linear(x, int q, int r) >< Materialize(out)
- | (q == 0) && (r == 0) => out ~ Concrete(r), x ~ Eraser
+ | (q == 0) => out ~ Concrete(r), x ~ Eraser
| (q == 1) && (r == 0) => out ~ x
| (q == 1) && (r != 0) => out ~ TermAdd(x, Concrete(r))
| (q != 0) && (r == 0) => out ~ TermMul(Concrete(q), x)
| _ => out ~ TermAdd(TermMul(Concrete(q), x), Concrete(r));
-// Net testing
-Linear(Symbolic(X), 1, 0) ~ Add(a, b);
-Concrete(0) ~ Mul(b, Linear(Symbolic(Y), 1, 0));
-a ~ Materialize(out);
-out; // Symbolic(X)
-free out a b;
-
-Linear(Symbolic(X), 1, 0) ~ Mul(a, Concrete(0));
-a ~ Materialize(out);
-out; // Concrete(0)
-free out a;
-
-Linear(Symbolic(X), 1, 0) ~ Mul(a, b);
-Linear(Symbolic(Y), 1, 0) ~ Add(b, Concrete(1));
-a ~ Materialize(out);
-out; // Mul(Symbolic(X),Add(Symbolic(Y),Concrete(1)))
-free out a b;
-
-Concrete(1) ~ Add(out, b);
-Concrete(2) ~ Add(b, c);
-Concrete(3) ~ Mul(c, Concrete(2));
-out; // Concrete(9)
-free out a b c;
-
-Linear(Symbolic(X), 1, 0) ~ Mul(a, Linear(Symbolic(W1), 1, 0));
-Linear(Symbolic(Y), 1, 0) ~ Mul(b, Linear(Symbolic(W2), 1, 0));
-b ~ Add(c, Linear(Symbolic(B), 1, 0));
-a ~ Add(d, c);
-d ~ ReLU(Materialize(out));
-out; // ReLU(Add(Mul(Symbolic(X),Symbolic(W1)),Add(Mul(Symbolic(Y),Symbolic(W2)),Symbolic(B))))
-free out a b c d;
-
-Linear(Symbolic(X), 2, 10) ~ Materialize(out);
-out;
+// Wiring
+Dup(v0, Dup(v1, Dup(v2, v3))) ~ Linear(Symbolic(0), 1, 0);
+Mul(v4, Concrete(-693)) ~ v0;
+Add(v5, v4) ~ Concrete(-692);
+Mul(v6, Concrete(-78)) ~ v1;
+Add(v7, v6) ~ Concrete(916);
+Mul(v8, Concrete(235)) ~ v2;
+Add(v9, v8) ~ Concrete(-424);
+Mul(v10, Concrete(181)) ~ v3;
+Add(v11, v10) ~ Concrete(202);
+Dup(v12, Dup(v13, Dup(v14, v15))) ~ Linear(Symbolic(1), 1, 0);
+Mul(v16, Concrete(-674)) ~ v12;
+Add(v17, v16) ~ v5;
+Mul(v18, Concrete(-97)) ~ v13;
+Add(v19, v18) ~ v7;
+Mul(v20, Concrete(-572)) ~ v14;
+Add(v21, v20) ~ v9;
+Mul(v22, Concrete(224)) ~ v15;
+Add(v23, v22) ~ v11;
+ReLU(v24) ~ v17;
+ReLU(v25) ~ v19;
+ReLU(v26) ~ v21;
+ReLU(v27) ~ v23;
+Mul(v28, Concrete(-318)) ~ v24;
+Add(v29, v28) ~ Concrete(-89);
+Mul(v30, Concrete(587)) ~ v25;
+Add(v31, v30) ~ v29;
+Mul(v32, Concrete(-250)) ~ v26;
+Add(v33, v32) ~ v31;
+Mul(v34, Concrete(254)) ~ v27;
+Add(v35, v34) ~ v33;
+Materialize(result0) ~ v35;
+result0; \ No newline at end of file
diff --git a/xor.py b/xor.py
new file mode 100644
index 0000000..b2082c7
--- /dev/null
+++ b/xor.py
@@ -0,0 +1,161 @@
+import torch
+import torch.nn as nn
+import torch.fx as fx
+import numpy as np
+import os
+from typing import List, Dict
+
+class XOR_MLP(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.layers = nn.Sequential(
+ nn.Linear(2, 4),
+ nn.ReLU(),
+ nn.Linear(4, 1)
+ )
+ def forward(self, x):
+ return self.layers(x)
+
+class NameGen:
+ def __init__(self):
+ self.counter = 0
+ def next(self) -> str:
+ name = f"v{self.counter}"
+ self.counter += 1
+ return name
+
+def get_rules() -> str:
+ rules_path = os.path.join(os.path.dirname(__file__), "rules.in")
+ if not os.path.exists(rules_path):
+ return "// Rules not found in rules.in\n"
+
+ rules_lines = []
+ with open(rules_path, "r") as f:
+ for line in f:
+ if "// Net testing" in line:
+ break
+ rules_lines.append(line)
+ return "".join(rules_lines)
+
+def export_to_inpla(model: nn.Module, input_shape: tuple, scale: int = 1000) -> str:
+ traced = fx.symbolic_trace(model)
+ name_gen = NameGen()
+ script: List[str] = []
+ wire_map: Dict[str, List[str]] = {}
+
+ for node in traced.graph.nodes:
+ if node.op == 'placeholder':
+ num_inputs = int(np.prod(input_shape))
+ wire_map[node.name] = [f"Linear(Symbolic({i}), 1, 0)" for i in range(num_inputs)]
+
+ elif node.op == 'call_module':
+ target_str = str(node.target)
+ module = dict(model.named_modules())[target_str]
+
+ input_node = node.args[0]
+ if not isinstance(input_node, fx.Node):
+ continue
+ input_wires = wire_map[input_node.name]
+
+ if isinstance(module, nn.Flatten):
+ wire_map[node.name] = input_wires
+
+ elif isinstance(module, nn.Linear):
+ W = (module.weight.data.detach().cpu().numpy() * scale).astype(int)
+ B = (module.bias.data.detach().cpu().numpy() * scale).astype(int)
+ out_dim, in_dim = W.shape
+
+ neuron_wires = [f"Concrete({B[j]})" for j in range(out_dim)]
+
+ for i in range(in_dim):
+ in_term = input_wires[i]
+ if out_dim == 1:
+ weight = int(W[0, i])
+ if weight == 0:
+ script.append(f"Eraser ~ {in_term};")
+ elif weight == 1:
+ new_s = name_gen.next()
+ script.append(f"Add({new_s}, {in_term}) ~ {neuron_wires[0]};")
+ neuron_wires[0] = new_s
+ else:
+ mul_out = name_gen.next()
+ new_s = name_gen.next()
+ script.append(f"Mul({mul_out}, Concrete({weight})) ~ {in_term};")
+ script.append(f"Add({new_s}, {mul_out}) ~ {neuron_wires[0]};")
+ neuron_wires[0] = new_s
+ else:
+ branch_wires = [name_gen.next() for _ in range(out_dim)]
+
+ def nest_dups(names: List[str]) -> str:
+ if len(names) == 1: return names[0]
+ if len(names) == 2: return f"Dup({names[0]}, {names[1]})"
+ return f"Dup({names[0]}, {nest_dups(names[1:])})"
+
+ script.append(f"{nest_dups(branch_wires)} ~ {in_term};")
+
+ for j in range(out_dim):
+ weight = int(W[j, i])
+ if weight == 0:
+ script.append(f"Eraser ~ {branch_wires[j]};")
+ elif weight == 1:
+ new_s = name_gen.next()
+ script.append(f"Add({new_s}, {branch_wires[j]}) ~ {neuron_wires[j]};")
+ neuron_wires[j] = new_s
+ else:
+ mul_out = name_gen.next()
+ new_s = name_gen.next()
+ script.append(f"Mul({mul_out}, Concrete({weight})) ~ {branch_wires[j]};")
+ script.append(f"Add({new_s}, {mul_out}) ~ {neuron_wires[j]};")
+ neuron_wires[j] = new_s
+
+ wire_map[node.name] = neuron_wires
+
+ elif isinstance(module, nn.ReLU):
+ input_wires = wire_map[node.args[0].name]
+ output_wires = []
+ for i, w in enumerate(input_wires):
+ r_out = name_gen.next()
+ script.append(f"ReLU({r_out}) ~ {w};")
+ output_wires.append(r_out)
+ wire_map[node.name] = output_wires
+
+ elif node.op == 'output':
+ output_node = node.args[0]
+ if isinstance(output_node, fx.Node):
+ final_wires = wire_map[output_node.name]
+ for i, w in enumerate(final_wires):
+ res_name = f"result{i}"
+ script.append(f"Materialize({res_name}) ~ {w};")
+ script.append(f"{res_name};")
+
+ rules = get_rules()
+ return rules + "\n\n// Wiring\n" + "\n".join(script)
+
+if __name__ == "__main__":
+ X = torch.tensor([[0,0], [0,1], [1,0], [1,1]], dtype=torch.float32)
+ Y = torch.tensor([[0], [1], [1], [0]], dtype=torch.float32)
+
+ net = XOR_MLP()
+ loss_fn = nn.MSELoss()
+ optimizer = torch.optim.Adam(net.parameters(), lr=0.01)
+
+ print("Training XOR MLP...")
+ for epoch in range(1000):
+ optimizer.zero_grad()
+ out = net(X)
+ loss = loss_fn(out, Y)
+ loss.backward()
+ optimizer.step()
+ if (epoch+1) % 200 == 0:
+ print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")
+
+ print("\nTraining Finished. Predictions:")
+ with torch.no_grad():
+ print(net(X).numpy())
+
+ print("\nExporting XOR to Inpla...")
+ net.eval()
+ inpla_script = export_to_inpla(net, (2,))
+ with open("xor.in", "w") as f:
+ f.write(inpla_script)
+ print("Exported to xor.in")