aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorericmarin <maarin.eric@gmail.com>2026-03-12 15:37:53 +0100
committerericmarin <maarin.eric@gmail.com>2026-03-12 15:56:24 +0100
commit19652ec48be4c6faf3f7815a9281b611aed94727 (patch)
tree94ab6abc2c50835d2d9a51242025f8ad13b7f7c5
parentfb544e2089e0c52bd83ffe56f2f4e8d7176564ee (diff)
downloadvein-19652ec48be4c6faf3f7815a9281b611aed94727.tar.gz
vein-19652ec48be4c6faf3f7815a9281b611aed94727.zip
changed to float
-rw-r--r--prover.py14
-rw-r--r--rules.in34
-rw-r--r--xor.in110
-rw-r--r--xor.py12
4 files changed, 84 insertions, 86 deletions
diff --git a/prover.py b/prover.py
index d6d1fcc..614cdb6 100644
--- a/prover.py
+++ b/prover.py
@@ -1,17 +1,14 @@
import z3
+import re
import sys
-# Scale used during export to Inpla (must match the 'scale' parameter in the exporter)
-SCALE = 1000.0
-
syms = {}
def Symbolic(id):
- id = f"x_{id}"
if id not in syms:
syms[id] = z3.Real(id)
return syms[id]
-def Concrete(val): return z3.RealVal(val) / SCALE
+def Concrete(val): return z3.RealVal(val)
def TermAdd(a, b): return a + b
def TermMul(a, b): return a * b
@@ -90,8 +87,9 @@ if __name__ == "__main__":
sys.exit(1)
try:
- net_a_str = lines[-2]
- net_b_str = lines[-1]
+ wrap = re.compile(r"Symbolic\((.*?)\)")
+ net_a_str = wrap.sub(r'Symbolic("\1")', lines[-2]);
+ net_b_str = wrap.sub(r'Symbolic("\1")', lines[-1]);
print(f"Comparing:\nA: {net_a_str}\n\nB: {net_b_str}")
@@ -101,7 +99,7 @@ if __name__ == "__main__":
print("\nStrict Equivalence")
equivalence(net_a, net_b)
print("\nEpsilon-Equivalence")
- epsilon_equivalence(net_a, net_b, 1e-2)
+ epsilon_equivalence(net_a, net_b, 1e-1)
print("\nARGMAX Equivalence")
argmax_equivalence(net_a, net_b)
diff --git a/rules.in b/rules.in
index 2ff92dd..22698bd 100644
--- a/rules.in
+++ b/rules.in
@@ -4,8 +4,8 @@
// Dup: duplicates other agents recursively
// Implemented
-// Linear(x, int q, int r): represent "q*x + r"
-// Concrete(int k): represent a concrete value k
+// Linear(x, float q, float r): represent "q*x + r"
+// Concrete(float k): represent a concrete value k
// Symbolic(id): represent the variable id
// Add(out, b): represent the addition (has various steps AddCheckLinear/AddCheckConcrete)
// Mul(out, b): represent the multiplication (has various steps MulCheckLinear/MulCheckConcrete)
@@ -15,55 +15,55 @@
// TODO: add range information to enable ReLU elimination
// Rules
-Linear(x, int q, int r) >< Add(out, b) => b ~ AddCheckLinear(out, x, q, r);
+Linear(x, float q, float r) >< Add(out, b) => b ~ AddCheckLinear(out, x, q, r);
-Concrete(int k) >< Add(out, b)
+Concrete(float k) >< Add(out, b)
| k == 0 => out ~ b
| _ => b ~ AddCheckConcrete(out, k);
-Linear(y, int s, int t) >< AddCheckLinear(out, x, int q, int r)
+Linear(y, float s, float t) >< AddCheckLinear(out, x, float q, float r)
| (q == 0) && (r == 0) && (s == 0) && (t == 0) => out ~ Concrete(0), x ~ Eraser, y ~ Eraser
| (s == 0) && (t == 0) => Linear(x, q, r) ~ Materialize(out), y ~ Eraser
| (q == 0) && (r == 0) => (*L)Linear(y, s, t) ~ Materialize(out), x ~ Eraser
| _ => Linear(x, q, r) ~ Materialize(out_x), (*L)Linear(y, s, t) ~ Materialize(out_y), out ~ Linear(TermAdd(out_x, out_y), 1, 0);
-Concrete(int j) >< AddCheckLinear(out, x, int q, int r) => out ~ Linear(x, q, r + j);
+Concrete(float j) >< AddCheckLinear(out, x, float q, float r) => out ~ Linear(x, q, r + j);
-Linear(y, int s, int t) >< AddCheckConcrete(out, int k) => out ~ Linear(y, s, t + k);
+Linear(y, float s, float t) >< AddCheckConcrete(out, float k) => out ~ Linear(y, s, t + k);
-Concrete(int j) >< AddCheckConcrete(out, int k)
+Concrete(float j) >< AddCheckConcrete(out, float k)
| j == 0 => out ~ Concrete(k)
| _ => out ~ Concrete(k + j);
-Linear(x, int q, int r) >< Mul(out, b) => b ~ MulCheckLinear(out, x, q, r);
+Linear(x, float q, float r) >< Mul(out, b) => b ~ MulCheckLinear(out, x, q, r);
-Concrete(int k) >< Mul(out, b)
+Concrete(float k) >< Mul(out, b)
| k == 0 => b ~ Eraser, out ~ (*L)Concrete(0)
| k == 1 => out ~ b
| _ => b ~ MulCheckConcrete(out, k);
-Linear(y, int s, int t) >< MulCheckLinear(out, x, int q, int r)
+Linear(y, float s, float t) >< MulCheckLinear(out, x, float q, float r)
| (q == 0) && (r == 0) && (s == 0) && (t == 0) => out ~ Concrete(0), x ~ Eraser, y ~ Eraser
| (s == 0) && (t == 0) => Linear(x, q, r) ~ Materialize(out), y ~ Eraser
| (q == 0) && (r == 0) => (*L)Linear(y, s, t) ~ Materialize(out), x ~ Eraser
| _ => Linear(x, q, r) ~ Materialize(out_x), (*L)Linear(y, s, t) ~ Materialize(out_y), out ~ Linear(TermMul(out_x, out_y), 1, 0);
-Concrete(int j) >< MulCheckLinear(out, x, int q, int r) => out ~ Linear(x, q * j, r * j);
+Concrete(float j) >< MulCheckLinear(out, x, float q, float r) => out ~ Linear(x, q * j, r * j);
-Linear(y, int s, int t) >< MulCheckConcrete(out, int k) => out ~ Linear(y, s * k, t * k);
+Linear(y, float s, float t) >< MulCheckConcrete(out, float k) => out ~ Linear(y, s * k, t * k);
-Concrete(int j) >< MulCheckConcrete(out, int k)
+Concrete(float j) >< MulCheckConcrete(out, float k)
| j == 0 => out ~ Concrete(0)
| j == 1 => out ~ Concrete(k)
| _ => out ~ Concrete(k * j);
-Linear(x, int q, int r) >< ReLU(out) => (*L)Linear(x, q, r) ~ Materialize(out_x), out ~ Linear(TermReLU(out_x), 1, 0);
+Linear(x, float q, float r) >< ReLU(out) => (*L)Linear(x, q, r) ~ Materialize(out_x), out ~ Linear(TermReLU(out_x), 1, 0);
-Concrete(int k) >< ReLU(out)
+Concrete(float k) >< ReLU(out)
| k > 0 => out ~ (*L)Concrete(k)
| _ => out ~ Concrete(0);
-Linear(x, int q, int r) >< Materialize(out)
+Linear(x, float q, float r) >< Materialize(out)
| (q == 0) => out ~ Concrete(r), x ~ Eraser
| (q == 1) && (r == 0) => out ~ x
| (q == 1) && (r != 0) => out ~ TermAdd(x, Concrete(r))
diff --git a/xor.in b/xor.in
index a382883..6a71e7b 100644
--- a/xor.in
+++ b/xor.in
@@ -4,8 +4,8 @@
// Dup: duplicates other agents recursively
// Implemented
-// Linear(x, int q, int r): represent "q*x + r"
-// Concrete(int k): represent a concrete value k
+// Linear(x, float q, float r): represent "q*x + r"
+// Concrete(float k): represent a concrete value k
// Symbolic(id): represent the variable id
// Add(out, b): represent the addition (has various steps AddCheckLinear/AddCheckConcrete)
// Mul(out, b): represent the multiplication (has various steps MulCheckLinear/MulCheckConcrete)
@@ -15,55 +15,55 @@
// TODO: add range information to enable ReLU elimination
// Rules
-Linear(x, int q, int r) >< Add(out, b) => b ~ AddCheckLinear(out, x, q, r);
+Linear(x, float q, float r) >< Add(out, b) => b ~ AddCheckLinear(out, x, q, r);
-Concrete(int k) >< Add(out, b)
+Concrete(float k) >< Add(out, b)
| k == 0 => out ~ b
| _ => b ~ AddCheckConcrete(out, k);
-Linear(y, int s, int t) >< AddCheckLinear(out, x, int q, int r)
+Linear(y, float s, float t) >< AddCheckLinear(out, x, float q, float r)
| (q == 0) && (r == 0) && (s == 0) && (t == 0) => out ~ Concrete(0), x ~ Eraser, y ~ Eraser
| (s == 0) && (t == 0) => Linear(x, q, r) ~ Materialize(out), y ~ Eraser
| (q == 0) && (r == 0) => (*L)Linear(y, s, t) ~ Materialize(out), x ~ Eraser
| _ => Linear(x, q, r) ~ Materialize(out_x), (*L)Linear(y, s, t) ~ Materialize(out_y), out ~ Linear(TermAdd(out_x, out_y), 1, 0);
-Concrete(int j) >< AddCheckLinear(out, x, int q, int r) => out ~ Linear(x, q, r + j);
+Concrete(float j) >< AddCheckLinear(out, x, float q, float r) => out ~ Linear(x, q, r + j);
-Linear(y, int s, int t) >< AddCheckConcrete(out, int k) => out ~ Linear(y, s, t + k);
+Linear(y, float s, float t) >< AddCheckConcrete(out, float k) => out ~ Linear(y, s, t + k);
-Concrete(int j) >< AddCheckConcrete(out, int k)
+Concrete(float j) >< AddCheckConcrete(out, float k)
| j == 0 => out ~ Concrete(k)
| _ => out ~ Concrete(k + j);
-Linear(x, int q, int r) >< Mul(out, b) => b ~ MulCheckLinear(out, x, q, r);
+Linear(x, float q, float r) >< Mul(out, b) => b ~ MulCheckLinear(out, x, q, r);
-Concrete(int k) >< Mul(out, b)
+Concrete(float k) >< Mul(out, b)
| k == 0 => b ~ Eraser, out ~ (*L)Concrete(0)
| k == 1 => out ~ b
| _ => b ~ MulCheckConcrete(out, k);
-Linear(y, int s, int t) >< MulCheckLinear(out, x, int q, int r)
+Linear(y, float s, float t) >< MulCheckLinear(out, x, float q, float r)
| (q == 0) && (r == 0) && (s == 0) && (t == 0) => out ~ Concrete(0), x ~ Eraser, y ~ Eraser
| (s == 0) && (t == 0) => Linear(x, q, r) ~ Materialize(out), y ~ Eraser
| (q == 0) && (r == 0) => (*L)Linear(y, s, t) ~ Materialize(out), x ~ Eraser
| _ => Linear(x, q, r) ~ Materialize(out_x), (*L)Linear(y, s, t) ~ Materialize(out_y), out ~ Linear(TermMul(out_x, out_y), 1, 0);
-Concrete(int j) >< MulCheckLinear(out, x, int q, int r) => out ~ Linear(x, q * j, r * j);
+Concrete(float j) >< MulCheckLinear(out, x, float q, float r) => out ~ Linear(x, q * j, r * j);
-Linear(y, int s, int t) >< MulCheckConcrete(out, int k) => out ~ Linear(y, s * k, t * k);
+Linear(y, float s, float t) >< MulCheckConcrete(out, float k) => out ~ Linear(y, s * k, t * k);
-Concrete(int j) >< MulCheckConcrete(out, int k)
+Concrete(float j) >< MulCheckConcrete(out, float k)
| j == 0 => out ~ Concrete(0)
| j == 1 => out ~ Concrete(k)
| _ => out ~ Concrete(k * j);
-Linear(x, int q, int r) >< ReLU(out) => (*L)Linear(x, q, r) ~ Materialize(out_x), out ~ Linear(TermReLU(out_x), 1, 0);
+Linear(x, float q, float r) >< ReLU(out) => (*L)Linear(x, q, r) ~ Materialize(out_x), out ~ Linear(TermReLU(out_x), 1, 0);
-Concrete(int k) >< ReLU(out)
+Concrete(float k) >< ReLU(out)
| k > 0 => out ~ (*L)Concrete(k)
| _ => out ~ Concrete(0);
-Linear(x, int q, int r) >< Materialize(out)
+Linear(x, float q, float r) >< Materialize(out)
| (q == 0) => out ~ Concrete(r), x ~ Eraser
| (q == 1) && (r == 0) => out ~ x
| (q == 1) && (r != 0) => out ~ TermAdd(x, Concrete(r))
@@ -72,35 +72,35 @@ Linear(x, int q, int r) >< Materialize(out)
// Network A
-Dup(v0, Dup(v1, Dup(v2, v3))) ~ Linear(Symbolic(0), 1, 0);
-Mul(v4, Concrete(-729)) ~ v0;
-Add(v5, v4) ~ Concrete(732);
-Mul(v6, Concrete(707)) ~ v1;
-Add(v7, v6) ~ Concrete(106);
-Mul(v8, Concrete(-577)) ~ v2;
-Add(v9, v8) ~ Concrete(-502);
-Mul(v10, Concrete(1070)) ~ v3;
-Add(v11, v10) ~ Concrete(-1068);
-Dup(v12, Dup(v13, Dup(v14, v15))) ~ Linear(Symbolic(1), 1, 0);
-Mul(v16, Concrete(-725)) ~ v12;
+Dup(v0, Dup(v1, Dup(v2, v3))) ~ Linear(Symbolic(X_0), 1.0, 0.0);
+Mul(v4, Concrete(1.249051570892334)) ~ v0;
+Add(v5, v4) ~ Concrete(-2.076689270325005e-05);
+Mul(v6, Concrete(0.8312496542930603)) ~ v1;
+Add(v7, v6) ~ Concrete(-0.8312351703643799);
+Mul(v8, Concrete(0.9251033663749695)) ~ v2;
+Add(v9, v8) ~ Concrete(-0.9250767230987549);
+Mul(v10, Concrete(0.3333963453769684)) ~ v3;
+Add(v11, v10) ~ Concrete(0.05585573986172676);
+Dup(v12, Dup(v13, Dup(v14, v15))) ~ Linear(Symbolic(X_1), 1.0, 0.0);
+Mul(v16, Concrete(0.8467237949371338)) ~ v12;
Add(v17, v16) ~ v5;
-Mul(v18, Concrete(708)) ~ v13;
+Mul(v18, Concrete(0.8312491774559021)) ~ v13;
Add(v19, v18) ~ v7;
-Mul(v20, Concrete(220)) ~ v14;
+Mul(v20, Concrete(0.9251176118850708)) ~ v14;
Add(v21, v20) ~ v9;
-Mul(v22, Concrete(1066)) ~ v15;
+Mul(v22, Concrete(1.084873080253601)) ~ v15;
Add(v23, v22) ~ v11;
ReLU(v24) ~ v17;
ReLU(v25) ~ v19;
ReLU(v26) ~ v21;
ReLU(v27) ~ v23;
-Mul(v28, Concrete(-642)) ~ v24;
-Add(v29, v28) ~ Concrete(390);
-Mul(v30, Concrete(753)) ~ v25;
+Mul(v28, Concrete(0.7005411982536316)) ~ v24;
+Add(v29, v28) ~ Concrete(-0.02095046266913414);
+Mul(v30, Concrete(-0.9663007259368896)) ~ v25;
Add(v31, v30) ~ v29;
-Mul(v32, Concrete(235)) ~ v26;
+Mul(v32, Concrete(-1.293721079826355)) ~ v26;
Add(v33, v32) ~ v31;
-Mul(v34, Concrete(-1440)) ~ v27;
+Mul(v34, Concrete(0.3750816583633423)) ~ v27;
Add(v35, v34) ~ v33;
Materialize(result0) ~ v35;
result0;
@@ -108,35 +108,35 @@ free ifce;
// Network B
-Dup(v0, Dup(v1, Dup(v2, v3))) ~ Linear(Symbolic(0), 1, 0);
-Mul(v4, Concrete(-181)) ~ v0;
-Add(v5, v4) ~ Concrete(-142);
-Mul(v6, Concrete(-1061)) ~ v1;
-Add(v7, v6) ~ Concrete(1050);
-Mul(v8, Concrete(1181)) ~ v2;
-Add(v9, v8) ~ Concrete(-568);
-Mul(v10, Concrete(-627)) ~ v3;
-Add(v11, v10) ~ Concrete(1236);
-Dup(v12, Dup(v13, Dup(v14, v15))) ~ Linear(Symbolic(1), 1, 0);
-Mul(v16, Concrete(-609)) ~ v12;
+Dup(v0, Dup(v1, Dup(v2, v3))) ~ Linear(Symbolic(X_0), 1.0, 0.0);
+Mul(v4, Concrete(1.1727254390716553)) ~ v0;
+Add(v5, v4) ~ Concrete(-0.005158121697604656);
+Mul(v6, Concrete(1.1684346199035645)) ~ v1;
+Add(v7, v6) ~ Concrete(-1.1664382219314575);
+Mul(v8, Concrete(-0.2502972185611725)) ~ v2;
+Add(v9, v8) ~ Concrete(-0.10056735575199127);
+Mul(v10, Concrete(-0.6796815395355225)) ~ v3;
+Add(v11, v10) ~ Concrete(-0.32640340924263);
+Dup(v12, Dup(v13, Dup(v14, v15))) ~ Linear(Symbolic(X_1), 1.0, 0.0);
+Mul(v16, Concrete(1.1758666038513184)) ~ v12;
Add(v17, v16) ~ v5;
-Mul(v18, Concrete(-1058)) ~ v13;
+Mul(v18, Concrete(1.1700055599212646)) ~ v13;
Add(v19, v18) ~ v7;
-Mul(v20, Concrete(1404)) ~ v14;
+Mul(v20, Concrete(0.02409248612821102)) ~ v14;
Add(v21, v20) ~ v9;
-Mul(v22, Concrete(-311)) ~ v15;
+Mul(v22, Concrete(-0.43328654766082764)) ~ v15;
Add(v23, v22) ~ v11;
ReLU(v24) ~ v17;
ReLU(v25) ~ v19;
ReLU(v26) ~ v21;
ReLU(v27) ~ v23;
-Mul(v28, Concrete(-313)) ~ v24;
-Add(v29, v28) ~ Concrete(1112);
-Mul(v30, Concrete(-1571)) ~ v25;
+Mul(v28, Concrete(0.8594199419021606)) ~ v24;
+Add(v29, v28) ~ Concrete(7.867255291671427e-09);
+Mul(v30, Concrete(-1.7184218168258667)) ~ v25;
Add(v31, v30) ~ v29;
-Mul(v32, Concrete(-615)) ~ v26;
+Mul(v32, Concrete(-0.207244873046875)) ~ v26;
Add(v33, v32) ~ v31;
-Mul(v34, Concrete(434)) ~ v27;
+Mul(v34, Concrete(-0.14912307262420654)) ~ v27;
Add(v35, v34) ~ v33;
Materialize(result0) ~ v35;
result0; \ No newline at end of file
diff --git a/xor.py b/xor.py
index 905ddfb..493eaef 100644
--- a/xor.py
+++ b/xor.py
@@ -37,7 +37,7 @@ def get_rules() -> str:
rules_lines.append(line)
return "".join(rules_lines)
-def export_to_inpla_wiring(model: nn.Module, input_shape: tuple, scale: int = 1000) -> str:
+def export_to_inpla_wiring(model: nn.Module, input_shape: tuple) -> str:
traced = fx.symbolic_trace(model)
name_gen = NameGen()
script: List[str] = []
@@ -46,7 +46,7 @@ def export_to_inpla_wiring(model: nn.Module, input_shape: tuple, scale: int = 10
for node in traced.graph.nodes:
if node.op == 'placeholder':
num_inputs = int(np.prod(input_shape))
- wire_map[node.name] = [f"Linear(Symbolic({i}), 1, 0)" for i in range(num_inputs)]
+ wire_map[node.name] = [f"Linear(Symbolic(X_{i}), 1.0, 0.0)" for i in range(num_inputs)]
elif node.op == 'call_module':
target_str = str(node.target)
@@ -61,8 +61,8 @@ def export_to_inpla_wiring(model: nn.Module, input_shape: tuple, scale: int = 10
wire_map[node.name] = input_wires
elif isinstance(module, nn.Linear):
- W = (module.weight.data.detach().cpu().numpy() * scale).astype(int)
- B = (module.bias.data.detach().cpu().numpy() * scale).astype(int)
+ W = (module.weight.data.detach().cpu().numpy()).astype(float)
+ B = (module.bias.data.detach().cpu().numpy()).astype(float)
out_dim, in_dim = W.shape
neuron_wires = [f"Concrete({B[j]})" for j in range(out_dim)]
@@ -70,7 +70,7 @@ def export_to_inpla_wiring(model: nn.Module, input_shape: tuple, scale: int = 10
for i in range(in_dim):
in_term = input_wires[i]
if out_dim == 1:
- weight = int(W[0, i])
+ weight = float(W[0, i])
if weight == 0:
script.append(f"Eraser ~ {in_term};")
elif weight == 1:
@@ -94,7 +94,7 @@ def export_to_inpla_wiring(model: nn.Module, input_shape: tuple, scale: int = 10
script.append(f"{nest_dups(branch_wires)} ~ {in_term};")
for j in range(out_dim):
- weight = int(W[j, i])
+ weight = float(W[j, i])
if weight == 0:
script.append(f"Eraser ~ {branch_wires[j]};")
elif weight == 1: