diff options
Diffstat (limited to 'xor.py')
| -rw-r--r-- | xor.py | 144 |
1 files changed, 11 insertions, 133 deletions
@@ -1,11 +1,8 @@ import torch import torch.nn as nn -import torch.fx as fx -import numpy as np -import os -from typing import List, Dict +import nneq -class XOR_MLP(nn.Module): +class xor_mlp(nn.Module): def __init__(self, hidden_dim=4): super().__init__() self.layers = nn.Sequential( @@ -16,124 +13,11 @@ class XOR_MLP(nn.Module): def forward(self, x): return self.layers(x) -class NameGen: - def __init__(self): - self.counter = 0 - def next(self) -> str: - name = f"v{self.counter}" - self.counter += 1 - return name - -def get_rules() -> str: - rules_path = os.path.join(os.path.dirname(__file__), "rules.in") - if not os.path.exists(rules_path): - return "// Rules not found in rules.in\n" - - rules_lines = [] - with open(rules_path, "r") as f: - for line in f: - if "// Net testing" in line: - break - rules_lines.append(line) - return "".join(rules_lines) - -def export_to_inpla_wiring(model: nn.Module, input_shape: tuple) -> str: - traced = fx.symbolic_trace(model) - name_gen = NameGen() - script: List[str] = [] - wire_map: Dict[str, List[str]] = {} - - for node in traced.graph.nodes: - if node.op == 'placeholder': - num_inputs = int(np.prod(input_shape)) - wire_map[node.name] = [f"Linear(Symbolic(X_{i}), 1.0, 0.0)" for i in range(num_inputs)] - - elif node.op == 'call_module': - target_str = str(node.target) - module = dict(model.named_modules())[target_str] - - input_node = node.args[0] - if not isinstance(input_node, fx.Node): - continue - input_wires = wire_map[input_node.name] - - if isinstance(module, nn.Flatten): - wire_map[node.name] = input_wires - - elif isinstance(module, nn.Linear): - W = (module.weight.data.detach().cpu().numpy()).astype(float) - B = (module.bias.data.detach().cpu().numpy()).astype(float) - out_dim, in_dim = W.shape - - neuron_wires = [f"Concrete({B[j]})" for j in range(out_dim)] - - for i in range(in_dim): - in_term = input_wires[i] - if out_dim == 1: - weight = float(W[0, i]) - if weight == 0: - script.append(f"Eraser ~ {in_term};") - elif weight == 1: - new_s = name_gen.next() - script.append(f"Add({new_s}, {in_term}) ~ {neuron_wires[0]};") - neuron_wires[0] = new_s - else: - mul_out = name_gen.next() - new_s = name_gen.next() - script.append(f"Mul({mul_out}, Concrete({weight})) ~ {in_term};") - script.append(f"Add({new_s}, {mul_out}) ~ {neuron_wires[0]};") - neuron_wires[0] = new_s - else: - branch_wires = [name_gen.next() for _ in range(out_dim)] - - def nest_dups(names: List[str]) -> str: - if len(names) == 1: return names[0] - if len(names) == 2: return f"Dup({names[0]}, {names[1]})" - return f"Dup({names[0]}, {nest_dups(names[1:])})" - - script.append(f"{nest_dups(branch_wires)} ~ {in_term};") - - for j in range(out_dim): - weight = float(W[j, i]) - if weight == 0: - script.append(f"Eraser ~ {branch_wires[j]};") - elif weight == 1: - new_s = name_gen.next() - script.append(f"Add({new_s}, {branch_wires[j]}) ~ {neuron_wires[j]};") - neuron_wires[j] = new_s - else: - mul_out = name_gen.next() - new_s = name_gen.next() - script.append(f"Mul({mul_out}, Concrete({weight})) ~ {branch_wires[j]};") - script.append(f"Add({new_s}, {mul_out}) ~ {neuron_wires[j]};") - neuron_wires[j] = new_s - - wire_map[node.name] = neuron_wires - - elif isinstance(module, nn.ReLU): - output_wires = [] - for i, w in enumerate(input_wires): - r_out = name_gen.next() - script.append(f"ReLU({r_out}) ~ {w};") - output_wires.append(r_out) - wire_map[node.name] = output_wires - - elif node.op == 'output': - output_node = node.args[0] - if isinstance(output_node, fx.Node): - final_wires = wire_map[output_node.name] - for i, w in enumerate(final_wires): - res_name = f"result{i}" - script.append(f"Materialize({res_name}) ~ {w};") - script.append(f"{res_name};") - - return "\n".join(script) - def train_model(name: str): X = torch.tensor([[0,0], [0,1], [1,0], [1,1]], dtype=torch.float32) Y = torch.tensor([[0], [1], [1], [0]], dtype=torch.float32) - net = XOR_MLP() + net = xor_mlp() loss_fn = nn.MSELoss() optimizer = torch.optim.Adam(net.parameters(), lr=0.01) @@ -149,22 +33,16 @@ def train_model(name: str): return net if __name__ == "__main__": - # Train two different models net_a = train_model("Network A") net_b = train_model("Network B") - print("\nExporting both to xor.in...") + z3_net_a = nneq.net(net_a, (2,)) + z3_net_b = nneq.net(net_b, (2,)) - rules = get_rules() - wiring_a = export_to_inpla_wiring(net_a, (2,)) - wiring_b = export_to_inpla_wiring(net_b, (2,)) + print("") + nneq.strict_equivalence(z3_net_a, z3_net_b) + print("") + nneq.epsilon_equivalence(z3_net_a, z3_net_b, 0.1) + print("") + nneq.argmax_equivalence(z3_net_a, z3_net_b) - with open("xor.in", "w") as f: - f.write(rules) - f.write("\n\n// Network A\n") - f.write(wiring_a) - f.write("\nfree ifce;\n") - f.write("\n\n// Network B\n") - f.write(wiring_b) - - print("Done. Now run: inpla -f xor.in | python3 prover.py") |
