aboutsummaryrefslogtreecommitdiff
path: root/xor.py
diff options
context:
space:
mode:
authorericmarin <maarin.eric@gmail.com>2026-03-10 17:42:02 +0100
committerericmarin <maarin.eric@gmail.com>2026-03-10 17:50:08 +0100
commit0882fc5328127f68a7d79c06d0c7decdee770bb9 (patch)
tree805051f8c5f830fee28d82fc5c9daedcff15f91f /xor.py
parentaf2b13214579d78827392e762149fa8824526aa9 (diff)
downloadvein-0882fc5328127f68a7d79c06d0c7decdee770bb9.tar.gz
vein-0882fc5328127f68a7d79c06d0c7decdee770bb9.zip
two nets
Diffstat (limited to '')
-rw-r--r--xor.py47
1 files changed, 28 insertions, 19 deletions
diff --git a/xor.py b/xor.py
index b2082c7..f7e247f 100644
--- a/xor.py
+++ b/xor.py
@@ -6,12 +6,12 @@ import os
from typing import List, Dict
class XOR_MLP(nn.Module):
- def __init__(self):
+ def __init__(self, hidden_dim=4):
super().__init__()
self.layers = nn.Sequential(
- nn.Linear(2, 4),
+ nn.Linear(2, hidden_dim),
nn.ReLU(),
- nn.Linear(4, 1)
+ nn.Linear(hidden_dim, 1)
)
def forward(self, x):
return self.layers(x)
@@ -37,7 +37,7 @@ def get_rules() -> str:
rules_lines.append(line)
return "".join(rules_lines)
-def export_to_inpla(model: nn.Module, input_shape: tuple, scale: int = 1000) -> str:
+def export_to_inpla_wiring(model: nn.Module, input_shape: tuple, scale: int = 1000) -> str:
traced = fx.symbolic_trace(model)
name_gen = NameGen()
script: List[str] = []
@@ -111,7 +111,6 @@ def export_to_inpla(model: nn.Module, input_shape: tuple, scale: int = 1000) ->
wire_map[node.name] = neuron_wires
elif isinstance(module, nn.ReLU):
- input_wires = wire_map[node.args[0].name]
output_wires = []
for i, w in enumerate(input_wires):
r_out = name_gen.next()
@@ -128,10 +127,9 @@ def export_to_inpla(model: nn.Module, input_shape: tuple, scale: int = 1000) ->
script.append(f"Materialize({res_name}) ~ {w};")
script.append(f"{res_name};")
- rules = get_rules()
- return rules + "\n\n// Wiring\n" + "\n".join(script)
+ return "\n".join(script)
-if __name__ == "__main__":
+def train_model(name: str):
X = torch.tensor([[0,0], [0,1], [1,0], [1,1]], dtype=torch.float32)
Y = torch.tensor([[0], [1], [1], [0]], dtype=torch.float32)
@@ -139,23 +137,34 @@ if __name__ == "__main__":
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.01)
- print("Training XOR MLP...")
+ print(f"Training {name}...")
for epoch in range(1000):
optimizer.zero_grad()
out = net(X)
loss = loss_fn(out, Y)
loss.backward()
optimizer.step()
- if (epoch+1) % 200 == 0:
- print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")
+ if (epoch+1) % 500 == 0:
+ print(f" Epoch {epoch+1}, Loss: {loss.item():.4f}")
+ return net
+
+if __name__ == "__main__":
+ # Train two different models
+ net_a = train_model("Network A")
+ net_b = train_model("Network B")
- print("\nTraining Finished. Predictions:")
- with torch.no_grad():
- print(net(X).numpy())
+ print("\nExporting both to xor.in...")
+
+ rules = get_rules()
+ wiring_a = export_to_inpla_wiring(net_a, (2,))
+ wiring_b = export_to_inpla_wiring(net_b, (2,))
- print("\nExporting XOR to Inpla...")
- net.eval()
- inpla_script = export_to_inpla(net, (2,))
with open("xor.in", "w") as f:
- f.write(inpla_script)
- print("Exported to xor.in")
+ f.write(rules)
+ f.write("\n\n// Network A\n")
+ f.write(wiring_a)
+ f.write("\nfree ifce;\n")
+ f.write("\n\n// Network B\n")
+ f.write(wiring_b)
+
+ print("Done. Now run: inpla -f xor.in | python3 prover.py")