aboutsummaryrefslogtreecommitdiff
path: root/xor.py
diff options
context:
space:
mode:
authorericmarin <maarin.eric@gmail.com>2026-03-13 16:42:00 +0100
committerericmarin <maarin.eric@gmail.com>2026-03-16 10:23:02 +0100
commita0b1e7f6a8c11ed98ae20ac484e2fe9f75b9b85f (patch)
tree57d6aa106daf4a46d9132832eec88fdb79fc5543 /xor.py
parent19652ec48be4c6faf3f7815a9281b611aed94727 (diff)
downloadvein-a0b1e7f6a8c11ed98ae20ac484e2fe9f75b9b85f.tar.gz
vein-a0b1e7f6a8c11ed98ae20ac484e2fe9f75b9b85f.zip
defined modules
Diffstat (limited to '')
-rw-r--r--xor.py144
1 files changed, 11 insertions, 133 deletions
diff --git a/xor.py b/xor.py
index 493eaef..9ab7be7 100644
--- a/xor.py
+++ b/xor.py
@@ -1,11 +1,8 @@
import torch
import torch.nn as nn
-import torch.fx as fx
-import numpy as np
-import os
-from typing import List, Dict
+import nneq
-class XOR_MLP(nn.Module):
+class xor_mlp(nn.Module):
def __init__(self, hidden_dim=4):
super().__init__()
self.layers = nn.Sequential(
@@ -16,124 +13,11 @@ class XOR_MLP(nn.Module):
def forward(self, x):
return self.layers(x)
-class NameGen:
- def __init__(self):
- self.counter = 0
- def next(self) -> str:
- name = f"v{self.counter}"
- self.counter += 1
- return name
-
-def get_rules() -> str:
- rules_path = os.path.join(os.path.dirname(__file__), "rules.in")
- if not os.path.exists(rules_path):
- return "// Rules not found in rules.in\n"
-
- rules_lines = []
- with open(rules_path, "r") as f:
- for line in f:
- if "// Net testing" in line:
- break
- rules_lines.append(line)
- return "".join(rules_lines)
-
-def export_to_inpla_wiring(model: nn.Module, input_shape: tuple) -> str:
- traced = fx.symbolic_trace(model)
- name_gen = NameGen()
- script: List[str] = []
- wire_map: Dict[str, List[str]] = {}
-
- for node in traced.graph.nodes:
- if node.op == 'placeholder':
- num_inputs = int(np.prod(input_shape))
- wire_map[node.name] = [f"Linear(Symbolic(X_{i}), 1.0, 0.0)" for i in range(num_inputs)]
-
- elif node.op == 'call_module':
- target_str = str(node.target)
- module = dict(model.named_modules())[target_str]
-
- input_node = node.args[0]
- if not isinstance(input_node, fx.Node):
- continue
- input_wires = wire_map[input_node.name]
-
- if isinstance(module, nn.Flatten):
- wire_map[node.name] = input_wires
-
- elif isinstance(module, nn.Linear):
- W = (module.weight.data.detach().cpu().numpy()).astype(float)
- B = (module.bias.data.detach().cpu().numpy()).astype(float)
- out_dim, in_dim = W.shape
-
- neuron_wires = [f"Concrete({B[j]})" for j in range(out_dim)]
-
- for i in range(in_dim):
- in_term = input_wires[i]
- if out_dim == 1:
- weight = float(W[0, i])
- if weight == 0:
- script.append(f"Eraser ~ {in_term};")
- elif weight == 1:
- new_s = name_gen.next()
- script.append(f"Add({new_s}, {in_term}) ~ {neuron_wires[0]};")
- neuron_wires[0] = new_s
- else:
- mul_out = name_gen.next()
- new_s = name_gen.next()
- script.append(f"Mul({mul_out}, Concrete({weight})) ~ {in_term};")
- script.append(f"Add({new_s}, {mul_out}) ~ {neuron_wires[0]};")
- neuron_wires[0] = new_s
- else:
- branch_wires = [name_gen.next() for _ in range(out_dim)]
-
- def nest_dups(names: List[str]) -> str:
- if len(names) == 1: return names[0]
- if len(names) == 2: return f"Dup({names[0]}, {names[1]})"
- return f"Dup({names[0]}, {nest_dups(names[1:])})"
-
- script.append(f"{nest_dups(branch_wires)} ~ {in_term};")
-
- for j in range(out_dim):
- weight = float(W[j, i])
- if weight == 0:
- script.append(f"Eraser ~ {branch_wires[j]};")
- elif weight == 1:
- new_s = name_gen.next()
- script.append(f"Add({new_s}, {branch_wires[j]}) ~ {neuron_wires[j]};")
- neuron_wires[j] = new_s
- else:
- mul_out = name_gen.next()
- new_s = name_gen.next()
- script.append(f"Mul({mul_out}, Concrete({weight})) ~ {branch_wires[j]};")
- script.append(f"Add({new_s}, {mul_out}) ~ {neuron_wires[j]};")
- neuron_wires[j] = new_s
-
- wire_map[node.name] = neuron_wires
-
- elif isinstance(module, nn.ReLU):
- output_wires = []
- for i, w in enumerate(input_wires):
- r_out = name_gen.next()
- script.append(f"ReLU({r_out}) ~ {w};")
- output_wires.append(r_out)
- wire_map[node.name] = output_wires
-
- elif node.op == 'output':
- output_node = node.args[0]
- if isinstance(output_node, fx.Node):
- final_wires = wire_map[output_node.name]
- for i, w in enumerate(final_wires):
- res_name = f"result{i}"
- script.append(f"Materialize({res_name}) ~ {w};")
- script.append(f"{res_name};")
-
- return "\n".join(script)
-
def train_model(name: str):
X = torch.tensor([[0,0], [0,1], [1,0], [1,1]], dtype=torch.float32)
Y = torch.tensor([[0], [1], [1], [0]], dtype=torch.float32)
- net = XOR_MLP()
+ net = xor_mlp()
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.01)
@@ -149,22 +33,16 @@ def train_model(name: str):
return net
if __name__ == "__main__":
- # Train two different models
net_a = train_model("Network A")
net_b = train_model("Network B")
- print("\nExporting both to xor.in...")
+ z3_net_a = nneq.net(net_a, (2,))
+ z3_net_b = nneq.net(net_b, (2,))
- rules = get_rules()
- wiring_a = export_to_inpla_wiring(net_a, (2,))
- wiring_b = export_to_inpla_wiring(net_b, (2,))
+ print("")
+ nneq.strict_equivalence(z3_net_a, z3_net_b)
+ print("")
+ nneq.epsilon_equivalence(z3_net_a, z3_net_b, 0.1)
+ print("")
+ nneq.argmax_equivalence(z3_net_a, z3_net_b)
- with open("xor.in", "w") as f:
- f.write(rules)
- f.write("\n\n// Network A\n")
- f.write(wiring_a)
- f.write("\nfree ifce;\n")
- f.write("\n\n// Network B\n")
- f.write(wiring_b)
-
- print("Done. Now run: inpla -f xor.in | python3 prover.py")