aboutsummaryrefslogtreecommitdiff
path: root/xor.py
diff options
context:
space:
mode:
authorericmarin <maarin.eric@gmail.com>2026-03-10 16:29:33 +0100
committerericmarin <maarin.eric@gmail.com>2026-03-10 17:34:09 +0100
commitaf2b13214579d78827392e762149fa8824526aa9 (patch)
tree81dd6bb8e442668cecff72ec20bcc77e2e9f5c2d /xor.py
parent7933d744e06337f1d69b7da83f2cee1611556097 (diff)
downloadvein-af2b13214579d78827392e762149fa8824526aa9.tar.gz
vein-af2b13214579d78827392e762149fa8824526aa9.zip
added MLP
Diffstat (limited to '')
-rw-r--r--xor.py161
1 files changed, 161 insertions, 0 deletions
diff --git a/xor.py b/xor.py
new file mode 100644
index 0000000..b2082c7
--- /dev/null
+++ b/xor.py
@@ -0,0 +1,161 @@
+import torch
+import torch.nn as nn
+import torch.fx as fx
+import numpy as np
+import os
+from typing import List, Dict
+
+class XOR_MLP(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.layers = nn.Sequential(
+ nn.Linear(2, 4),
+ nn.ReLU(),
+ nn.Linear(4, 1)
+ )
+ def forward(self, x):
+ return self.layers(x)
+
+class NameGen:
+ def __init__(self):
+ self.counter = 0
+ def next(self) -> str:
+ name = f"v{self.counter}"
+ self.counter += 1
+ return name
+
+def get_rules() -> str:
+ rules_path = os.path.join(os.path.dirname(__file__), "rules.in")
+ if not os.path.exists(rules_path):
+ return "// Rules not found in rules.in\n"
+
+ rules_lines = []
+ with open(rules_path, "r") as f:
+ for line in f:
+ if "// Net testing" in line:
+ break
+ rules_lines.append(line)
+ return "".join(rules_lines)
+
+def export_to_inpla(model: nn.Module, input_shape: tuple, scale: int = 1000) -> str:
+ traced = fx.symbolic_trace(model)
+ name_gen = NameGen()
+ script: List[str] = []
+ wire_map: Dict[str, List[str]] = {}
+
+ for node in traced.graph.nodes:
+ if node.op == 'placeholder':
+ num_inputs = int(np.prod(input_shape))
+ wire_map[node.name] = [f"Linear(Symbolic({i}), 1, 0)" for i in range(num_inputs)]
+
+ elif node.op == 'call_module':
+ target_str = str(node.target)
+ module = dict(model.named_modules())[target_str]
+
+ input_node = node.args[0]
+ if not isinstance(input_node, fx.Node):
+ continue
+ input_wires = wire_map[input_node.name]
+
+ if isinstance(module, nn.Flatten):
+ wire_map[node.name] = input_wires
+
+ elif isinstance(module, nn.Linear):
+ W = (module.weight.data.detach().cpu().numpy() * scale).astype(int)
+ B = (module.bias.data.detach().cpu().numpy() * scale).astype(int)
+ out_dim, in_dim = W.shape
+
+ neuron_wires = [f"Concrete({B[j]})" for j in range(out_dim)]
+
+ for i in range(in_dim):
+ in_term = input_wires[i]
+ if out_dim == 1:
+ weight = int(W[0, i])
+ if weight == 0:
+ script.append(f"Eraser ~ {in_term};")
+ elif weight == 1:
+ new_s = name_gen.next()
+ script.append(f"Add({new_s}, {in_term}) ~ {neuron_wires[0]};")
+ neuron_wires[0] = new_s
+ else:
+ mul_out = name_gen.next()
+ new_s = name_gen.next()
+ script.append(f"Mul({mul_out}, Concrete({weight})) ~ {in_term};")
+ script.append(f"Add({new_s}, {mul_out}) ~ {neuron_wires[0]};")
+ neuron_wires[0] = new_s
+ else:
+ branch_wires = [name_gen.next() for _ in range(out_dim)]
+
+ def nest_dups(names: List[str]) -> str:
+ if len(names) == 1: return names[0]
+ if len(names) == 2: return f"Dup({names[0]}, {names[1]})"
+ return f"Dup({names[0]}, {nest_dups(names[1:])})"
+
+ script.append(f"{nest_dups(branch_wires)} ~ {in_term};")
+
+ for j in range(out_dim):
+ weight = int(W[j, i])
+ if weight == 0:
+ script.append(f"Eraser ~ {branch_wires[j]};")
+ elif weight == 1:
+ new_s = name_gen.next()
+ script.append(f"Add({new_s}, {branch_wires[j]}) ~ {neuron_wires[j]};")
+ neuron_wires[j] = new_s
+ else:
+ mul_out = name_gen.next()
+ new_s = name_gen.next()
+ script.append(f"Mul({mul_out}, Concrete({weight})) ~ {branch_wires[j]};")
+ script.append(f"Add({new_s}, {mul_out}) ~ {neuron_wires[j]};")
+ neuron_wires[j] = new_s
+
+ wire_map[node.name] = neuron_wires
+
+ elif isinstance(module, nn.ReLU):
+ input_wires = wire_map[node.args[0].name]
+ output_wires = []
+ for i, w in enumerate(input_wires):
+ r_out = name_gen.next()
+ script.append(f"ReLU({r_out}) ~ {w};")
+ output_wires.append(r_out)
+ wire_map[node.name] = output_wires
+
+ elif node.op == 'output':
+ output_node = node.args[0]
+ if isinstance(output_node, fx.Node):
+ final_wires = wire_map[output_node.name]
+ for i, w in enumerate(final_wires):
+ res_name = f"result{i}"
+ script.append(f"Materialize({res_name}) ~ {w};")
+ script.append(f"{res_name};")
+
+ rules = get_rules()
+ return rules + "\n\n// Wiring\n" + "\n".join(script)
+
+if __name__ == "__main__":
+ X = torch.tensor([[0,0], [0,1], [1,0], [1,1]], dtype=torch.float32)
+ Y = torch.tensor([[0], [1], [1], [0]], dtype=torch.float32)
+
+ net = XOR_MLP()
+ loss_fn = nn.MSELoss()
+ optimizer = torch.optim.Adam(net.parameters(), lr=0.01)
+
+ print("Training XOR MLP...")
+ for epoch in range(1000):
+ optimizer.zero_grad()
+ out = net(X)
+ loss = loss_fn(out, Y)
+ loss.backward()
+ optimizer.step()
+ if (epoch+1) % 200 == 0:
+ print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")
+
+ print("\nTraining Finished. Predictions:")
+ with torch.no_grad():
+ print(net(X).numpy())
+
+ print("\nExporting XOR to Inpla...")
+ net.eval()
+ inpla_script = export_to_inpla(net, (2,))
+ with open("xor.in", "w") as f:
+ f.write(inpla_script)
+ print("Exported to xor.in")