diff options
| author | ericmarin <maarin.eric@gmail.com> | 2026-03-10 17:42:02 +0100 |
|---|---|---|
| committer | ericmarin <maarin.eric@gmail.com> | 2026-03-10 17:50:08 +0100 |
| commit | 0882fc5328127f68a7d79c06d0c7decdee770bb9 (patch) | |
| tree | 805051f8c5f830fee28d82fc5c9daedcff15f91f | |
| parent | af2b13214579d78827392e762149fa8824526aa9 (diff) | |
| download | vein-0882fc5328127f68a7d79c06d0c7decdee770bb9.tar.gz vein-0882fc5328127f68a7d79c06d0c7decdee770bb9.zip | |
two nets
| -rw-r--r-- | prover.py | 38 | ||||
| -rw-r--r-- | xor.in | 72 | ||||
| -rw-r--r-- | xor.py | 47 |
3 files changed, 117 insertions, 40 deletions
@@ -1,4 +1,5 @@ import z3 +import sys syms = {} def Symbolic(id): @@ -52,15 +53,46 @@ def epsilon_equivalence(net_a, net_b, epsilon): else: print("UNKNOWN: Solver could not decide.") +def argmax_equivalence(net_a, net_b): + solver = z3.Solver() + + solver.add(z3.IsInt(net_a > 0.5) != z3.IsInt(net_b > 0.5)) + + result = solver.check() + + if result == z3.unsat: + print(f"VERIFIED: The networks are argmax equivalent.") + elif result == z3.sat: + print("FAILED: The networks are different.") + print("Counter-example input:") + print(solver.model()) + else: + print("UNKNOWN: Solver could not decide.") + + if __name__ == "__main__": - net_a_str = "TermAdd(TermMul(Symbolic(0), Concrete(2)), Concrete(3))" # 2x + 3 - net_b_str = "TermAdd(Concrete(3), TermMul(Concrete(2), Symbolic(0)))" # 3 + 2x + lines = [line.strip() for line in sys.stdin if line.strip() and not line.startswith("(")] + + if len(lines) < 2: + print(f"; Error: Expected at least 2 Inpla output strings, but got {len(lines)}.") + sys.exit(1) try: + net_a_str = lines[-2] + net_b_str = lines[-1] + + print(f"Comparing:\nA: {net_a_str}\nB: {net_b_str}") + net_a = eval(net_a_str, context) net_b = eval(net_b_str, context) + + print("\nStrict Equivalence") equivalence(net_a, net_b) + print("\nEpsilon-Equivalence") epsilon_equivalence(net_a, net_b, 1e-5) + print("\nARGMAX Equivalence") + argmax_equivalence(net_a, net_b) + except Exception as e: print(f"; Error parsing Inpla output: {e}") - + sys.exit(1) @@ -71,36 +71,72 @@ Linear(x, int q, int r) >< Materialize(out) | _ => out ~ TermAdd(TermMul(Concrete(q), x), Concrete(r)); -// Wiring +// Network A Dup(v0, Dup(v1, Dup(v2, v3))) ~ Linear(Symbolic(0), 1, 0); -Mul(v4, Concrete(-693)) ~ v0; -Add(v5, v4) ~ Concrete(-692); -Mul(v6, Concrete(-78)) ~ v1; -Add(v7, v6) ~ Concrete(916); -Mul(v8, Concrete(235)) ~ v2; -Add(v9, v8) ~ Concrete(-424); -Mul(v10, Concrete(181)) ~ v3; -Add(v11, v10) ~ Concrete(202); +Mul(v4, Concrete(865)) ~ v0; +Add(v5, v4) ~ Concrete(0); +Mul(v6, Concrete(1029)) ~ v1; +Add(v7, v6) ~ Concrete(0); +Mul(v8, Concrete(1087)) ~ v2; +Add(v9, v8) ~ Concrete(1086); +Mul(v10, Concrete(676)) ~ v3; +Add(v11, v10) ~ Concrete(-693); Dup(v12, Dup(v13, Dup(v14, v15))) ~ Linear(Symbolic(1), 1, 0); -Mul(v16, Concrete(-674)) ~ v12; +Mul(v16, Concrete(-865)) ~ v12; Add(v17, v16) ~ v5; -Mul(v18, Concrete(-97)) ~ v13; +Mul(v18, Concrete(-1029)) ~ v13; Add(v19, v18) ~ v7; -Mul(v20, Concrete(-572)) ~ v14; +Mul(v20, Concrete(-1087)) ~ v14; Add(v21, v20) ~ v9; -Mul(v22, Concrete(224)) ~ v15; +Mul(v22, Concrete(-378)) ~ v15; Add(v23, v22) ~ v11; ReLU(v24) ~ v17; ReLU(v25) ~ v19; ReLU(v26) ~ v21; ReLU(v27) ~ v23; -Mul(v28, Concrete(-318)) ~ v24; -Add(v29, v28) ~ Concrete(-89); -Mul(v30, Concrete(587)) ~ v25; +Mul(v28, Concrete(1153)) ~ v24; +Add(v29, v28) ~ Concrete(1000); +Mul(v30, Concrete(974)) ~ v25; Add(v31, v30) ~ v29; -Mul(v32, Concrete(-250)) ~ v26; +Mul(v32, Concrete(-920)) ~ v26; Add(v33, v32) ~ v31; -Mul(v34, Concrete(254)) ~ v27; +Mul(v34, Concrete(367)) ~ v27; +Add(v35, v34) ~ v33; +Materialize(result0) ~ v35; +result0; +free ifce; + + +// Network B +Dup(v0, Dup(v1, Dup(v2, v3))) ~ Linear(Symbolic(0), 1, 0); +Mul(v4, Concrete(-238)) ~ v0; +Add(v5, v4) ~ Concrete(-704); +Mul(v6, Concrete(-111)) ~ v1; +Add(v7, v6) ~ Concrete(-515); +Mul(v8, Concrete(-1232)) ~ v2; +Add(v9, v8) ~ Concrete(-8); +Mul(v10, Concrete(1113)) ~ v3; +Add(v11, v10) ~ Concrete(189); +Dup(v12, Dup(v13, Dup(v14, v15))) ~ Linear(Symbolic(1), 1, 0); +Mul(v16, Concrete(639)) ~ v12; +Add(v17, v16) ~ v5; +Mul(v18, Concrete(66)) ~ v13; +Add(v19, v18) ~ v7; +Mul(v20, Concrete(1226)) ~ v14; +Add(v21, v20) ~ v9; +Mul(v22, Concrete(-1113)) ~ v15; +Add(v23, v22) ~ v11; +ReLU(v24) ~ v17; +ReLU(v25) ~ v19; +ReLU(v26) ~ v21; +ReLU(v27) ~ v23; +Mul(v28, Concrete(111)) ~ v24; +Add(v29, v28) ~ Concrete(-170); +Mul(v30, Concrete(239)) ~ v25; +Add(v31, v30) ~ v29; +Mul(v32, Concrete(961)) ~ v26; +Add(v33, v32) ~ v31; +Mul(v34, Concrete(897)) ~ v27; Add(v35, v34) ~ v33; Materialize(result0) ~ v35; result0;
\ No newline at end of file @@ -6,12 +6,12 @@ import os from typing import List, Dict class XOR_MLP(nn.Module): - def __init__(self): + def __init__(self, hidden_dim=4): super().__init__() self.layers = nn.Sequential( - nn.Linear(2, 4), + nn.Linear(2, hidden_dim), nn.ReLU(), - nn.Linear(4, 1) + nn.Linear(hidden_dim, 1) ) def forward(self, x): return self.layers(x) @@ -37,7 +37,7 @@ def get_rules() -> str: rules_lines.append(line) return "".join(rules_lines) -def export_to_inpla(model: nn.Module, input_shape: tuple, scale: int = 1000) -> str: +def export_to_inpla_wiring(model: nn.Module, input_shape: tuple, scale: int = 1000) -> str: traced = fx.symbolic_trace(model) name_gen = NameGen() script: List[str] = [] @@ -111,7 +111,6 @@ def export_to_inpla(model: nn.Module, input_shape: tuple, scale: int = 1000) -> wire_map[node.name] = neuron_wires elif isinstance(module, nn.ReLU): - input_wires = wire_map[node.args[0].name] output_wires = [] for i, w in enumerate(input_wires): r_out = name_gen.next() @@ -128,10 +127,9 @@ def export_to_inpla(model: nn.Module, input_shape: tuple, scale: int = 1000) -> script.append(f"Materialize({res_name}) ~ {w};") script.append(f"{res_name};") - rules = get_rules() - return rules + "\n\n// Wiring\n" + "\n".join(script) + return "\n".join(script) -if __name__ == "__main__": +def train_model(name: str): X = torch.tensor([[0,0], [0,1], [1,0], [1,1]], dtype=torch.float32) Y = torch.tensor([[0], [1], [1], [0]], dtype=torch.float32) @@ -139,23 +137,34 @@ if __name__ == "__main__": loss_fn = nn.MSELoss() optimizer = torch.optim.Adam(net.parameters(), lr=0.01) - print("Training XOR MLP...") + print(f"Training {name}...") for epoch in range(1000): optimizer.zero_grad() out = net(X) loss = loss_fn(out, Y) loss.backward() optimizer.step() - if (epoch+1) % 200 == 0: - print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}") + if (epoch+1) % 500 == 0: + print(f" Epoch {epoch+1}, Loss: {loss.item():.4f}") + return net + +if __name__ == "__main__": + # Train two different models + net_a = train_model("Network A") + net_b = train_model("Network B") - print("\nTraining Finished. Predictions:") - with torch.no_grad(): - print(net(X).numpy()) + print("\nExporting both to xor.in...") + + rules = get_rules() + wiring_a = export_to_inpla_wiring(net_a, (2,)) + wiring_b = export_to_inpla_wiring(net_b, (2,)) - print("\nExporting XOR to Inpla...") - net.eval() - inpla_script = export_to_inpla(net, (2,)) with open("xor.in", "w") as f: - f.write(inpla_script) - print("Exported to xor.in") + f.write(rules) + f.write("\n\n// Network A\n") + f.write(wiring_a) + f.write("\nfree ifce;\n") + f.write("\n\n// Network B\n") + f.write(wiring_b) + + print("Done. Now run: inpla -f xor.in | python3 prover.py") |
