diff options
| author | ericmarin <maarin.eric@gmail.com> | 2026-03-10 17:57:36 +0100 |
|---|---|---|
| committer | ericmarin <maarin.eric@gmail.com> | 2026-03-10 18:05:46 +0100 |
| commit | 8619ee7a61bafb8c401087508b886e37779be07b (patch) | |
| tree | 36d318e99883aab10cc429a57c44c1c96ddbb1c1 | |
| parent | 0882fc5328127f68a7d79c06d0c7decdee770bb9 (diff) | |
| download | vein-8619ee7a61bafb8c401087508b886e37779be07b.tar.gz vein-8619ee7a61bafb8c401087508b886e37779be07b.zip | |
added scale
| -rw-r--r-- | prover.py | 15 | ||||
| -rw-r--r-- | xor.in | 68 | ||||
| -rw-r--r-- | xor.py | 2 |
3 files changed, 44 insertions, 41 deletions
@@ -1,6 +1,9 @@ import z3 import sys +# Scale used during export to Inpla (must match the 'scale' parameter in the exporter) +SCALE = 1000.0 + syms = {} def Symbolic(id): id = f"x_{id}" @@ -8,7 +11,8 @@ def Symbolic(id): syms[id] = z3.Real(id) return syms[id] -def Concrete(val): return z3.RealVal(val) +def Concrete(val): return z3.RealVal(val) / SCALE + def TermAdd(a, b): return a + b def TermMul(a, b): return a * b def TermReLU(x): return z3.If(x > 0, x, 0) @@ -56,20 +60,19 @@ def epsilon_equivalence(net_a, net_b, epsilon): def argmax_equivalence(net_a, net_b): solver = z3.Solver() - solver.add(z3.IsInt(net_a > 0.5) != z3.IsInt(net_b > 0.5)) + solver.add((net_a > 0.5) != (net_b > 0.5)) result = solver.check() if result == z3.unsat: - print(f"VERIFIED: The networks are argmax equivalent.") + print("VERIFIED: The networks are argmax equivalent (binary).") elif result == z3.sat: - print("FAILED: The networks are different.") + print("FAILED: The networks are classification-different.") print("Counter-example input:") print(solver.model()) else: print("UNKNOWN: Solver could not decide.") - if __name__ == "__main__": lines = [line.strip() for line in sys.stdin if line.strip() and not line.startswith("(")] @@ -81,7 +84,7 @@ if __name__ == "__main__": net_a_str = lines[-2] net_b_str = lines[-1] - print(f"Comparing:\nA: {net_a_str}\nB: {net_b_str}") + print(f"Comparing:\nA: {net_a_str}\n\nB: {net_b_str}") net_a = eval(net_a_str, context) net_b = eval(net_b_str, context) @@ -73,34 +73,34 @@ Linear(x, int q, int r) >< Materialize(out) // Network A Dup(v0, Dup(v1, Dup(v2, v3))) ~ Linear(Symbolic(0), 1, 0); -Mul(v4, Concrete(865)) ~ v0; -Add(v5, v4) ~ Concrete(0); -Mul(v6, Concrete(1029)) ~ v1; -Add(v7, v6) ~ Concrete(0); -Mul(v8, Concrete(1087)) ~ v2; -Add(v9, v8) ~ Concrete(1086); -Mul(v10, Concrete(676)) ~ v3; -Add(v11, v10) ~ Concrete(-693); +Mul(v4, Concrete(-729)) ~ v0; +Add(v5, v4) ~ Concrete(732); +Mul(v6, Concrete(707)) ~ v1; +Add(v7, v6) ~ Concrete(106); +Mul(v8, Concrete(-577)) ~ v2; +Add(v9, v8) ~ Concrete(-502); +Mul(v10, Concrete(1070)) ~ v3; +Add(v11, v10) ~ Concrete(-1068); Dup(v12, Dup(v13, Dup(v14, v15))) ~ Linear(Symbolic(1), 1, 0); -Mul(v16, Concrete(-865)) ~ v12; +Mul(v16, Concrete(-725)) ~ v12; Add(v17, v16) ~ v5; -Mul(v18, Concrete(-1029)) ~ v13; +Mul(v18, Concrete(708)) ~ v13; Add(v19, v18) ~ v7; -Mul(v20, Concrete(-1087)) ~ v14; +Mul(v20, Concrete(220)) ~ v14; Add(v21, v20) ~ v9; -Mul(v22, Concrete(-378)) ~ v15; +Mul(v22, Concrete(1066)) ~ v15; Add(v23, v22) ~ v11; ReLU(v24) ~ v17; ReLU(v25) ~ v19; ReLU(v26) ~ v21; ReLU(v27) ~ v23; -Mul(v28, Concrete(1153)) ~ v24; -Add(v29, v28) ~ Concrete(1000); -Mul(v30, Concrete(974)) ~ v25; +Mul(v28, Concrete(-642)) ~ v24; +Add(v29, v28) ~ Concrete(390); +Mul(v30, Concrete(753)) ~ v25; Add(v31, v30) ~ v29; -Mul(v32, Concrete(-920)) ~ v26; +Mul(v32, Concrete(235)) ~ v26; Add(v33, v32) ~ v31; -Mul(v34, Concrete(367)) ~ v27; +Mul(v34, Concrete(-1440)) ~ v27; Add(v35, v34) ~ v33; Materialize(result0) ~ v35; result0; @@ -109,34 +109,34 @@ free ifce; // Network B Dup(v0, Dup(v1, Dup(v2, v3))) ~ Linear(Symbolic(0), 1, 0); -Mul(v4, Concrete(-238)) ~ v0; -Add(v5, v4) ~ Concrete(-704); -Mul(v6, Concrete(-111)) ~ v1; -Add(v7, v6) ~ Concrete(-515); -Mul(v8, Concrete(-1232)) ~ v2; -Add(v9, v8) ~ Concrete(-8); -Mul(v10, Concrete(1113)) ~ v3; -Add(v11, v10) ~ Concrete(189); +Mul(v4, Concrete(-181)) ~ v0; +Add(v5, v4) ~ Concrete(-142); +Mul(v6, Concrete(-1061)) ~ v1; +Add(v7, v6) ~ Concrete(1050); +Mul(v8, Concrete(1181)) ~ v2; +Add(v9, v8) ~ Concrete(-568); +Mul(v10, Concrete(-627)) ~ v3; +Add(v11, v10) ~ Concrete(1236); Dup(v12, Dup(v13, Dup(v14, v15))) ~ Linear(Symbolic(1), 1, 0); -Mul(v16, Concrete(639)) ~ v12; +Mul(v16, Concrete(-609)) ~ v12; Add(v17, v16) ~ v5; -Mul(v18, Concrete(66)) ~ v13; +Mul(v18, Concrete(-1058)) ~ v13; Add(v19, v18) ~ v7; -Mul(v20, Concrete(1226)) ~ v14; +Mul(v20, Concrete(1404)) ~ v14; Add(v21, v20) ~ v9; -Mul(v22, Concrete(-1113)) ~ v15; +Mul(v22, Concrete(-311)) ~ v15; Add(v23, v22) ~ v11; ReLU(v24) ~ v17; ReLU(v25) ~ v19; ReLU(v26) ~ v21; ReLU(v27) ~ v23; -Mul(v28, Concrete(111)) ~ v24; -Add(v29, v28) ~ Concrete(-170); -Mul(v30, Concrete(239)) ~ v25; +Mul(v28, Concrete(-313)) ~ v24; +Add(v29, v28) ~ Concrete(1112); +Mul(v30, Concrete(-1571)) ~ v25; Add(v31, v30) ~ v29; -Mul(v32, Concrete(961)) ~ v26; +Mul(v32, Concrete(-615)) ~ v26; Add(v33, v32) ~ v31; -Mul(v34, Concrete(897)) ~ v27; +Mul(v34, Concrete(434)) ~ v27; Add(v35, v34) ~ v33; Materialize(result0) ~ v35; result0;
\ No newline at end of file @@ -144,7 +144,7 @@ def train_model(name: str): loss = loss_fn(out, Y) loss.backward() optimizer.step() - if (epoch+1) % 500 == 0: + if (epoch+1) % 100 == 0: print(f" Epoch {epoch+1}, Loss: {loss.item():.4f}") return net |
