aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorericmarin <maarin.eric@gmail.com>2026-03-16 19:36:31 +0100
committerericmarin <maarin.eric@gmail.com>2026-03-17 17:27:47 +0100
commit5ff90e94c9bb411a0262a8130a6f0ce4125ca11b (patch)
tree80103130dae1d4bfa4cee6537a72c30777ed6a2d
parenta0b1e7f6a8c11ed98ae20ac484e2fe9f75b9b85f (diff)
downloadvein-5ff90e94c9bb411a0262a8130a6f0ce4125ca11b.tar.gz
vein-5ff90e94c9bb411a0262a8130a6f0ce4125ca11b.zip
changed torch.fx to ONNX
-rw-r--r--nneq/nneq.py279
-rw-r--r--notes.norg9
-rw-r--r--proof.norg367
-rw-r--r--xor.py23
4 files changed, 514 insertions, 164 deletions
diff --git a/nneq/nneq.py b/nneq/nneq.py
index 22cf171..d9d7d30 100644
--- a/nneq/nneq.py
+++ b/nneq/nneq.py
@@ -1,8 +1,9 @@
import z3
import re
-import torch.fx as fx, torch.nn as nn
import numpy as np
import subprocess
+import onnx
+from onnx import numpy_helper
from typing import List, Dict
__all__ = ["net", "strict_equivalence", "epsilon_equivalence", "argmax_equivalence"]
@@ -11,148 +12,130 @@ type inpla_str = str
type z3_str = str
rules: inpla_str = """
-Linear(x, float q, float r) >< Add(out, b) => b ~ AddCheckLinear(out, x, q, r);
-Concrete(float k) >< Add(out, b)
- | k == 0 => out ~ b
- | _ => b ~ AddCheckConcrete(out, k);
-Linear(y, float s, float t) >< AddCheckLinear(out, x, float q, float r)
- | (q == 0) && (r == 0) && (s == 0) && (t == 0) => out ~ Concrete(0), x ~ Eraser, y ~ Eraser
- | (s == 0) && (t == 0) => out ~ Linear(x, q, r), y ~ Eraser
- | (q == 0) && (r == 0) => out ~ (*L)Linear(y, s, t), x ~ Eraser
- | _ => Linear(x, q, r) ~ Materialize(out_x), (*L)Linear(y, s, t) ~ Materialize(out_y), out ~ Linear(TermAdd(out_x, out_y), 1, 0);
-Concrete(float j) >< AddCheckLinear(out, x, float q, float r) => out ~ Linear(x, q, r + j);
-Linear(y, float s, float t) >< AddCheckConcrete(out, float k) => out ~ Linear(y, s, t + k);
-Concrete(float j) >< AddCheckConcrete(out, float k)
- | j == 0 => out ~ Concrete(k)
- | _ => out ~ Concrete(k + j);
-Linear(x, float q, float r) >< Mul(out, b) => b ~ MulCheckLinear(out, x, q, r);
-Concrete(float k) >< Mul(out, b)
- | k == 0 => b ~ Eraser, out ~ (*L)Concrete(0)
- | k == 1 => out ~ b
- | _ => b ~ MulCheckConcrete(out, k);
-Linear(y, float s, float t) >< MulCheckLinear(out, x, float q, float r)
- | ((q == 0) && (r == 0)) || ((s == 0) && (t == 0)) => out ~ Concrete(0), x ~ Eraser, y ~ Eraser
- | _ => Linear(x, q, r) ~ Materialize(out_x), (*L)Linear(y, s, t) ~ Materialize(out_y), out ~ Linear(TermMul(out_x, out_y), 1, 0);
-Concrete(float j) >< MulCheckLinear(out, x, float q, float r) => out ~ Linear(x, q * j, r * j);
-Linear(y, float s, float t) >< MulCheckConcrete(out, float k) => out ~ Linear(y, s * k, t * k);
-Concrete(float j) >< MulCheckConcrete(out, float k)
- | j == 0 => out ~ Concrete(0)
- | j == 1 => out ~ Concrete(k)
- | _ => out ~ Concrete(k * j);
-Linear(x, float q, float r) >< ReLU(out) => (*L)Linear(x, q, r) ~ Materialize(out_x), out ~ Linear(TermReLU(out_x), 1, 0);
-Concrete(float k) >< ReLU(out)
- | k > 0 => out ~ (*L)Concrete(k)
- | _ => out ~ Concrete(0);
-Linear(x, float q, float r) >< Materialize(out)
- | (q == 0) => out ~ Concrete(r), x ~ Eraser
- | (q == 1) && (r == 0) => out ~ x
- | (q == 1) && (r != 0) => out ~ TermAdd(x, Concrete(r))
- | (q != 0) && (r == 0) => out ~ TermMul(Concrete(q), x)
- | _ => out ~ TermAdd(TermMul(Concrete(q), x), Concrete(r));
-Concrete(float k) >< Materialize(out) => out ~ (*L)Concrete(k);
+ Linear(x, float q, float r) >< Add(out, b) => b ~ AddCheckLinear(out, x, q, r);
+ Concrete(float k) >< Add(out, b)
+ | k == 0 => out ~ b
+ | _ => b ~ AddCheckConcrete(out, k);
+ Linear(y, float s, float t) >< AddCheckLinear(out, x, float q, float r)
+ | (q == 0) && (r == 0) && (s == 0) && (t == 0) => out ~ Concrete(0), x ~ Eraser, y ~ Eraser
+ | (s == 0) && (t == 0) => out ~ Linear(x, q, r), y ~ Eraser
+ | (q == 0) && (r == 0) => out ~ (*L)Linear(y, s, t), x ~ Eraser
+ | _ => Linear(x, q, r) ~ Materialize(out_x), (*L)Linear(y, s, t) ~ Materialize(out_y), out ~ Linear(TermAdd(out_x, out_y), 1, 0);
+ Concrete(float j) >< AddCheckLinear(out, x, float q, float r) => out ~ Linear(x, q, r + j);
+ Linear(y, float s, float t) >< AddCheckConcrete(out, float k) => out ~ Linear(y, s, t + k);
+ Concrete(float j) >< AddCheckConcrete(out, float k)
+ | j == 0 => out ~ Concrete(k)
+ | _ => out ~ Concrete(k + j);
+ Linear(x, float q, float r) >< Mul(out, b) => b ~ MulCheckLinear(out, x, q, r);
+ Concrete(float k) >< Mul(out, b)
+ | k == 0 => b ~ Eraser, out ~ (*L)Concrete(0)
+ | k == 1 => out ~ b
+ | _ => b ~ MulCheckConcrete(out, k);
+ Linear(y, float s, float t) >< MulCheckLinear(out, x, float q, float r)
+ | ((q == 0) && (r == 0)) || ((s == 0) && (t == 0)) => out ~ Concrete(0), x ~ Eraser, y ~ Eraser
+ | _ => Linear(x, q, r) ~ Materialize(out_x), (*L)Linear(y, s, t) ~ Materialize(out_y), out ~ Linear(TermMul(out_x, out_y), 1, 0);
+ Concrete(float j) >< MulCheckLinear(out, x, float q, float r) => out ~ Linear(x, q * j, r * j);
+ Linear(y, float s, float t) >< MulCheckConcrete(out, float k) => out ~ Linear(y, s * k, t * k);
+ Concrete(float j) >< MulCheckConcrete(out, float k)
+ | j == 0 => out ~ Concrete(0)
+ | j == 1 => out ~ Concrete(k)
+ | _ => out ~ Concrete(k * j);
+ Linear(x, float q, float r) >< ReLU(out) => (*L)Linear(x, q, r) ~ Materialize(out_x), out ~ Linear(TermReLU(out_x), 1, 0);
+ Concrete(float k) >< ReLU(out)
+ | k > 0 => out ~ (*L)Concrete(k)
+ | _ => out ~ Concrete(0);
+ Linear(x, float q, float r) >< Materialize(out)
+ | (q == 0) => out ~ Concrete(r), x ~ Eraser
+ | (q == 1) && (r == 0) => out ~ x
+ | (q == 1) && (r != 0) => out ~ TermAdd(x, Concrete(r))
+ | (q != 0) && (r == 0) => out ~ TermMul(Concrete(q), x)
+ | _ => out ~ TermAdd(TermMul(Concrete(q), x), Concrete(r));
+ Concrete(float k) >< Materialize(out) => out ~ (*L)Concrete(k);
"""
-class NameGen:
- def __init__(self):
- self.counter = 0
- def next(self) -> str:
- name = f"v{self.counter}"
- self.counter += 1
- return name
-
-def inpla_export(model: nn.Module, input_shape: tuple) -> inpla_str:
- traced = fx.symbolic_trace(model)
- name_gen = NameGen()
- script: List[str] = []
- wire_map: Dict[str, List[str]] = {}
-
- for node in traced.graph.nodes:
- if node.op == 'placeholder':
- num_inputs = int(np.prod(input_shape))
- wire_map[node.name] = [f"Linear(Symbolic(X_{i}), 1.0, 0.0)" for i in range(num_inputs)]
-
- elif node.op == 'call_module':
- target_str = str(node.target)
- module = dict(model.named_modules())[target_str]
-
- input_node = node.args[0]
- if not isinstance(input_node, fx.Node):
- continue
- input_wires = wire_map[input_node.name]
-
- if isinstance(module, nn.Flatten):
- wire_map[node.name] = input_wires
-
- elif isinstance(module, nn.Linear):
- W = (module.weight.data.detach().cpu().numpy()).astype(float)
- B = (module.bias.data.detach().cpu().numpy()).astype(float)
- out_dim, in_dim = W.shape
-
- neuron_wires = [f"Concrete({B[j]})" for j in range(out_dim)]
-
- for i in range(in_dim):
- in_term = input_wires[i]
- if out_dim == 1:
- weight = float(W[0, i])
- if weight == 0:
- script.append(f"Eraser ~ {in_term};")
- elif weight == 1:
- new_s = name_gen.next()
- script.append(f"Add({new_s}, {in_term}) ~ {neuron_wires[0]};")
- neuron_wires[0] = new_s
- else:
- mul_out = name_gen.next()
- new_s = name_gen.next()
- script.append(f"Mul({mul_out}, Concrete({weight})) ~ {in_term};")
- script.append(f"Add({new_s}, {mul_out}) ~ {neuron_wires[0]};")
- neuron_wires[0] = new_s
- else:
- branch_wires = [name_gen.next() for _ in range(out_dim)]
-
- def nest_dups(names: List[str]) -> str:
- if len(names) == 1: return names[0]
- if len(names) == 2: return f"Dup({names[0]}, {names[1]})"
- return f"Dup({names[0]}, {nest_dups(names[1:])})"
-
- script.append(f"{nest_dups(branch_wires)} ~ {in_term};")
-
- for j in range(out_dim):
- weight = float(W[j, i])
- if weight == 0:
- script.append(f"Eraser ~ {branch_wires[j]};")
- elif weight == 1:
- new_s = name_gen.next()
- script.append(f"Add({new_s}, {branch_wires[j]}) ~ {neuron_wires[j]};")
- neuron_wires[j] = new_s
- else:
- mul_out = name_gen.next()
- new_s = name_gen.next()
- script.append(f"Mul({mul_out}, Concrete({weight})) ~ {branch_wires[j]};")
- script.append(f"Add({new_s}, {mul_out}) ~ {neuron_wires[j]};")
- neuron_wires[j] = new_s
-
- wire_map[node.name] = neuron_wires
-
- elif isinstance(module, nn.ReLU):
- output_wires = []
- for i, w in enumerate(input_wires):
- r_out = name_gen.next()
- script.append(f"ReLU({r_out}) ~ {w};")
- output_wires.append(r_out)
- wire_map[node.name] = output_wires
-
- elif node.op == 'output':
- output_node = node.args[0]
- if isinstance(output_node, fx.Node):
- final_wires = wire_map[output_node.name]
- for i, w in enumerate(final_wires):
- res_name = f"result{i}"
- script.append(f"Materialize({res_name}) ~ {w};")
- script.append(f"{res_name};")
-
- return "\n".join(script)
+def inpla_export(model: onnx.ModelProto) -> inpla_str:
+ class NameGen:
+ def __init__(self):
+ self.counter = 0
+ def next(self) -> str:
+ name = f"v{self.counter}"
+ self.counter += 1
+ return name
+
+ def get_initializers(graph) -> Dict[str, np.ndarray]:
+ initializers = {}
+ for init in graph.initializer:
+ initializers[init.name] = numpy_helper.to_array(init)
+ return initializers
+
+ def get_attrs(node: onnx.NodeProto) -> Dict:
+ return {attr.name: onnx.helper.get_attribute_value(attr) for attr in node.attribute}
+
+ def get_dim(name):
+ for i in list(graph.input) + list(graph.output) + list(graph.value_info):
+ if i.name == name: return i.type.tensor_type.shape.dim[-1].dim_value
+ return None
+
+ def nest_dups(terms: List[str]) -> str:
+ if not terms: return "Eraser"
+ if len(terms) == 1: return terms[0]
+ return f"Dup({nest_dups(terms[:len(terms)//2])}, {nest_dups(terms[len(terms)//2:])})"
+
+ def op_gemm(node):
+ attrs = get_attrs(node)
+ W = initializers[node.input[1]]
+ if not attrs.get("transB", 0): W = W.T
+ out_dim, in_dim = W.shape
+ B = initializers[node.input[2]] if len(node.input) > 2 else np.zeros(out_dim)
+ alpha, beta = attrs.get("alpha", 1.0), attrs.get("beta", 1.0)
+
+ if node.input[0] not in interactions: interactions[node.input[0]] = [[] for _ in range(in_dim)]
+
+ out_terms = interactions.get(node.output[0]) or [[] for _ in range(out_dim)]
+
+ for j in range(out_dim):
+ chain = nest_dups(out_terms[j])
+ for i in range(in_dim):
+ weight = float(alpha * W[j, i])
+ if weight == 0: interactions[node.input[0]][i].append("Eraser")
+ else:
+ v = name_gen.next()
+ chain, term = f"Add({chain}, {v})", f"Mul({v}, Concrete({weight}))"
+ interactions[node.input[0]][i].append(term)
+ yield f"{chain} ~ Concrete({float(beta * B[j])});"
+
+ def op_relu(node):
+ out_name, in_name = node.output[0], node.input[0]
+ if out_name in interactions:
+ dim = len(interactions[out_name])
+ if in_name not in interactions: interactions[in_name] = [[] for _ in range(dim)]
+ for i in range(dim):
+ interactions[in_name][i].append(f"ReLU({nest_dups(interactions[out_name][i])})")
+ yield from []
+
+ graph, initializers, name_gen = model.graph, get_initializers(model.graph), NameGen()
+ interactions: Dict[str, List[List[str]]] = {}
+ ops = {"Gemm": op_gemm, "Relu": op_relu}
+
+ if graph.output:
+ out = graph.output[0].name
+ dim = get_dim(out)
+ if dim: interactions[out] = [[f"Materialize(result{i})"] for i in range(dim)]
+
+ node_script = []
+ for node in reversed(graph.node):
+ if node.op_type in ops: node_script.extend(ops[node.op_type](node))
+
+ input_script = []
+ if graph.input and graph.input[0].name in interactions:
+ for i, terms in enumerate(interactions[graph.input[0].name]):
+ input_script.append(f"{nest_dups(terms)} ~ Linear(Symbolic(X_{i}), 1.0, 0.0);")
+
+ result_lines = [f"result{i};" for i in range(len(interactions.get(graph.output[0].name, [])))]
+ return "\n".join(input_script + list(reversed(node_script)) + result_lines)
def inpla_run(model: inpla_str) -> z3_str:
+ print(model)
return subprocess.run(["./inpla"], input=f"{rules}\n{model}", capture_output=True, text=True).stdout
syms = {}
@@ -168,11 +151,11 @@ def TermMul(a, b): return a * b
def TermReLU(x): return z3.If(x > 0, x, 0)
context = {
- 'Concrete': Concrete,
- 'Symbolic': Symbolic,
- 'TermAdd': TermAdd,
- 'TermMul': TermMul,
- 'TermReLU': TermReLU
+ 'Concrete': Concrete,
+ 'Symbolic': Symbolic,
+ 'TermAdd': TermAdd,
+ 'TermMul': TermMul,
+ 'TermReLU': TermReLU
}
wrap = re.compile(r"Symbolic\((.*?)\)")
@@ -181,8 +164,8 @@ def z3_evaluate(model: z3_str):
model = wrap.sub(r'Symbolic("\1")', model);
return eval(model, context)
-def net(model: nn.Module, input_shape: tuple):
- return z3_evaluate(inpla_run(inpla_export(model, input_shape)))
+def net(model: onnx.ModelProto):
+ return z3_evaluate(inpla_run(inpla_export(model)))
def strict_equivalence(net_a, net_b):
@@ -210,7 +193,7 @@ def epsilon_equivalence(net_a, net_b, epsilon):
for sym in syms.values():
solver.add(z3.Or(sym == 0, sym == 1))
-
+
solver.add(z3.Abs(net_a - net_b) > epsilon)
result = solver.check()
@@ -230,7 +213,7 @@ def argmax_equivalence(net_a, net_b):
for sym in syms.values():
solver.add(z3.Or(sym == 0, sym == 1))
-
+
solver.add((net_a > 0.5) != (net_b > 0.5))
result = solver.check()
diff --git a/notes.norg b/notes.norg
index 66dcbd9..4427354 100644
--- a/notes.norg
+++ b/notes.norg
@@ -1,19 +1,16 @@
@document.meta
title: Neural Network Equivalence
-description: WIP tool to prove NNEQ using Interaction Nets as pre-processor
+description: WIP tool to prove NNEQ using Interaction Nets as pre-processor fo my Batchelor's Thesis
authors: ericmarin
categories: research
created: 2026-03-14T09:21:24
-updated: 2026-03-14T18:34:04
+updated: 2026-03-17T11:18:11
version: 1.1.1
@end
* TODO
- (?) Scalability %Maybe done? I have increased the limits of Inpla, but I have yet to test%
- - ( ) Soundness of translated NN
- ~~ Define the semantic of the Agents (give a mathematical definition)
- ~~ Prove that a Layer L and the Inpla translation represent the same function
- ~~ Prove that each Interaction Rules preserve the mathematical semantic of the output
+ - (x) Soundness of translated NN
- ( ) Compatibility with other types of NN
- ( ) Comparison with other tool ({https://github.com/NeuralNetworkVerification/Marabou}[Marabou], {https://github.com/guykatzz/ReluplexCav2017}[Reluplex])
- ( ) Add Range agent to enable ReLU optimization
diff --git a/proof.norg b/proof.norg
new file mode 100644
index 0000000..4ec18ce
--- /dev/null
+++ b/proof.norg
@@ -0,0 +1,367 @@
+@document.meta
+title: proof
+description:
+authors: ericmarin
+categories:
+created: 2026-03-16T11:34:52
+updated: 2026-03-16T18:31:41
+version: 1.1.1
+@end
+
+* Proof for translation from Pytorch representation to Interaction Net graph
+
+
+* Proof for the Interaction Rules
+** Mathematical Definitions
+ - Linear(x, q, r) = q*x + r %with q,r Real%
+ - Concrete(k) = k %with k Real%
+ - Add(a, b) = a + b
+ - AddCheckLinear(x, q, r, b) = q*x + (r + b) %with q,r Real%
+ - AddCheckConcrete(k, b) = k + b %with k Real%
+ - Mul(a, b) = a * b
+ - MulCheckLinear(x, q, r, b) = q*b*x + r*b %with q,r Real%
+ - MulCheckConcrete(k, b) = k*b %with k Real%
+ - ReLU(x) = IF (x > 0) THEN x ELSE 0
+ - Materialize(x) = x
+
+** Rules
+*** Formatting
+ Agent1 >< Agent2 => Wiring
+
+ LEFT SIDE MATHEMATICAL INTERPRETATION
+
+ RIGHT SIDE MATHEMATICAL INTERPRETATION
+
+ SHOWING EQUIVALENCE
+
+*** Materialize
+ The Materialize agent transforms a Linear agent into a tree of explicit mathematical operations
+ that are used as final representation for the solver.
+ In the Python module the terms are defined as:
+ @code python
+ def TermAdd(a, b):
+ return a + b
+ def TermMul(a, b):
+ return a * b
+ def TermReLU(x):
+ return z3.If(x > 0, x, 0)
+ @end
+**** Linear(x, q, r) >< Materialize(out) => (1), (2), (3), (4), (5)
+
+ Linear(x, q, r) = term
+ Materialize(term) = out
+ out = q*x + r
+
+ $$ Case 1: q = 0 => out ~ Concrete(r), x ~ Eraser
+ Concrete(r) = out
+ out = r
+
+ 0*x + r = r => r = r
+ $$
+
+ $$ Case 2: q = 1, r = 0 => out ~ x
+ x = out
+ out = x
+
+ 1*x + 0 = x => x = x
+ $$
+
+ $$ Case 3: q = 1 => out ~ TermAdd(x, Concrete(r))
+ TermAdd(x, Concrete(r)) = out
+ out = x + r
+
+ 1*x + r = x + r => x + r = x + r
+ $$
+
+ $$ Case 4: r = 0 => out ~ TermMul(Concrete(q), x)
+ TermMul(Concrete(q), x) = out
+ out = q*x
+
+ q*x + 0 = q*x => q*x = q*x
+ $$
+
+ $$ Case 5: otherwise => out ~ TermAdd(TermMul(Concrete(q), x), Concrete(r))
+ TermAdd(TermMul(Concrete(q), x), r) = out
+ out = q*x + r
+
+ q*x + r = q*x + r
+ $$
+
+**** Concrete(k) >< Materialize(out) => out ~ Concrete(k)
+
+ Concrete(k) = term
+ Materialize(term) = out
+ out = k
+
+ Concrete(k) = out
+ out = k
+
+ k = k
+
+*** Add
+**** Linear(x, q, r) >< Add(out, b) => b ~ AddCheckLinear(out, x, q, r)
+
+ Linear(x, q, r) = a
+ Add(a, b) = out
+ out = q*x + r + b
+
+ AddCheckLinear(x, q, r, b) = out
+ out = q*x + (r + b)
+
+ q*x + r + b = q*x + (r + b) => q*x + (r + b) = q*x + (r + b)
+
+**** Concrete(k) >< Add(out, b) => (1), (2)
+
+ Concrete(k) = a
+ Add(a, b) = out
+ out = k + b
+
+ $$ Case 1: k = 0 => out ~ b
+ b = out
+ out = b
+
+ 0 + b = b => b = b
+ $$
+
+ $$ Case 2: otherwise => b ~ AddCheckConcrete(out, k)
+ AddCheckConcrete(k, b) = out
+ out = k + b
+
+ k + b = k + b
+ $$
+
+**** Linear(y, s, t) >< AddCheckLinear(out, x, q, r) => (1), (2), (3), (4)
+
+ Linear(y, s, t) = b
+ AddCheckLinear(x, q, r, b) = out
+ out = q*x + (r + s*y + t)
+
+ $$ Case 1: q,r,s,t = 0 => out ~ Concrete(0), x ~ Eraser, y ~ Eraser
+ Concrete(0) = out
+ out = 0
+
+ 0*x + (0 + 0*y + 0) = 0 => 0 = 0
+ $$
+
+ $$ Case 2: s,t = 0 => out ~ Linear(x, q, r), y ~ Eraser
+ Linear(x, q, r) = out
+ out = q*x + r
+
+ q*x + (r + 0*y + 0) = q*x + r => q*x + r = q*x + r
+ $$
+
+ $$ Case 3: q, r = 0 => out ~ Linear(y, s, t), x ~ Eraser
+ Linear(y, s, t) = out
+ out = s*y + t
+
+ 0*x + (0 + s*y + t) = s*y + t => s*y + t = s*y + t
+ $$
+
+ $$ Case 4: otherwise => Linear(x, q, r) ~ Materialize(out_x), Linear(y, s, t) ~ Materialize(out_y), out ~ Linear(TermAdd(out_x, out_y), 1, 0)
+ Materialize(Linear(x, q, r)) = out_x
+ Materialize(Linear(y, s, t)) = out_y
+ Linear(TermAdd(out_x, out_y), 1, 0) = out
+ out_x = q*x + r
+ out_y = s*y + t
+ out = 1*TermAdd(q*x + r, s*y + t) + 0
+ Because TermAdd(a, b) is defined as "a+b":
+ out = 1*(q*x + r + s*y + t) + 0
+
+ q*x + (r + s*y + t) = 1*(q*x + r + s*y + t) + 0 => q*x + r + s*y + t = q*x + r + s*y + t
+ $$
+
+**** Concrete(j) >< AddCheckLinear(out, x, q, r) => out ~ Linear(x, q, r + j)
+
+ Concrete(j) = b
+ AddCheckLinear(x, q, r, b) = out
+ out = q*x + (r + j)
+
+ Linear(x, q, r + j) = out
+ out = q*x + (r + j)
+
+ q*x + (r + j) = q*x + (r + j)
+
+**** Linear(y, s, t) >< AddCheckConcrete(out, k) => out ~ Linear(y, s, t + k)
+
+ Linear(y, s, t) = b
+ AddCheckConcrete(k, b) = out
+ out = k + s*y + t
+
+ Linear(y, s, t + k)
+ out = s*y + (t + k)
+
+ k + s*y + t = s*y + (t + k) => s*y + (t + k) = s*y + (t + k)
+
+**** Concrete(j) >< AddCheckConcrete(out, k) => (1), (2)
+
+ Concrete(j) = b
+ AddCheckConcrete(k, b) = out
+ out = k + j
+
+ $$ Case 1: j = 0 => out ~ Concrete(k)
+ Concrete(k) = out
+ out = k
+
+ k + 0 = k => k = k
+ $$
+
+ $$ Case 2: otherwise => out ~ Concrete(k + j)
+ Concrete(k + j) = out
+ out = k + j
+
+ k + j = k + j
+ $$
+
+*** Mul
+**** Linear(x, q, r) >< Mul(out, b) => b ~ MulCheckLinear(out, x, q, r)
+
+ Linear(x, q, r) = a
+ Mul(a, b) = out
+ out = (q*x + r) * b
+
+ MulCheckLinear(x, q, r, b) = out
+ out = q*b*x + r*b
+
+ (q*x + r) * b = q*b*x + r*b => q*b*x + r*b = q*b*x + r*b
+
+**** Concrete(k) >< Mul(out, b) => (1), (2), (3)
+
+ Concrete(k) = a
+ Mul(a, b) = out
+ out = k * b
+
+ $$ Case 1: k = 0 => out ~ Concrete(0), b ~ Eraser
+ Concrete(0) = out
+ out = 0
+
+ 0 * b = 0 => 0 = 0
+ $$
+
+ $$ Case 2: k = 1 => out ~ b
+ b = out
+ out = b
+
+ 1 * b = b => b = b
+ $$
+
+ $$ Case 3: otherwise => b ~ MulCheckConcrete(out, k)
+ MulCheckConcrete(k, b) = out
+ out = k * b
+
+ k * b = k * b
+ $$
+
+**** Linear(y, s, t) >< MulCheckLinear(out, x, q, r) => (1), (2)
+
+ Linear(y, s, t) = b
+ MulCheckLinear(x, q, r, b) = out
+ out = q\*(s*y + t)\*x + r*(s*y + t)
+
+ $$ Case 1: (q,r = 0) or (s,t = 0) => x ~ Eraser, y ~ Eraser, out ~ Concrete(0)
+ Concrete(0) = out
+ out = 0
+
+ 0\*(s*y + t)\*x + 0*(s*y + t) = 0 => 0 = 0
+ or
+ q\*(0*y + 0)\*x + r*(0*y + 0) = 0 => 0 = 0
+ $$
+
+ $$ Case 2: otherwise => Linear(x, q, r) ~ Materialize(out_x), Linear(y, s, t) ~ Materialize(out_y), out ~ Linear(TermMul(out_x, out_y), 1, 0)
+ Materialize(Linear(x, q, r)) = out_x
+ Materialize(Linear(y, s, t)) = out_y
+ Linear(TermMul(out_x, out_y), 1, 0) = out
+ out_x = q*x + r
+ out_y = s*y + t
+ out = 1*TermMul(q*x + r, s*y + t) + 0
+ Because TermMul(a, b) is defined as "a*b":
+ out = 1*(q*x + r)*(s*y + t) + 0
+
+ q*(s*y + t)\*x + r*(s*y + t) = 1*(q*x + r)\*(s*y + t) =>
+ q\*(s*y + t)\*x + r*(s*y + t) = (q*x + r)\*(s*y + t) =>
+ q\*(s*y + t)\*x + r*(s*y + t) = q\*(s*y + t)\*x + r*(s*y + t)
+ $$
+
+**** Concrete(j) >< MulCheckLinear(out, x, q, r) => out ~ Linear(x, q * j, r * j)
+
+ Concrete(j) = b
+ MulCheckLinear(x, q, r, b) = out
+ out = q*j*x + r*j
+
+ Linear(x, q * j, r * j) = out
+ out = q*j*x + r*j
+
+ q*j*x + r*j = q*j*x + r*j
+
+**** Linear(y, s, t) >< MulCheckConcrete(out, k) => out ~ Linear(y, s * k, t * k)
+
+ Linear(y, s, t) = b
+ MulCheckConcrete(k, b) = out
+ out = k * (s*y + t)
+
+ Linear(y, s * k, t * k) = out
+ out = s*k*y + t*k
+
+ k * (s*y + t) = s*k*y + t*k => s*k*y + t*k = s*k*y + t*k
+
+**** Concrete(j) >< MulCheckConcrete(out, k) => (1), (2), (3)
+
+ Concrete(j) = b
+ MulCheckConcrete(k, b) = out
+ out = k * j
+
+ $$ Case 1: j = 0 => out ~ Concrete(0)
+ Concrete(0) = out
+ out = 0
+
+ k * 0 = 0 => 0 = 0
+ $$
+
+ $$ Case 2: j = 1 => out ~ Concrete(k)
+ Concrete(k) = out
+ out = k
+
+ k * 1 = k => k = k
+ $$
+
+ $$ Case 3: otherwise => out ~ Concrete(k * j)
+ Concrete(k * j) = out
+ out = k * j
+
+ k * j = k * j
+
+*** ReLU
+**** Linear(x, q, r) >< ReLU(out) => Linear(x, q, r) ~ Materialize(out_x), out ~ Linear(TermReLU(out_x), 1, 0)
+
+ Linear(x, q, r) = x
+ ReLU(x) = out
+ out = IF (q*x + r) > 0 THEN (q*x + r) ELSE 0
+
+ Materialize(Linear(x, q, r)) = out_x
+ Linear(TermReLU(out_x), 1, 0) = out
+ out_x = q*x + r
+ out = 1*TermReLU(q*x + r) + 0
+ Because TermReLU(x) is defined as "z3.If(x > 0, x, 0)":
+ out = 1*(IF (q*x + r) > 0 THEN (q*x + r) ELSE 0) + 0
+
+ IF (q*x + r) > 0 THEN (q*x + r) ELSE 0 = 1*(IF (q*x + r) > 0 THEN (q*x + r) ELSE 0) + 0 =>
+ IF (q*x + r) > 0 THEN (q*x + r) ELSE 0 = IF (q*x + r) > 0 THEN (q*x + r) ELSE 0
+
+**** Concrete(k) >< ReLU(out) => (1), (2)
+
+ Concrete(k) = x
+ ReLU(x) = out
+ out = IF k > 0 THEN k ELSE 0
+
+ $$ Case 1: k > 0 => out ~ Concrete(k)
+ Concrete(k) = out
+ out = k
+
+ IF true THEN k ELSE 0 = k => k = k
+ $$
+
+ $$ Case 2: k <= 0 => out ~ Concrete(0)
+ Concrete(0) = out
+ out = 0
+
+ IF false THEN k ELSE 0 = 0 => 0 = 0
+ $$
diff --git a/xor.py b/xor.py
index 9ab7be7..0f8390d 100644
--- a/xor.py
+++ b/xor.py
@@ -1,9 +1,10 @@
import torch
import torch.nn as nn
+import torch.onnx
import nneq
class xor_mlp(nn.Module):
- def __init__(self, hidden_dim=4):
+ def __init__(self, hidden_dim=8):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(2, hidden_dim),
@@ -13,13 +14,13 @@ class xor_mlp(nn.Module):
def forward(self, x):
return self.layers(x)
-def train_model(name: str):
+def train_model(name: str, dim):
X = torch.tensor([[0,0], [0,1], [1,0], [1,1]], dtype=torch.float32)
Y = torch.tensor([[0], [1], [1], [0]], dtype=torch.float32)
- net = xor_mlp()
+ net = xor_mlp(hidden_dim=dim)
loss_fn = nn.MSELoss()
- optimizer = torch.optim.Adam(net.parameters(), lr=0.01)
+ optimizer = torch.optim.Adam(net.parameters(), lr=0.1)
print(f"Training {name}...")
for epoch in range(1000):
@@ -33,16 +34,18 @@ def train_model(name: str):
return net
if __name__ == "__main__":
- net_a = train_model("Network A")
- net_b = train_model("Network B")
+ torch_net_a = train_model("Network A", 8).eval()
+ torch_net_b = train_model("Network B", 16).eval()
+
+ onnx_net_a = torch.onnx.export(torch_net_a, (torch.randn(1, 2),), verbose=False, dynamo=True).model_proto # type: ignore
+ onnx_net_b = torch.onnx.export(torch_net_b, (torch.randn(1, 2),), verbose=False, dynamo=True).model_proto # type: ignore
+
+ z3_net_a = nneq.net(onnx_net_a)
+ z3_net_b = nneq.net(onnx_net_b)
- z3_net_a = nneq.net(net_a, (2,))
- z3_net_b = nneq.net(net_b, (2,))
-
print("")
nneq.strict_equivalence(z3_net_a, z3_net_b)
print("")
nneq.epsilon_equivalence(z3_net_a, z3_net_b, 0.1)
print("")
nneq.argmax_equivalence(z3_net_a, z3_net_b)
-