aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorericmarin <maarin.eric@gmail.com>2026-03-10 17:42:02 +0100
committerericmarin <maarin.eric@gmail.com>2026-03-10 17:50:08 +0100
commit0882fc5328127f68a7d79c06d0c7decdee770bb9 (patch)
tree805051f8c5f830fee28d82fc5c9daedcff15f91f
parentaf2b13214579d78827392e762149fa8824526aa9 (diff)
downloadvein-0882fc5328127f68a7d79c06d0c7decdee770bb9.tar.gz
vein-0882fc5328127f68a7d79c06d0c7decdee770bb9.zip
two nets
-rw-r--r--prover.py38
-rw-r--r--xor.in72
-rw-r--r--xor.py47
3 files changed, 117 insertions, 40 deletions
diff --git a/prover.py b/prover.py
index 7ac8c5d..d4a398f 100644
--- a/prover.py
+++ b/prover.py
@@ -1,4 +1,5 @@
import z3
+import sys
syms = {}
def Symbolic(id):
@@ -52,15 +53,46 @@ def epsilon_equivalence(net_a, net_b, epsilon):
else:
print("UNKNOWN: Solver could not decide.")
+def argmax_equivalence(net_a, net_b):
+ solver = z3.Solver()
+
+ solver.add(z3.IsInt(net_a > 0.5) != z3.IsInt(net_b > 0.5))
+
+ result = solver.check()
+
+ if result == z3.unsat:
+ print(f"VERIFIED: The networks are argmax equivalent.")
+ elif result == z3.sat:
+ print("FAILED: The networks are different.")
+ print("Counter-example input:")
+ print(solver.model())
+ else:
+ print("UNKNOWN: Solver could not decide.")
+
+
if __name__ == "__main__":
- net_a_str = "TermAdd(TermMul(Symbolic(0), Concrete(2)), Concrete(3))" # 2x + 3
- net_b_str = "TermAdd(Concrete(3), TermMul(Concrete(2), Symbolic(0)))" # 3 + 2x
+ lines = [line.strip() for line in sys.stdin if line.strip() and not line.startswith("(")]
+
+ if len(lines) < 2:
+ print(f"; Error: Expected at least 2 Inpla output strings, but got {len(lines)}.")
+ sys.exit(1)
try:
+ net_a_str = lines[-2]
+ net_b_str = lines[-1]
+
+ print(f"Comparing:\nA: {net_a_str}\nB: {net_b_str}")
+
net_a = eval(net_a_str, context)
net_b = eval(net_b_str, context)
+
+ print("\nStrict Equivalence")
equivalence(net_a, net_b)
+ print("\nEpsilon-Equivalence")
epsilon_equivalence(net_a, net_b, 1e-5)
+ print("\nARGMAX Equivalence")
+ argmax_equivalence(net_a, net_b)
+
except Exception as e:
print(f"; Error parsing Inpla output: {e}")
-
+ sys.exit(1)
diff --git a/xor.in b/xor.in
index b7e3b6d..b3fb968 100644
--- a/xor.in
+++ b/xor.in
@@ -71,36 +71,72 @@ Linear(x, int q, int r) >< Materialize(out)
| _ => out ~ TermAdd(TermMul(Concrete(q), x), Concrete(r));
-// Wiring
+// Network A
Dup(v0, Dup(v1, Dup(v2, v3))) ~ Linear(Symbolic(0), 1, 0);
-Mul(v4, Concrete(-693)) ~ v0;
-Add(v5, v4) ~ Concrete(-692);
-Mul(v6, Concrete(-78)) ~ v1;
-Add(v7, v6) ~ Concrete(916);
-Mul(v8, Concrete(235)) ~ v2;
-Add(v9, v8) ~ Concrete(-424);
-Mul(v10, Concrete(181)) ~ v3;
-Add(v11, v10) ~ Concrete(202);
+Mul(v4, Concrete(865)) ~ v0;
+Add(v5, v4) ~ Concrete(0);
+Mul(v6, Concrete(1029)) ~ v1;
+Add(v7, v6) ~ Concrete(0);
+Mul(v8, Concrete(1087)) ~ v2;
+Add(v9, v8) ~ Concrete(1086);
+Mul(v10, Concrete(676)) ~ v3;
+Add(v11, v10) ~ Concrete(-693);
Dup(v12, Dup(v13, Dup(v14, v15))) ~ Linear(Symbolic(1), 1, 0);
-Mul(v16, Concrete(-674)) ~ v12;
+Mul(v16, Concrete(-865)) ~ v12;
Add(v17, v16) ~ v5;
-Mul(v18, Concrete(-97)) ~ v13;
+Mul(v18, Concrete(-1029)) ~ v13;
Add(v19, v18) ~ v7;
-Mul(v20, Concrete(-572)) ~ v14;
+Mul(v20, Concrete(-1087)) ~ v14;
Add(v21, v20) ~ v9;
-Mul(v22, Concrete(224)) ~ v15;
+Mul(v22, Concrete(-378)) ~ v15;
Add(v23, v22) ~ v11;
ReLU(v24) ~ v17;
ReLU(v25) ~ v19;
ReLU(v26) ~ v21;
ReLU(v27) ~ v23;
-Mul(v28, Concrete(-318)) ~ v24;
-Add(v29, v28) ~ Concrete(-89);
-Mul(v30, Concrete(587)) ~ v25;
+Mul(v28, Concrete(1153)) ~ v24;
+Add(v29, v28) ~ Concrete(1000);
+Mul(v30, Concrete(974)) ~ v25;
Add(v31, v30) ~ v29;
-Mul(v32, Concrete(-250)) ~ v26;
+Mul(v32, Concrete(-920)) ~ v26;
Add(v33, v32) ~ v31;
-Mul(v34, Concrete(254)) ~ v27;
+Mul(v34, Concrete(367)) ~ v27;
+Add(v35, v34) ~ v33;
+Materialize(result0) ~ v35;
+result0;
+free ifce;
+
+
+// Network B
+Dup(v0, Dup(v1, Dup(v2, v3))) ~ Linear(Symbolic(0), 1, 0);
+Mul(v4, Concrete(-238)) ~ v0;
+Add(v5, v4) ~ Concrete(-704);
+Mul(v6, Concrete(-111)) ~ v1;
+Add(v7, v6) ~ Concrete(-515);
+Mul(v8, Concrete(-1232)) ~ v2;
+Add(v9, v8) ~ Concrete(-8);
+Mul(v10, Concrete(1113)) ~ v3;
+Add(v11, v10) ~ Concrete(189);
+Dup(v12, Dup(v13, Dup(v14, v15))) ~ Linear(Symbolic(1), 1, 0);
+Mul(v16, Concrete(639)) ~ v12;
+Add(v17, v16) ~ v5;
+Mul(v18, Concrete(66)) ~ v13;
+Add(v19, v18) ~ v7;
+Mul(v20, Concrete(1226)) ~ v14;
+Add(v21, v20) ~ v9;
+Mul(v22, Concrete(-1113)) ~ v15;
+Add(v23, v22) ~ v11;
+ReLU(v24) ~ v17;
+ReLU(v25) ~ v19;
+ReLU(v26) ~ v21;
+ReLU(v27) ~ v23;
+Mul(v28, Concrete(111)) ~ v24;
+Add(v29, v28) ~ Concrete(-170);
+Mul(v30, Concrete(239)) ~ v25;
+Add(v31, v30) ~ v29;
+Mul(v32, Concrete(961)) ~ v26;
+Add(v33, v32) ~ v31;
+Mul(v34, Concrete(897)) ~ v27;
Add(v35, v34) ~ v33;
Materialize(result0) ~ v35;
result0; \ No newline at end of file
diff --git a/xor.py b/xor.py
index b2082c7..f7e247f 100644
--- a/xor.py
+++ b/xor.py
@@ -6,12 +6,12 @@ import os
from typing import List, Dict
class XOR_MLP(nn.Module):
- def __init__(self):
+ def __init__(self, hidden_dim=4):
super().__init__()
self.layers = nn.Sequential(
- nn.Linear(2, 4),
+ nn.Linear(2, hidden_dim),
nn.ReLU(),
- nn.Linear(4, 1)
+ nn.Linear(hidden_dim, 1)
)
def forward(self, x):
return self.layers(x)
@@ -37,7 +37,7 @@ def get_rules() -> str:
rules_lines.append(line)
return "".join(rules_lines)
-def export_to_inpla(model: nn.Module, input_shape: tuple, scale: int = 1000) -> str:
+def export_to_inpla_wiring(model: nn.Module, input_shape: tuple, scale: int = 1000) -> str:
traced = fx.symbolic_trace(model)
name_gen = NameGen()
script: List[str] = []
@@ -111,7 +111,6 @@ def export_to_inpla(model: nn.Module, input_shape: tuple, scale: int = 1000) ->
wire_map[node.name] = neuron_wires
elif isinstance(module, nn.ReLU):
- input_wires = wire_map[node.args[0].name]
output_wires = []
for i, w in enumerate(input_wires):
r_out = name_gen.next()
@@ -128,10 +127,9 @@ def export_to_inpla(model: nn.Module, input_shape: tuple, scale: int = 1000) ->
script.append(f"Materialize({res_name}) ~ {w};")
script.append(f"{res_name};")
- rules = get_rules()
- return rules + "\n\n// Wiring\n" + "\n".join(script)
+ return "\n".join(script)
-if __name__ == "__main__":
+def train_model(name: str):
X = torch.tensor([[0,0], [0,1], [1,0], [1,1]], dtype=torch.float32)
Y = torch.tensor([[0], [1], [1], [0]], dtype=torch.float32)
@@ -139,23 +137,34 @@ if __name__ == "__main__":
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.01)
- print("Training XOR MLP...")
+ print(f"Training {name}...")
for epoch in range(1000):
optimizer.zero_grad()
out = net(X)
loss = loss_fn(out, Y)
loss.backward()
optimizer.step()
- if (epoch+1) % 200 == 0:
- print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")
+ if (epoch+1) % 500 == 0:
+ print(f" Epoch {epoch+1}, Loss: {loss.item():.4f}")
+ return net
+
+if __name__ == "__main__":
+ # Train two different models
+ net_a = train_model("Network A")
+ net_b = train_model("Network B")
- print("\nTraining Finished. Predictions:")
- with torch.no_grad():
- print(net(X).numpy())
+ print("\nExporting both to xor.in...")
+
+ rules = get_rules()
+ wiring_a = export_to_inpla_wiring(net_a, (2,))
+ wiring_b = export_to_inpla_wiring(net_b, (2,))
- print("\nExporting XOR to Inpla...")
- net.eval()
- inpla_script = export_to_inpla(net, (2,))
with open("xor.in", "w") as f:
- f.write(inpla_script)
- print("Exported to xor.in")
+ f.write(rules)
+ f.write("\n\n// Network A\n")
+ f.write(wiring_a)
+ f.write("\nfree ifce;\n")
+ f.write("\n\n// Network B\n")
+ f.write(wiring_b)
+
+ print("Done. Now run: inpla -f xor.in | python3 prover.py")