aboutsummaryrefslogtreecommitdiff
path: root/xor
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--xor.in142
-rw-r--r--xor.py144
2 files changed, 11 insertions, 275 deletions
diff --git a/xor.in b/xor.in
deleted file mode 100644
index 6a71e7b..0000000
--- a/xor.in
+++ /dev/null
@@ -1,142 +0,0 @@
-// Agents
-// Built-in
-// Eraser: delete other agents recursively
-// Dup: duplicates other agents recursively
-
-// Implemented
-// Linear(x, float q, float r): represent "q*x + r"
-// Concrete(float k): represent a concrete value k
-// Symbolic(id): represent the variable id
-// Add(out, b): represent the addition (has various steps AddCheckLinear/AddCheckConcrete)
-// Mul(out, b): represent the multiplication (has various steps MulCheckLinear/MulCheckConcrete)
-// ReLU(out): represent "if x > 0 ? x ; 0"
-// Materialize(out): transforms a Linear packet into a final representation of TermAdd/TermMul/TermReLU
-
-// TODO: add range information to enable ReLU elimination
-
-// Rules
-Linear(x, float q, float r) >< Add(out, b) => b ~ AddCheckLinear(out, x, q, r);
-
-Concrete(float k) >< Add(out, b)
- | k == 0 => out ~ b
- | _ => b ~ AddCheckConcrete(out, k);
-
-Linear(y, float s, float t) >< AddCheckLinear(out, x, float q, float r)
- | (q == 0) && (r == 0) && (s == 0) && (t == 0) => out ~ Concrete(0), x ~ Eraser, y ~ Eraser
- | (s == 0) && (t == 0) => Linear(x, q, r) ~ Materialize(out), y ~ Eraser
- | (q == 0) && (r == 0) => (*L)Linear(y, s, t) ~ Materialize(out), x ~ Eraser
- | _ => Linear(x, q, r) ~ Materialize(out_x), (*L)Linear(y, s, t) ~ Materialize(out_y), out ~ Linear(TermAdd(out_x, out_y), 1, 0);
-
-Concrete(float j) >< AddCheckLinear(out, x, float q, float r) => out ~ Linear(x, q, r + j);
-
-Linear(y, float s, float t) >< AddCheckConcrete(out, float k) => out ~ Linear(y, s, t + k);
-
-Concrete(float j) >< AddCheckConcrete(out, float k)
- | j == 0 => out ~ Concrete(k)
- | _ => out ~ Concrete(k + j);
-
-Linear(x, float q, float r) >< Mul(out, b) => b ~ MulCheckLinear(out, x, q, r);
-
-Concrete(float k) >< Mul(out, b)
- | k == 0 => b ~ Eraser, out ~ (*L)Concrete(0)
- | k == 1 => out ~ b
- | _ => b ~ MulCheckConcrete(out, k);
-
-Linear(y, float s, float t) >< MulCheckLinear(out, x, float q, float r)
- | (q == 0) && (r == 0) && (s == 0) && (t == 0) => out ~ Concrete(0), x ~ Eraser, y ~ Eraser
- | (s == 0) && (t == 0) => Linear(x, q, r) ~ Materialize(out), y ~ Eraser
- | (q == 0) && (r == 0) => (*L)Linear(y, s, t) ~ Materialize(out), x ~ Eraser
- | _ => Linear(x, q, r) ~ Materialize(out_x), (*L)Linear(y, s, t) ~ Materialize(out_y), out ~ Linear(TermMul(out_x, out_y), 1, 0);
-
-Concrete(float j) >< MulCheckLinear(out, x, float q, float r) => out ~ Linear(x, q * j, r * j);
-
-Linear(y, float s, float t) >< MulCheckConcrete(out, float k) => out ~ Linear(y, s * k, t * k);
-
-Concrete(float j) >< MulCheckConcrete(out, float k)
- | j == 0 => out ~ Concrete(0)
- | j == 1 => out ~ Concrete(k)
- | _ => out ~ Concrete(k * j);
-
-Linear(x, float q, float r) >< ReLU(out) => (*L)Linear(x, q, r) ~ Materialize(out_x), out ~ Linear(TermReLU(out_x), 1, 0);
-
-Concrete(float k) >< ReLU(out)
- | k > 0 => out ~ (*L)Concrete(k)
- | _ => out ~ Concrete(0);
-
-Linear(x, float q, float r) >< Materialize(out)
- | (q == 0) => out ~ Concrete(r), x ~ Eraser
- | (q == 1) && (r == 0) => out ~ x
- | (q == 1) && (r != 0) => out ~ TermAdd(x, Concrete(r))
- | (q != 0) && (r == 0) => out ~ TermMul(Concrete(q), x)
- | _ => out ~ TermAdd(TermMul(Concrete(q), x), Concrete(r));
-
-
-// Network A
-Dup(v0, Dup(v1, Dup(v2, v3))) ~ Linear(Symbolic(X_0), 1.0, 0.0);
-Mul(v4, Concrete(1.249051570892334)) ~ v0;
-Add(v5, v4) ~ Concrete(-2.076689270325005e-05);
-Mul(v6, Concrete(0.8312496542930603)) ~ v1;
-Add(v7, v6) ~ Concrete(-0.8312351703643799);
-Mul(v8, Concrete(0.9251033663749695)) ~ v2;
-Add(v9, v8) ~ Concrete(-0.9250767230987549);
-Mul(v10, Concrete(0.3333963453769684)) ~ v3;
-Add(v11, v10) ~ Concrete(0.05585573986172676);
-Dup(v12, Dup(v13, Dup(v14, v15))) ~ Linear(Symbolic(X_1), 1.0, 0.0);
-Mul(v16, Concrete(0.8467237949371338)) ~ v12;
-Add(v17, v16) ~ v5;
-Mul(v18, Concrete(0.8312491774559021)) ~ v13;
-Add(v19, v18) ~ v7;
-Mul(v20, Concrete(0.9251176118850708)) ~ v14;
-Add(v21, v20) ~ v9;
-Mul(v22, Concrete(1.084873080253601)) ~ v15;
-Add(v23, v22) ~ v11;
-ReLU(v24) ~ v17;
-ReLU(v25) ~ v19;
-ReLU(v26) ~ v21;
-ReLU(v27) ~ v23;
-Mul(v28, Concrete(0.7005411982536316)) ~ v24;
-Add(v29, v28) ~ Concrete(-0.02095046266913414);
-Mul(v30, Concrete(-0.9663007259368896)) ~ v25;
-Add(v31, v30) ~ v29;
-Mul(v32, Concrete(-1.293721079826355)) ~ v26;
-Add(v33, v32) ~ v31;
-Mul(v34, Concrete(0.3750816583633423)) ~ v27;
-Add(v35, v34) ~ v33;
-Materialize(result0) ~ v35;
-result0;
-free ifce;
-
-
-// Network B
-Dup(v0, Dup(v1, Dup(v2, v3))) ~ Linear(Symbolic(X_0), 1.0, 0.0);
-Mul(v4, Concrete(1.1727254390716553)) ~ v0;
-Add(v5, v4) ~ Concrete(-0.005158121697604656);
-Mul(v6, Concrete(1.1684346199035645)) ~ v1;
-Add(v7, v6) ~ Concrete(-1.1664382219314575);
-Mul(v8, Concrete(-0.2502972185611725)) ~ v2;
-Add(v9, v8) ~ Concrete(-0.10056735575199127);
-Mul(v10, Concrete(-0.6796815395355225)) ~ v3;
-Add(v11, v10) ~ Concrete(-0.32640340924263);
-Dup(v12, Dup(v13, Dup(v14, v15))) ~ Linear(Symbolic(X_1), 1.0, 0.0);
-Mul(v16, Concrete(1.1758666038513184)) ~ v12;
-Add(v17, v16) ~ v5;
-Mul(v18, Concrete(1.1700055599212646)) ~ v13;
-Add(v19, v18) ~ v7;
-Mul(v20, Concrete(0.02409248612821102)) ~ v14;
-Add(v21, v20) ~ v9;
-Mul(v22, Concrete(-0.43328654766082764)) ~ v15;
-Add(v23, v22) ~ v11;
-ReLU(v24) ~ v17;
-ReLU(v25) ~ v19;
-ReLU(v26) ~ v21;
-ReLU(v27) ~ v23;
-Mul(v28, Concrete(0.8594199419021606)) ~ v24;
-Add(v29, v28) ~ Concrete(7.867255291671427e-09);
-Mul(v30, Concrete(-1.7184218168258667)) ~ v25;
-Add(v31, v30) ~ v29;
-Mul(v32, Concrete(-0.207244873046875)) ~ v26;
-Add(v33, v32) ~ v31;
-Mul(v34, Concrete(-0.14912307262420654)) ~ v27;
-Add(v35, v34) ~ v33;
-Materialize(result0) ~ v35;
-result0; \ No newline at end of file
diff --git a/xor.py b/xor.py
index 493eaef..9ab7be7 100644
--- a/xor.py
+++ b/xor.py
@@ -1,11 +1,8 @@
import torch
import torch.nn as nn
-import torch.fx as fx
-import numpy as np
-import os
-from typing import List, Dict
+import nneq
-class XOR_MLP(nn.Module):
+class xor_mlp(nn.Module):
def __init__(self, hidden_dim=4):
super().__init__()
self.layers = nn.Sequential(
@@ -16,124 +13,11 @@ class XOR_MLP(nn.Module):
def forward(self, x):
return self.layers(x)
-class NameGen:
- def __init__(self):
- self.counter = 0
- def next(self) -> str:
- name = f"v{self.counter}"
- self.counter += 1
- return name
-
-def get_rules() -> str:
- rules_path = os.path.join(os.path.dirname(__file__), "rules.in")
- if not os.path.exists(rules_path):
- return "// Rules not found in rules.in\n"
-
- rules_lines = []
- with open(rules_path, "r") as f:
- for line in f:
- if "// Net testing" in line:
- break
- rules_lines.append(line)
- return "".join(rules_lines)
-
-def export_to_inpla_wiring(model: nn.Module, input_shape: tuple) -> str:
- traced = fx.symbolic_trace(model)
- name_gen = NameGen()
- script: List[str] = []
- wire_map: Dict[str, List[str]] = {}
-
- for node in traced.graph.nodes:
- if node.op == 'placeholder':
- num_inputs = int(np.prod(input_shape))
- wire_map[node.name] = [f"Linear(Symbolic(X_{i}), 1.0, 0.0)" for i in range(num_inputs)]
-
- elif node.op == 'call_module':
- target_str = str(node.target)
- module = dict(model.named_modules())[target_str]
-
- input_node = node.args[0]
- if not isinstance(input_node, fx.Node):
- continue
- input_wires = wire_map[input_node.name]
-
- if isinstance(module, nn.Flatten):
- wire_map[node.name] = input_wires
-
- elif isinstance(module, nn.Linear):
- W = (module.weight.data.detach().cpu().numpy()).astype(float)
- B = (module.bias.data.detach().cpu().numpy()).astype(float)
- out_dim, in_dim = W.shape
-
- neuron_wires = [f"Concrete({B[j]})" for j in range(out_dim)]
-
- for i in range(in_dim):
- in_term = input_wires[i]
- if out_dim == 1:
- weight = float(W[0, i])
- if weight == 0:
- script.append(f"Eraser ~ {in_term};")
- elif weight == 1:
- new_s = name_gen.next()
- script.append(f"Add({new_s}, {in_term}) ~ {neuron_wires[0]};")
- neuron_wires[0] = new_s
- else:
- mul_out = name_gen.next()
- new_s = name_gen.next()
- script.append(f"Mul({mul_out}, Concrete({weight})) ~ {in_term};")
- script.append(f"Add({new_s}, {mul_out}) ~ {neuron_wires[0]};")
- neuron_wires[0] = new_s
- else:
- branch_wires = [name_gen.next() for _ in range(out_dim)]
-
- def nest_dups(names: List[str]) -> str:
- if len(names) == 1: return names[0]
- if len(names) == 2: return f"Dup({names[0]}, {names[1]})"
- return f"Dup({names[0]}, {nest_dups(names[1:])})"
-
- script.append(f"{nest_dups(branch_wires)} ~ {in_term};")
-
- for j in range(out_dim):
- weight = float(W[j, i])
- if weight == 0:
- script.append(f"Eraser ~ {branch_wires[j]};")
- elif weight == 1:
- new_s = name_gen.next()
- script.append(f"Add({new_s}, {branch_wires[j]}) ~ {neuron_wires[j]};")
- neuron_wires[j] = new_s
- else:
- mul_out = name_gen.next()
- new_s = name_gen.next()
- script.append(f"Mul({mul_out}, Concrete({weight})) ~ {branch_wires[j]};")
- script.append(f"Add({new_s}, {mul_out}) ~ {neuron_wires[j]};")
- neuron_wires[j] = new_s
-
- wire_map[node.name] = neuron_wires
-
- elif isinstance(module, nn.ReLU):
- output_wires = []
- for i, w in enumerate(input_wires):
- r_out = name_gen.next()
- script.append(f"ReLU({r_out}) ~ {w};")
- output_wires.append(r_out)
- wire_map[node.name] = output_wires
-
- elif node.op == 'output':
- output_node = node.args[0]
- if isinstance(output_node, fx.Node):
- final_wires = wire_map[output_node.name]
- for i, w in enumerate(final_wires):
- res_name = f"result{i}"
- script.append(f"Materialize({res_name}) ~ {w};")
- script.append(f"{res_name};")
-
- return "\n".join(script)
-
def train_model(name: str):
X = torch.tensor([[0,0], [0,1], [1,0], [1,1]], dtype=torch.float32)
Y = torch.tensor([[0], [1], [1], [0]], dtype=torch.float32)
- net = XOR_MLP()
+ net = xor_mlp()
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.01)
@@ -149,22 +33,16 @@ def train_model(name: str):
return net
if __name__ == "__main__":
- # Train two different models
net_a = train_model("Network A")
net_b = train_model("Network B")
- print("\nExporting both to xor.in...")
+ z3_net_a = nneq.net(net_a, (2,))
+ z3_net_b = nneq.net(net_b, (2,))
- rules = get_rules()
- wiring_a = export_to_inpla_wiring(net_a, (2,))
- wiring_b = export_to_inpla_wiring(net_b, (2,))
+ print("")
+ nneq.strict_equivalence(z3_net_a, z3_net_b)
+ print("")
+ nneq.epsilon_equivalence(z3_net_a, z3_net_b, 0.1)
+ print("")
+ nneq.argmax_equivalence(z3_net_a, z3_net_b)
- with open("xor.in", "w") as f:
- f.write(rules)
- f.write("\n\n// Network A\n")
- f.write(wiring_a)
- f.write("\nfree ifce;\n")
- f.write("\n\n// Network B\n")
- f.write(wiring_b)
-
- print("Done. Now run: inpla -f xor.in | python3 prover.py")