aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorericmarin <maarin.eric@gmail.com>2026-03-10 17:57:36 +0100
committerericmarin <maarin.eric@gmail.com>2026-03-10 18:05:46 +0100
commit8619ee7a61bafb8c401087508b886e37779be07b (patch)
tree36d318e99883aab10cc429a57c44c1c96ddbb1c1
parent0882fc5328127f68a7d79c06d0c7decdee770bb9 (diff)
downloadvein-8619ee7a61bafb8c401087508b886e37779be07b.tar.gz
vein-8619ee7a61bafb8c401087508b886e37779be07b.zip
added scale
-rw-r--r--prover.py15
-rw-r--r--xor.in68
-rw-r--r--xor.py2
3 files changed, 44 insertions, 41 deletions
diff --git a/prover.py b/prover.py
index d4a398f..fb9658a 100644
--- a/prover.py
+++ b/prover.py
@@ -1,6 +1,9 @@
import z3
import sys
+# Scale used during export to Inpla (must match the 'scale' parameter in the exporter)
+SCALE = 1000.0
+
syms = {}
def Symbolic(id):
id = f"x_{id}"
@@ -8,7 +11,8 @@ def Symbolic(id):
syms[id] = z3.Real(id)
return syms[id]
-def Concrete(val): return z3.RealVal(val)
+def Concrete(val): return z3.RealVal(val) / SCALE
+
def TermAdd(a, b): return a + b
def TermMul(a, b): return a * b
def TermReLU(x): return z3.If(x > 0, x, 0)
@@ -56,20 +60,19 @@ def epsilon_equivalence(net_a, net_b, epsilon):
def argmax_equivalence(net_a, net_b):
solver = z3.Solver()
- solver.add(z3.IsInt(net_a > 0.5) != z3.IsInt(net_b > 0.5))
+ solver.add((net_a > 0.5) != (net_b > 0.5))
result = solver.check()
if result == z3.unsat:
- print(f"VERIFIED: The networks are argmax equivalent.")
+ print("VERIFIED: The networks are argmax equivalent (binary).")
elif result == z3.sat:
- print("FAILED: The networks are different.")
+ print("FAILED: The networks are classification-different.")
print("Counter-example input:")
print(solver.model())
else:
print("UNKNOWN: Solver could not decide.")
-
if __name__ == "__main__":
lines = [line.strip() for line in sys.stdin if line.strip() and not line.startswith("(")]
@@ -81,7 +84,7 @@ if __name__ == "__main__":
net_a_str = lines[-2]
net_b_str = lines[-1]
- print(f"Comparing:\nA: {net_a_str}\nB: {net_b_str}")
+ print(f"Comparing:\nA: {net_a_str}\n\nB: {net_b_str}")
net_a = eval(net_a_str, context)
net_b = eval(net_b_str, context)
diff --git a/xor.in b/xor.in
index b3fb968..a382883 100644
--- a/xor.in
+++ b/xor.in
@@ -73,34 +73,34 @@ Linear(x, int q, int r) >< Materialize(out)
// Network A
Dup(v0, Dup(v1, Dup(v2, v3))) ~ Linear(Symbolic(0), 1, 0);
-Mul(v4, Concrete(865)) ~ v0;
-Add(v5, v4) ~ Concrete(0);
-Mul(v6, Concrete(1029)) ~ v1;
-Add(v7, v6) ~ Concrete(0);
-Mul(v8, Concrete(1087)) ~ v2;
-Add(v9, v8) ~ Concrete(1086);
-Mul(v10, Concrete(676)) ~ v3;
-Add(v11, v10) ~ Concrete(-693);
+Mul(v4, Concrete(-729)) ~ v0;
+Add(v5, v4) ~ Concrete(732);
+Mul(v6, Concrete(707)) ~ v1;
+Add(v7, v6) ~ Concrete(106);
+Mul(v8, Concrete(-577)) ~ v2;
+Add(v9, v8) ~ Concrete(-502);
+Mul(v10, Concrete(1070)) ~ v3;
+Add(v11, v10) ~ Concrete(-1068);
Dup(v12, Dup(v13, Dup(v14, v15))) ~ Linear(Symbolic(1), 1, 0);
-Mul(v16, Concrete(-865)) ~ v12;
+Mul(v16, Concrete(-725)) ~ v12;
Add(v17, v16) ~ v5;
-Mul(v18, Concrete(-1029)) ~ v13;
+Mul(v18, Concrete(708)) ~ v13;
Add(v19, v18) ~ v7;
-Mul(v20, Concrete(-1087)) ~ v14;
+Mul(v20, Concrete(220)) ~ v14;
Add(v21, v20) ~ v9;
-Mul(v22, Concrete(-378)) ~ v15;
+Mul(v22, Concrete(1066)) ~ v15;
Add(v23, v22) ~ v11;
ReLU(v24) ~ v17;
ReLU(v25) ~ v19;
ReLU(v26) ~ v21;
ReLU(v27) ~ v23;
-Mul(v28, Concrete(1153)) ~ v24;
-Add(v29, v28) ~ Concrete(1000);
-Mul(v30, Concrete(974)) ~ v25;
+Mul(v28, Concrete(-642)) ~ v24;
+Add(v29, v28) ~ Concrete(390);
+Mul(v30, Concrete(753)) ~ v25;
Add(v31, v30) ~ v29;
-Mul(v32, Concrete(-920)) ~ v26;
+Mul(v32, Concrete(235)) ~ v26;
Add(v33, v32) ~ v31;
-Mul(v34, Concrete(367)) ~ v27;
+Mul(v34, Concrete(-1440)) ~ v27;
Add(v35, v34) ~ v33;
Materialize(result0) ~ v35;
result0;
@@ -109,34 +109,34 @@ free ifce;
// Network B
Dup(v0, Dup(v1, Dup(v2, v3))) ~ Linear(Symbolic(0), 1, 0);
-Mul(v4, Concrete(-238)) ~ v0;
-Add(v5, v4) ~ Concrete(-704);
-Mul(v6, Concrete(-111)) ~ v1;
-Add(v7, v6) ~ Concrete(-515);
-Mul(v8, Concrete(-1232)) ~ v2;
-Add(v9, v8) ~ Concrete(-8);
-Mul(v10, Concrete(1113)) ~ v3;
-Add(v11, v10) ~ Concrete(189);
+Mul(v4, Concrete(-181)) ~ v0;
+Add(v5, v4) ~ Concrete(-142);
+Mul(v6, Concrete(-1061)) ~ v1;
+Add(v7, v6) ~ Concrete(1050);
+Mul(v8, Concrete(1181)) ~ v2;
+Add(v9, v8) ~ Concrete(-568);
+Mul(v10, Concrete(-627)) ~ v3;
+Add(v11, v10) ~ Concrete(1236);
Dup(v12, Dup(v13, Dup(v14, v15))) ~ Linear(Symbolic(1), 1, 0);
-Mul(v16, Concrete(639)) ~ v12;
+Mul(v16, Concrete(-609)) ~ v12;
Add(v17, v16) ~ v5;
-Mul(v18, Concrete(66)) ~ v13;
+Mul(v18, Concrete(-1058)) ~ v13;
Add(v19, v18) ~ v7;
-Mul(v20, Concrete(1226)) ~ v14;
+Mul(v20, Concrete(1404)) ~ v14;
Add(v21, v20) ~ v9;
-Mul(v22, Concrete(-1113)) ~ v15;
+Mul(v22, Concrete(-311)) ~ v15;
Add(v23, v22) ~ v11;
ReLU(v24) ~ v17;
ReLU(v25) ~ v19;
ReLU(v26) ~ v21;
ReLU(v27) ~ v23;
-Mul(v28, Concrete(111)) ~ v24;
-Add(v29, v28) ~ Concrete(-170);
-Mul(v30, Concrete(239)) ~ v25;
+Mul(v28, Concrete(-313)) ~ v24;
+Add(v29, v28) ~ Concrete(1112);
+Mul(v30, Concrete(-1571)) ~ v25;
Add(v31, v30) ~ v29;
-Mul(v32, Concrete(961)) ~ v26;
+Mul(v32, Concrete(-615)) ~ v26;
Add(v33, v32) ~ v31;
-Mul(v34, Concrete(897)) ~ v27;
+Mul(v34, Concrete(434)) ~ v27;
Add(v35, v34) ~ v33;
Materialize(result0) ~ v35;
result0; \ No newline at end of file
diff --git a/xor.py b/xor.py
index f7e247f..905ddfb 100644
--- a/xor.py
+++ b/xor.py
@@ -144,7 +144,7 @@ def train_model(name: str):
loss = loss_fn(out, Y)
loss.backward()
optimizer.step()
- if (epoch+1) % 500 == 0:
+ if (epoch+1) % 100 == 0:
print(f" Epoch {epoch+1}, Loss: {loss.item():.4f}")
return net