diff options
| author | ericmarin <maarin.eric@gmail.com> | 2026-03-10 16:29:33 +0100 |
|---|---|---|
| committer | ericmarin <maarin.eric@gmail.com> | 2026-03-10 17:34:09 +0100 |
| commit | af2b13214579d78827392e762149fa8824526aa9 (patch) | |
| tree | 81dd6bb8e442668cecff72ec20bcc77e2e9f5c2d /xor.py | |
| parent | 7933d744e06337f1d69b7da83f2cee1611556097 (diff) | |
| download | vein-af2b13214579d78827392e762149fa8824526aa9.tar.gz vein-af2b13214579d78827392e762149fa8824526aa9.zip | |
added MLP
Diffstat (limited to 'xor.py')
| -rw-r--r-- | xor.py | 161 |
1 files changed, 161 insertions, 0 deletions
@@ -0,0 +1,161 @@ +import torch +import torch.nn as nn +import torch.fx as fx +import numpy as np +import os +from typing import List, Dict + +class XOR_MLP(nn.Module): + def __init__(self): + super().__init__() + self.layers = nn.Sequential( + nn.Linear(2, 4), + nn.ReLU(), + nn.Linear(4, 1) + ) + def forward(self, x): + return self.layers(x) + +class NameGen: + def __init__(self): + self.counter = 0 + def next(self) -> str: + name = f"v{self.counter}" + self.counter += 1 + return name + +def get_rules() -> str: + rules_path = os.path.join(os.path.dirname(__file__), "rules.in") + if not os.path.exists(rules_path): + return "// Rules not found in rules.in\n" + + rules_lines = [] + with open(rules_path, "r") as f: + for line in f: + if "// Net testing" in line: + break + rules_lines.append(line) + return "".join(rules_lines) + +def export_to_inpla(model: nn.Module, input_shape: tuple, scale: int = 1000) -> str: + traced = fx.symbolic_trace(model) + name_gen = NameGen() + script: List[str] = [] + wire_map: Dict[str, List[str]] = {} + + for node in traced.graph.nodes: + if node.op == 'placeholder': + num_inputs = int(np.prod(input_shape)) + wire_map[node.name] = [f"Linear(Symbolic({i}), 1, 0)" for i in range(num_inputs)] + + elif node.op == 'call_module': + target_str = str(node.target) + module = dict(model.named_modules())[target_str] + + input_node = node.args[0] + if not isinstance(input_node, fx.Node): + continue + input_wires = wire_map[input_node.name] + + if isinstance(module, nn.Flatten): + wire_map[node.name] = input_wires + + elif isinstance(module, nn.Linear): + W = (module.weight.data.detach().cpu().numpy() * scale).astype(int) + B = (module.bias.data.detach().cpu().numpy() * scale).astype(int) + out_dim, in_dim = W.shape + + neuron_wires = [f"Concrete({B[j]})" for j in range(out_dim)] + + for i in range(in_dim): + in_term = input_wires[i] + if out_dim == 1: + weight = int(W[0, i]) + if weight == 0: + script.append(f"Eraser ~ {in_term};") + elif weight == 1: + new_s = name_gen.next() + script.append(f"Add({new_s}, {in_term}) ~ {neuron_wires[0]};") + neuron_wires[0] = new_s + else: + mul_out = name_gen.next() + new_s = name_gen.next() + script.append(f"Mul({mul_out}, Concrete({weight})) ~ {in_term};") + script.append(f"Add({new_s}, {mul_out}) ~ {neuron_wires[0]};") + neuron_wires[0] = new_s + else: + branch_wires = [name_gen.next() for _ in range(out_dim)] + + def nest_dups(names: List[str]) -> str: + if len(names) == 1: return names[0] + if len(names) == 2: return f"Dup({names[0]}, {names[1]})" + return f"Dup({names[0]}, {nest_dups(names[1:])})" + + script.append(f"{nest_dups(branch_wires)} ~ {in_term};") + + for j in range(out_dim): + weight = int(W[j, i]) + if weight == 0: + script.append(f"Eraser ~ {branch_wires[j]};") + elif weight == 1: + new_s = name_gen.next() + script.append(f"Add({new_s}, {branch_wires[j]}) ~ {neuron_wires[j]};") + neuron_wires[j] = new_s + else: + mul_out = name_gen.next() + new_s = name_gen.next() + script.append(f"Mul({mul_out}, Concrete({weight})) ~ {branch_wires[j]};") + script.append(f"Add({new_s}, {mul_out}) ~ {neuron_wires[j]};") + neuron_wires[j] = new_s + + wire_map[node.name] = neuron_wires + + elif isinstance(module, nn.ReLU): + input_wires = wire_map[node.args[0].name] + output_wires = [] + for i, w in enumerate(input_wires): + r_out = name_gen.next() + script.append(f"ReLU({r_out}) ~ {w};") + output_wires.append(r_out) + wire_map[node.name] = output_wires + + elif node.op == 'output': + output_node = node.args[0] + if isinstance(output_node, fx.Node): + final_wires = wire_map[output_node.name] + for i, w in enumerate(final_wires): + res_name = f"result{i}" + script.append(f"Materialize({res_name}) ~ {w};") + script.append(f"{res_name};") + + rules = get_rules() + return rules + "\n\n// Wiring\n" + "\n".join(script) + +if __name__ == "__main__": + X = torch.tensor([[0,0], [0,1], [1,0], [1,1]], dtype=torch.float32) + Y = torch.tensor([[0], [1], [1], [0]], dtype=torch.float32) + + net = XOR_MLP() + loss_fn = nn.MSELoss() + optimizer = torch.optim.Adam(net.parameters(), lr=0.01) + + print("Training XOR MLP...") + for epoch in range(1000): + optimizer.zero_grad() + out = net(X) + loss = loss_fn(out, Y) + loss.backward() + optimizer.step() + if (epoch+1) % 200 == 0: + print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}") + + print("\nTraining Finished. Predictions:") + with torch.no_grad(): + print(net(X).numpy()) + + print("\nExporting XOR to Inpla...") + net.eval() + inpla_script = export_to_inpla(net, (2,)) + with open("xor.in", "w") as f: + f.write(inpla_script) + print("Exported to xor.in") |
