diff options
| author | ericmarin <maarin.eric@gmail.com> | 2026-03-10 16:29:33 +0100 |
|---|---|---|
| committer | ericmarin <maarin.eric@gmail.com> | 2026-03-10 17:34:09 +0100 |
| commit | af2b13214579d78827392e762149fa8824526aa9 (patch) | |
| tree | 81dd6bb8e442668cecff72ec20bcc77e2e9f5c2d | |
| parent | 7933d744e06337f1d69b7da83f2cee1611556097 (diff) | |
| download | vein-af2b13214579d78827392e762149fa8824526aa9.tar.gz vein-af2b13214579d78827392e762149fa8824526aa9.zip | |
added MLP
| -rw-r--r-- | rules.in | 71 | ||||
| -rw-r--r-- | xor.in (renamed from nn.in) | 69 | ||||
| -rw-r--r-- | xor.py | 161 |
3 files changed, 266 insertions, 35 deletions
diff --git a/rules.in b/rules.in new file mode 100644 index 0000000..2ff92dd --- /dev/null +++ b/rules.in @@ -0,0 +1,71 @@ +// Agents +// Built-in +// Eraser: delete other agents recursively +// Dup: duplicates other agents recursively + +// Implemented +// Linear(x, int q, int r): represent "q*x + r" +// Concrete(int k): represent a concrete value k +// Symbolic(id): represent the variable id +// Add(out, b): represent the addition (has various steps AddCheckLinear/AddCheckConcrete) +// Mul(out, b): represent the multiplication (has various steps MulCheckLinear/MulCheckConcrete) +// ReLU(out): represent "if x > 0 ? x ; 0" +// Materialize(out): transforms a Linear packet into a final representation of TermAdd/TermMul/TermReLU + +// TODO: add range information to enable ReLU elimination + +// Rules +Linear(x, int q, int r) >< Add(out, b) => b ~ AddCheckLinear(out, x, q, r); + +Concrete(int k) >< Add(out, b) + | k == 0 => out ~ b + | _ => b ~ AddCheckConcrete(out, k); + +Linear(y, int s, int t) >< AddCheckLinear(out, x, int q, int r) + | (q == 0) && (r == 0) && (s == 0) && (t == 0) => out ~ Concrete(0), x ~ Eraser, y ~ Eraser + | (s == 0) && (t == 0) => Linear(x, q, r) ~ Materialize(out), y ~ Eraser + | (q == 0) && (r == 0) => (*L)Linear(y, s, t) ~ Materialize(out), x ~ Eraser + | _ => Linear(x, q, r) ~ Materialize(out_x), (*L)Linear(y, s, t) ~ Materialize(out_y), out ~ Linear(TermAdd(out_x, out_y), 1, 0); + +Concrete(int j) >< AddCheckLinear(out, x, int q, int r) => out ~ Linear(x, q, r + j); + +Linear(y, int s, int t) >< AddCheckConcrete(out, int k) => out ~ Linear(y, s, t + k); + +Concrete(int j) >< AddCheckConcrete(out, int k) + | j == 0 => out ~ Concrete(k) + | _ => out ~ Concrete(k + j); + +Linear(x, int q, int r) >< Mul(out, b) => b ~ MulCheckLinear(out, x, q, r); + +Concrete(int k) >< Mul(out, b) + | k == 0 => b ~ Eraser, out ~ (*L)Concrete(0) + | k == 1 => out ~ b + | _ => b ~ MulCheckConcrete(out, k); + +Linear(y, int s, int t) >< MulCheckLinear(out, x, int q, int r) + | (q == 0) && (r == 0) && (s == 0) && (t == 0) => out ~ Concrete(0), x ~ Eraser, y ~ Eraser + | (s == 0) && (t == 0) => Linear(x, q, r) ~ Materialize(out), y ~ Eraser + | (q == 0) && (r == 0) => (*L)Linear(y, s, t) ~ Materialize(out), x ~ Eraser + | _ => Linear(x, q, r) ~ Materialize(out_x), (*L)Linear(y, s, t) ~ Materialize(out_y), out ~ Linear(TermMul(out_x, out_y), 1, 0); + +Concrete(int j) >< MulCheckLinear(out, x, int q, int r) => out ~ Linear(x, q * j, r * j); + +Linear(y, int s, int t) >< MulCheckConcrete(out, int k) => out ~ Linear(y, s * k, t * k); + +Concrete(int j) >< MulCheckConcrete(out, int k) + | j == 0 => out ~ Concrete(0) + | j == 1 => out ~ Concrete(k) + | _ => out ~ Concrete(k * j); + +Linear(x, int q, int r) >< ReLU(out) => (*L)Linear(x, q, r) ~ Materialize(out_x), out ~ Linear(TermReLU(out_x), 1, 0); + +Concrete(int k) >< ReLU(out) + | k > 0 => out ~ (*L)Concrete(k) + | _ => out ~ Concrete(0); + +Linear(x, int q, int r) >< Materialize(out) + | (q == 0) => out ~ Concrete(r), x ~ Eraser + | (q == 1) && (r == 0) => out ~ x + | (q == 1) && (r != 0) => out ~ TermAdd(x, Concrete(r)) + | (q != 0) && (r == 0) => out ~ TermMul(Concrete(q), x) + | _ => out ~ TermAdd(TermMul(Concrete(q), x), Concrete(r)); @@ -64,44 +64,43 @@ Concrete(int k) >< ReLU(out) | _ => out ~ Concrete(0); Linear(x, int q, int r) >< Materialize(out) - | (q == 0) && (r == 0) => out ~ Concrete(r), x ~ Eraser + | (q == 0) => out ~ Concrete(r), x ~ Eraser | (q == 1) && (r == 0) => out ~ x | (q == 1) && (r != 0) => out ~ TermAdd(x, Concrete(r)) | (q != 0) && (r == 0) => out ~ TermMul(Concrete(q), x) | _ => out ~ TermAdd(TermMul(Concrete(q), x), Concrete(r)); -// Net testing -Linear(Symbolic(X), 1, 0) ~ Add(a, b); -Concrete(0) ~ Mul(b, Linear(Symbolic(Y), 1, 0)); -a ~ Materialize(out); -out; // Symbolic(X) -free out a b; - -Linear(Symbolic(X), 1, 0) ~ Mul(a, Concrete(0)); -a ~ Materialize(out); -out; // Concrete(0) -free out a; - -Linear(Symbolic(X), 1, 0) ~ Mul(a, b); -Linear(Symbolic(Y), 1, 0) ~ Add(b, Concrete(1)); -a ~ Materialize(out); -out; // Mul(Symbolic(X),Add(Symbolic(Y),Concrete(1))) -free out a b; - -Concrete(1) ~ Add(out, b); -Concrete(2) ~ Add(b, c); -Concrete(3) ~ Mul(c, Concrete(2)); -out; // Concrete(9) -free out a b c; - -Linear(Symbolic(X), 1, 0) ~ Mul(a, Linear(Symbolic(W1), 1, 0)); -Linear(Symbolic(Y), 1, 0) ~ Mul(b, Linear(Symbolic(W2), 1, 0)); -b ~ Add(c, Linear(Symbolic(B), 1, 0)); -a ~ Add(d, c); -d ~ ReLU(Materialize(out)); -out; // ReLU(Add(Mul(Symbolic(X),Symbolic(W1)),Add(Mul(Symbolic(Y),Symbolic(W2)),Symbolic(B)))) -free out a b c d; - -Linear(Symbolic(X), 2, 10) ~ Materialize(out); -out; +// Wiring +Dup(v0, Dup(v1, Dup(v2, v3))) ~ Linear(Symbolic(0), 1, 0); +Mul(v4, Concrete(-693)) ~ v0; +Add(v5, v4) ~ Concrete(-692); +Mul(v6, Concrete(-78)) ~ v1; +Add(v7, v6) ~ Concrete(916); +Mul(v8, Concrete(235)) ~ v2; +Add(v9, v8) ~ Concrete(-424); +Mul(v10, Concrete(181)) ~ v3; +Add(v11, v10) ~ Concrete(202); +Dup(v12, Dup(v13, Dup(v14, v15))) ~ Linear(Symbolic(1), 1, 0); +Mul(v16, Concrete(-674)) ~ v12; +Add(v17, v16) ~ v5; +Mul(v18, Concrete(-97)) ~ v13; +Add(v19, v18) ~ v7; +Mul(v20, Concrete(-572)) ~ v14; +Add(v21, v20) ~ v9; +Mul(v22, Concrete(224)) ~ v15; +Add(v23, v22) ~ v11; +ReLU(v24) ~ v17; +ReLU(v25) ~ v19; +ReLU(v26) ~ v21; +ReLU(v27) ~ v23; +Mul(v28, Concrete(-318)) ~ v24; +Add(v29, v28) ~ Concrete(-89); +Mul(v30, Concrete(587)) ~ v25; +Add(v31, v30) ~ v29; +Mul(v32, Concrete(-250)) ~ v26; +Add(v33, v32) ~ v31; +Mul(v34, Concrete(254)) ~ v27; +Add(v35, v34) ~ v33; +Materialize(result0) ~ v35; +result0;
\ No newline at end of file @@ -0,0 +1,161 @@ +import torch +import torch.nn as nn +import torch.fx as fx +import numpy as np +import os +from typing import List, Dict + +class XOR_MLP(nn.Module): + def __init__(self): + super().__init__() + self.layers = nn.Sequential( + nn.Linear(2, 4), + nn.ReLU(), + nn.Linear(4, 1) + ) + def forward(self, x): + return self.layers(x) + +class NameGen: + def __init__(self): + self.counter = 0 + def next(self) -> str: + name = f"v{self.counter}" + self.counter += 1 + return name + +def get_rules() -> str: + rules_path = os.path.join(os.path.dirname(__file__), "rules.in") + if not os.path.exists(rules_path): + return "// Rules not found in rules.in\n" + + rules_lines = [] + with open(rules_path, "r") as f: + for line in f: + if "// Net testing" in line: + break + rules_lines.append(line) + return "".join(rules_lines) + +def export_to_inpla(model: nn.Module, input_shape: tuple, scale: int = 1000) -> str: + traced = fx.symbolic_trace(model) + name_gen = NameGen() + script: List[str] = [] + wire_map: Dict[str, List[str]] = {} + + for node in traced.graph.nodes: + if node.op == 'placeholder': + num_inputs = int(np.prod(input_shape)) + wire_map[node.name] = [f"Linear(Symbolic({i}), 1, 0)" for i in range(num_inputs)] + + elif node.op == 'call_module': + target_str = str(node.target) + module = dict(model.named_modules())[target_str] + + input_node = node.args[0] + if not isinstance(input_node, fx.Node): + continue + input_wires = wire_map[input_node.name] + + if isinstance(module, nn.Flatten): + wire_map[node.name] = input_wires + + elif isinstance(module, nn.Linear): + W = (module.weight.data.detach().cpu().numpy() * scale).astype(int) + B = (module.bias.data.detach().cpu().numpy() * scale).astype(int) + out_dim, in_dim = W.shape + + neuron_wires = [f"Concrete({B[j]})" for j in range(out_dim)] + + for i in range(in_dim): + in_term = input_wires[i] + if out_dim == 1: + weight = int(W[0, i]) + if weight == 0: + script.append(f"Eraser ~ {in_term};") + elif weight == 1: + new_s = name_gen.next() + script.append(f"Add({new_s}, {in_term}) ~ {neuron_wires[0]};") + neuron_wires[0] = new_s + else: + mul_out = name_gen.next() + new_s = name_gen.next() + script.append(f"Mul({mul_out}, Concrete({weight})) ~ {in_term};") + script.append(f"Add({new_s}, {mul_out}) ~ {neuron_wires[0]};") + neuron_wires[0] = new_s + else: + branch_wires = [name_gen.next() for _ in range(out_dim)] + + def nest_dups(names: List[str]) -> str: + if len(names) == 1: return names[0] + if len(names) == 2: return f"Dup({names[0]}, {names[1]})" + return f"Dup({names[0]}, {nest_dups(names[1:])})" + + script.append(f"{nest_dups(branch_wires)} ~ {in_term};") + + for j in range(out_dim): + weight = int(W[j, i]) + if weight == 0: + script.append(f"Eraser ~ {branch_wires[j]};") + elif weight == 1: + new_s = name_gen.next() + script.append(f"Add({new_s}, {branch_wires[j]}) ~ {neuron_wires[j]};") + neuron_wires[j] = new_s + else: + mul_out = name_gen.next() + new_s = name_gen.next() + script.append(f"Mul({mul_out}, Concrete({weight})) ~ {branch_wires[j]};") + script.append(f"Add({new_s}, {mul_out}) ~ {neuron_wires[j]};") + neuron_wires[j] = new_s + + wire_map[node.name] = neuron_wires + + elif isinstance(module, nn.ReLU): + input_wires = wire_map[node.args[0].name] + output_wires = [] + for i, w in enumerate(input_wires): + r_out = name_gen.next() + script.append(f"ReLU({r_out}) ~ {w};") + output_wires.append(r_out) + wire_map[node.name] = output_wires + + elif node.op == 'output': + output_node = node.args[0] + if isinstance(output_node, fx.Node): + final_wires = wire_map[output_node.name] + for i, w in enumerate(final_wires): + res_name = f"result{i}" + script.append(f"Materialize({res_name}) ~ {w};") + script.append(f"{res_name};") + + rules = get_rules() + return rules + "\n\n// Wiring\n" + "\n".join(script) + +if __name__ == "__main__": + X = torch.tensor([[0,0], [0,1], [1,0], [1,1]], dtype=torch.float32) + Y = torch.tensor([[0], [1], [1], [0]], dtype=torch.float32) + + net = XOR_MLP() + loss_fn = nn.MSELoss() + optimizer = torch.optim.Adam(net.parameters(), lr=0.01) + + print("Training XOR MLP...") + for epoch in range(1000): + optimizer.zero_grad() + out = net(X) + loss = loss_fn(out, Y) + loss.backward() + optimizer.step() + if (epoch+1) % 200 == 0: + print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}") + + print("\nTraining Finished. Predictions:") + with torch.no_grad(): + print(net(X).numpy()) + + print("\nExporting XOR to Inpla...") + net.eval() + inpla_script = export_to_inpla(net, (2,)) + with open("xor.in", "w") as f: + f.write(inpla_script) + print("Exported to xor.in") |
