diff options
| author | ericmarin <maarin.eric@gmail.com> | 2026-03-16 19:36:31 +0100 |
|---|---|---|
| committer | ericmarin <maarin.eric@gmail.com> | 2026-03-17 17:27:47 +0100 |
| commit | 5ff90e94c9bb411a0262a8130a6f0ce4125ca11b (patch) | |
| tree | 80103130dae1d4bfa4cee6537a72c30777ed6a2d | |
| parent | a0b1e7f6a8c11ed98ae20ac484e2fe9f75b9b85f (diff) | |
| download | vein-5ff90e94c9bb411a0262a8130a6f0ce4125ca11b.tar.gz vein-5ff90e94c9bb411a0262a8130a6f0ce4125ca11b.zip | |
changed torch.fx to ONNX
| -rw-r--r-- | nneq/nneq.py | 279 | ||||
| -rw-r--r-- | notes.norg | 9 | ||||
| -rw-r--r-- | proof.norg | 367 | ||||
| -rw-r--r-- | xor.py | 23 |
4 files changed, 514 insertions, 164 deletions
diff --git a/nneq/nneq.py b/nneq/nneq.py index 22cf171..d9d7d30 100644 --- a/nneq/nneq.py +++ b/nneq/nneq.py @@ -1,8 +1,9 @@ import z3 import re -import torch.fx as fx, torch.nn as nn import numpy as np import subprocess +import onnx +from onnx import numpy_helper from typing import List, Dict __all__ = ["net", "strict_equivalence", "epsilon_equivalence", "argmax_equivalence"] @@ -11,148 +12,130 @@ type inpla_str = str type z3_str = str rules: inpla_str = """ -Linear(x, float q, float r) >< Add(out, b) => b ~ AddCheckLinear(out, x, q, r); -Concrete(float k) >< Add(out, b) - | k == 0 => out ~ b - | _ => b ~ AddCheckConcrete(out, k); -Linear(y, float s, float t) >< AddCheckLinear(out, x, float q, float r) - | (q == 0) && (r == 0) && (s == 0) && (t == 0) => out ~ Concrete(0), x ~ Eraser, y ~ Eraser - | (s == 0) && (t == 0) => out ~ Linear(x, q, r), y ~ Eraser - | (q == 0) && (r == 0) => out ~ (*L)Linear(y, s, t), x ~ Eraser - | _ => Linear(x, q, r) ~ Materialize(out_x), (*L)Linear(y, s, t) ~ Materialize(out_y), out ~ Linear(TermAdd(out_x, out_y), 1, 0); -Concrete(float j) >< AddCheckLinear(out, x, float q, float r) => out ~ Linear(x, q, r + j); -Linear(y, float s, float t) >< AddCheckConcrete(out, float k) => out ~ Linear(y, s, t + k); -Concrete(float j) >< AddCheckConcrete(out, float k) - | j == 0 => out ~ Concrete(k) - | _ => out ~ Concrete(k + j); -Linear(x, float q, float r) >< Mul(out, b) => b ~ MulCheckLinear(out, x, q, r); -Concrete(float k) >< Mul(out, b) - | k == 0 => b ~ Eraser, out ~ (*L)Concrete(0) - | k == 1 => out ~ b - | _ => b ~ MulCheckConcrete(out, k); -Linear(y, float s, float t) >< MulCheckLinear(out, x, float q, float r) - | ((q == 0) && (r == 0)) || ((s == 0) && (t == 0)) => out ~ Concrete(0), x ~ Eraser, y ~ Eraser - | _ => Linear(x, q, r) ~ Materialize(out_x), (*L)Linear(y, s, t) ~ Materialize(out_y), out ~ Linear(TermMul(out_x, out_y), 1, 0); -Concrete(float j) >< MulCheckLinear(out, x, float q, float r) => out ~ Linear(x, q * j, r * j); -Linear(y, float s, float t) >< MulCheckConcrete(out, float k) => out ~ Linear(y, s * k, t * k); -Concrete(float j) >< MulCheckConcrete(out, float k) - | j == 0 => out ~ Concrete(0) - | j == 1 => out ~ Concrete(k) - | _ => out ~ Concrete(k * j); -Linear(x, float q, float r) >< ReLU(out) => (*L)Linear(x, q, r) ~ Materialize(out_x), out ~ Linear(TermReLU(out_x), 1, 0); -Concrete(float k) >< ReLU(out) - | k > 0 => out ~ (*L)Concrete(k) - | _ => out ~ Concrete(0); -Linear(x, float q, float r) >< Materialize(out) - | (q == 0) => out ~ Concrete(r), x ~ Eraser - | (q == 1) && (r == 0) => out ~ x - | (q == 1) && (r != 0) => out ~ TermAdd(x, Concrete(r)) - | (q != 0) && (r == 0) => out ~ TermMul(Concrete(q), x) - | _ => out ~ TermAdd(TermMul(Concrete(q), x), Concrete(r)); -Concrete(float k) >< Materialize(out) => out ~ (*L)Concrete(k); + Linear(x, float q, float r) >< Add(out, b) => b ~ AddCheckLinear(out, x, q, r); + Concrete(float k) >< Add(out, b) + | k == 0 => out ~ b + | _ => b ~ AddCheckConcrete(out, k); + Linear(y, float s, float t) >< AddCheckLinear(out, x, float q, float r) + | (q == 0) && (r == 0) && (s == 0) && (t == 0) => out ~ Concrete(0), x ~ Eraser, y ~ Eraser + | (s == 0) && (t == 0) => out ~ Linear(x, q, r), y ~ Eraser + | (q == 0) && (r == 0) => out ~ (*L)Linear(y, s, t), x ~ Eraser + | _ => Linear(x, q, r) ~ Materialize(out_x), (*L)Linear(y, s, t) ~ Materialize(out_y), out ~ Linear(TermAdd(out_x, out_y), 1, 0); + Concrete(float j) >< AddCheckLinear(out, x, float q, float r) => out ~ Linear(x, q, r + j); + Linear(y, float s, float t) >< AddCheckConcrete(out, float k) => out ~ Linear(y, s, t + k); + Concrete(float j) >< AddCheckConcrete(out, float k) + | j == 0 => out ~ Concrete(k) + | _ => out ~ Concrete(k + j); + Linear(x, float q, float r) >< Mul(out, b) => b ~ MulCheckLinear(out, x, q, r); + Concrete(float k) >< Mul(out, b) + | k == 0 => b ~ Eraser, out ~ (*L)Concrete(0) + | k == 1 => out ~ b + | _ => b ~ MulCheckConcrete(out, k); + Linear(y, float s, float t) >< MulCheckLinear(out, x, float q, float r) + | ((q == 0) && (r == 0)) || ((s == 0) && (t == 0)) => out ~ Concrete(0), x ~ Eraser, y ~ Eraser + | _ => Linear(x, q, r) ~ Materialize(out_x), (*L)Linear(y, s, t) ~ Materialize(out_y), out ~ Linear(TermMul(out_x, out_y), 1, 0); + Concrete(float j) >< MulCheckLinear(out, x, float q, float r) => out ~ Linear(x, q * j, r * j); + Linear(y, float s, float t) >< MulCheckConcrete(out, float k) => out ~ Linear(y, s * k, t * k); + Concrete(float j) >< MulCheckConcrete(out, float k) + | j == 0 => out ~ Concrete(0) + | j == 1 => out ~ Concrete(k) + | _ => out ~ Concrete(k * j); + Linear(x, float q, float r) >< ReLU(out) => (*L)Linear(x, q, r) ~ Materialize(out_x), out ~ Linear(TermReLU(out_x), 1, 0); + Concrete(float k) >< ReLU(out) + | k > 0 => out ~ (*L)Concrete(k) + | _ => out ~ Concrete(0); + Linear(x, float q, float r) >< Materialize(out) + | (q == 0) => out ~ Concrete(r), x ~ Eraser + | (q == 1) && (r == 0) => out ~ x + | (q == 1) && (r != 0) => out ~ TermAdd(x, Concrete(r)) + | (q != 0) && (r == 0) => out ~ TermMul(Concrete(q), x) + | _ => out ~ TermAdd(TermMul(Concrete(q), x), Concrete(r)); + Concrete(float k) >< Materialize(out) => out ~ (*L)Concrete(k); """ -class NameGen: - def __init__(self): - self.counter = 0 - def next(self) -> str: - name = f"v{self.counter}" - self.counter += 1 - return name - -def inpla_export(model: nn.Module, input_shape: tuple) -> inpla_str: - traced = fx.symbolic_trace(model) - name_gen = NameGen() - script: List[str] = [] - wire_map: Dict[str, List[str]] = {} - - for node in traced.graph.nodes: - if node.op == 'placeholder': - num_inputs = int(np.prod(input_shape)) - wire_map[node.name] = [f"Linear(Symbolic(X_{i}), 1.0, 0.0)" for i in range(num_inputs)] - - elif node.op == 'call_module': - target_str = str(node.target) - module = dict(model.named_modules())[target_str] - - input_node = node.args[0] - if not isinstance(input_node, fx.Node): - continue - input_wires = wire_map[input_node.name] - - if isinstance(module, nn.Flatten): - wire_map[node.name] = input_wires - - elif isinstance(module, nn.Linear): - W = (module.weight.data.detach().cpu().numpy()).astype(float) - B = (module.bias.data.detach().cpu().numpy()).astype(float) - out_dim, in_dim = W.shape - - neuron_wires = [f"Concrete({B[j]})" for j in range(out_dim)] - - for i in range(in_dim): - in_term = input_wires[i] - if out_dim == 1: - weight = float(W[0, i]) - if weight == 0: - script.append(f"Eraser ~ {in_term};") - elif weight == 1: - new_s = name_gen.next() - script.append(f"Add({new_s}, {in_term}) ~ {neuron_wires[0]};") - neuron_wires[0] = new_s - else: - mul_out = name_gen.next() - new_s = name_gen.next() - script.append(f"Mul({mul_out}, Concrete({weight})) ~ {in_term};") - script.append(f"Add({new_s}, {mul_out}) ~ {neuron_wires[0]};") - neuron_wires[0] = new_s - else: - branch_wires = [name_gen.next() for _ in range(out_dim)] - - def nest_dups(names: List[str]) -> str: - if len(names) == 1: return names[0] - if len(names) == 2: return f"Dup({names[0]}, {names[1]})" - return f"Dup({names[0]}, {nest_dups(names[1:])})" - - script.append(f"{nest_dups(branch_wires)} ~ {in_term};") - - for j in range(out_dim): - weight = float(W[j, i]) - if weight == 0: - script.append(f"Eraser ~ {branch_wires[j]};") - elif weight == 1: - new_s = name_gen.next() - script.append(f"Add({new_s}, {branch_wires[j]}) ~ {neuron_wires[j]};") - neuron_wires[j] = new_s - else: - mul_out = name_gen.next() - new_s = name_gen.next() - script.append(f"Mul({mul_out}, Concrete({weight})) ~ {branch_wires[j]};") - script.append(f"Add({new_s}, {mul_out}) ~ {neuron_wires[j]};") - neuron_wires[j] = new_s - - wire_map[node.name] = neuron_wires - - elif isinstance(module, nn.ReLU): - output_wires = [] - for i, w in enumerate(input_wires): - r_out = name_gen.next() - script.append(f"ReLU({r_out}) ~ {w};") - output_wires.append(r_out) - wire_map[node.name] = output_wires - - elif node.op == 'output': - output_node = node.args[0] - if isinstance(output_node, fx.Node): - final_wires = wire_map[output_node.name] - for i, w in enumerate(final_wires): - res_name = f"result{i}" - script.append(f"Materialize({res_name}) ~ {w};") - script.append(f"{res_name};") - - return "\n".join(script) +def inpla_export(model: onnx.ModelProto) -> inpla_str: + class NameGen: + def __init__(self): + self.counter = 0 + def next(self) -> str: + name = f"v{self.counter}" + self.counter += 1 + return name + + def get_initializers(graph) -> Dict[str, np.ndarray]: + initializers = {} + for init in graph.initializer: + initializers[init.name] = numpy_helper.to_array(init) + return initializers + + def get_attrs(node: onnx.NodeProto) -> Dict: + return {attr.name: onnx.helper.get_attribute_value(attr) for attr in node.attribute} + + def get_dim(name): + for i in list(graph.input) + list(graph.output) + list(graph.value_info): + if i.name == name: return i.type.tensor_type.shape.dim[-1].dim_value + return None + + def nest_dups(terms: List[str]) -> str: + if not terms: return "Eraser" + if len(terms) == 1: return terms[0] + return f"Dup({nest_dups(terms[:len(terms)//2])}, {nest_dups(terms[len(terms)//2:])})" + + def op_gemm(node): + attrs = get_attrs(node) + W = initializers[node.input[1]] + if not attrs.get("transB", 0): W = W.T + out_dim, in_dim = W.shape + B = initializers[node.input[2]] if len(node.input) > 2 else np.zeros(out_dim) + alpha, beta = attrs.get("alpha", 1.0), attrs.get("beta", 1.0) + + if node.input[0] not in interactions: interactions[node.input[0]] = [[] for _ in range(in_dim)] + + out_terms = interactions.get(node.output[0]) or [[] for _ in range(out_dim)] + + for j in range(out_dim): + chain = nest_dups(out_terms[j]) + for i in range(in_dim): + weight = float(alpha * W[j, i]) + if weight == 0: interactions[node.input[0]][i].append("Eraser") + else: + v = name_gen.next() + chain, term = f"Add({chain}, {v})", f"Mul({v}, Concrete({weight}))" + interactions[node.input[0]][i].append(term) + yield f"{chain} ~ Concrete({float(beta * B[j])});" + + def op_relu(node): + out_name, in_name = node.output[0], node.input[0] + if out_name in interactions: + dim = len(interactions[out_name]) + if in_name not in interactions: interactions[in_name] = [[] for _ in range(dim)] + for i in range(dim): + interactions[in_name][i].append(f"ReLU({nest_dups(interactions[out_name][i])})") + yield from [] + + graph, initializers, name_gen = model.graph, get_initializers(model.graph), NameGen() + interactions: Dict[str, List[List[str]]] = {} + ops = {"Gemm": op_gemm, "Relu": op_relu} + + if graph.output: + out = graph.output[0].name + dim = get_dim(out) + if dim: interactions[out] = [[f"Materialize(result{i})"] for i in range(dim)] + + node_script = [] + for node in reversed(graph.node): + if node.op_type in ops: node_script.extend(ops[node.op_type](node)) + + input_script = [] + if graph.input and graph.input[0].name in interactions: + for i, terms in enumerate(interactions[graph.input[0].name]): + input_script.append(f"{nest_dups(terms)} ~ Linear(Symbolic(X_{i}), 1.0, 0.0);") + + result_lines = [f"result{i};" for i in range(len(interactions.get(graph.output[0].name, [])))] + return "\n".join(input_script + list(reversed(node_script)) + result_lines) def inpla_run(model: inpla_str) -> z3_str: + print(model) return subprocess.run(["./inpla"], input=f"{rules}\n{model}", capture_output=True, text=True).stdout syms = {} @@ -168,11 +151,11 @@ def TermMul(a, b): return a * b def TermReLU(x): return z3.If(x > 0, x, 0) context = { - 'Concrete': Concrete, - 'Symbolic': Symbolic, - 'TermAdd': TermAdd, - 'TermMul': TermMul, - 'TermReLU': TermReLU + 'Concrete': Concrete, + 'Symbolic': Symbolic, + 'TermAdd': TermAdd, + 'TermMul': TermMul, + 'TermReLU': TermReLU } wrap = re.compile(r"Symbolic\((.*?)\)") @@ -181,8 +164,8 @@ def z3_evaluate(model: z3_str): model = wrap.sub(r'Symbolic("\1")', model); return eval(model, context) -def net(model: nn.Module, input_shape: tuple): - return z3_evaluate(inpla_run(inpla_export(model, input_shape))) +def net(model: onnx.ModelProto): + return z3_evaluate(inpla_run(inpla_export(model))) def strict_equivalence(net_a, net_b): @@ -210,7 +193,7 @@ def epsilon_equivalence(net_a, net_b, epsilon): for sym in syms.values(): solver.add(z3.Or(sym == 0, sym == 1)) - + solver.add(z3.Abs(net_a - net_b) > epsilon) result = solver.check() @@ -230,7 +213,7 @@ def argmax_equivalence(net_a, net_b): for sym in syms.values(): solver.add(z3.Or(sym == 0, sym == 1)) - + solver.add((net_a > 0.5) != (net_b > 0.5)) result = solver.check() @@ -1,19 +1,16 @@ @document.meta title: Neural Network Equivalence -description: WIP tool to prove NNEQ using Interaction Nets as pre-processor +description: WIP tool to prove NNEQ using Interaction Nets as pre-processor fo my Batchelor's Thesis authors: ericmarin categories: research created: 2026-03-14T09:21:24 -updated: 2026-03-14T18:34:04 +updated: 2026-03-17T11:18:11 version: 1.1.1 @end * TODO - (?) Scalability %Maybe done? I have increased the limits of Inpla, but I have yet to test% - - ( ) Soundness of translated NN - ~~ Define the semantic of the Agents (give a mathematical definition) - ~~ Prove that a Layer L and the Inpla translation represent the same function - ~~ Prove that each Interaction Rules preserve the mathematical semantic of the output + - (x) Soundness of translated NN - ( ) Compatibility with other types of NN - ( ) Comparison with other tool ({https://github.com/NeuralNetworkVerification/Marabou}[Marabou], {https://github.com/guykatzz/ReluplexCav2017}[Reluplex]) - ( ) Add Range agent to enable ReLU optimization diff --git a/proof.norg b/proof.norg new file mode 100644 index 0000000..4ec18ce --- /dev/null +++ b/proof.norg @@ -0,0 +1,367 @@ +@document.meta +title: proof +description: +authors: ericmarin +categories: +created: 2026-03-16T11:34:52 +updated: 2026-03-16T18:31:41 +version: 1.1.1 +@end + +* Proof for translation from Pytorch representation to Interaction Net graph + + +* Proof for the Interaction Rules +** Mathematical Definitions + - Linear(x, q, r) = q*x + r %with q,r Real% + - Concrete(k) = k %with k Real% + - Add(a, b) = a + b + - AddCheckLinear(x, q, r, b) = q*x + (r + b) %with q,r Real% + - AddCheckConcrete(k, b) = k + b %with k Real% + - Mul(a, b) = a * b + - MulCheckLinear(x, q, r, b) = q*b*x + r*b %with q,r Real% + - MulCheckConcrete(k, b) = k*b %with k Real% + - ReLU(x) = IF (x > 0) THEN x ELSE 0 + - Materialize(x) = x + +** Rules +*** Formatting + Agent1 >< Agent2 => Wiring + + LEFT SIDE MATHEMATICAL INTERPRETATION + + RIGHT SIDE MATHEMATICAL INTERPRETATION + + SHOWING EQUIVALENCE + +*** Materialize + The Materialize agent transforms a Linear agent into a tree of explicit mathematical operations + that are used as final representation for the solver. + In the Python module the terms are defined as: + @code python + def TermAdd(a, b): + return a + b + def TermMul(a, b): + return a * b + def TermReLU(x): + return z3.If(x > 0, x, 0) + @end +**** Linear(x, q, r) >< Materialize(out) => (1), (2), (3), (4), (5) + + Linear(x, q, r) = term + Materialize(term) = out + out = q*x + r + + $$ Case 1: q = 0 => out ~ Concrete(r), x ~ Eraser + Concrete(r) = out + out = r + + 0*x + r = r => r = r + $$ + + $$ Case 2: q = 1, r = 0 => out ~ x + x = out + out = x + + 1*x + 0 = x => x = x + $$ + + $$ Case 3: q = 1 => out ~ TermAdd(x, Concrete(r)) + TermAdd(x, Concrete(r)) = out + out = x + r + + 1*x + r = x + r => x + r = x + r + $$ + + $$ Case 4: r = 0 => out ~ TermMul(Concrete(q), x) + TermMul(Concrete(q), x) = out + out = q*x + + q*x + 0 = q*x => q*x = q*x + $$ + + $$ Case 5: otherwise => out ~ TermAdd(TermMul(Concrete(q), x), Concrete(r)) + TermAdd(TermMul(Concrete(q), x), r) = out + out = q*x + r + + q*x + r = q*x + r + $$ + +**** Concrete(k) >< Materialize(out) => out ~ Concrete(k) + + Concrete(k) = term + Materialize(term) = out + out = k + + Concrete(k) = out + out = k + + k = k + +*** Add +**** Linear(x, q, r) >< Add(out, b) => b ~ AddCheckLinear(out, x, q, r) + + Linear(x, q, r) = a + Add(a, b) = out + out = q*x + r + b + + AddCheckLinear(x, q, r, b) = out + out = q*x + (r + b) + + q*x + r + b = q*x + (r + b) => q*x + (r + b) = q*x + (r + b) + +**** Concrete(k) >< Add(out, b) => (1), (2) + + Concrete(k) = a + Add(a, b) = out + out = k + b + + $$ Case 1: k = 0 => out ~ b + b = out + out = b + + 0 + b = b => b = b + $$ + + $$ Case 2: otherwise => b ~ AddCheckConcrete(out, k) + AddCheckConcrete(k, b) = out + out = k + b + + k + b = k + b + $$ + +**** Linear(y, s, t) >< AddCheckLinear(out, x, q, r) => (1), (2), (3), (4) + + Linear(y, s, t) = b + AddCheckLinear(x, q, r, b) = out + out = q*x + (r + s*y + t) + + $$ Case 1: q,r,s,t = 0 => out ~ Concrete(0), x ~ Eraser, y ~ Eraser + Concrete(0) = out + out = 0 + + 0*x + (0 + 0*y + 0) = 0 => 0 = 0 + $$ + + $$ Case 2: s,t = 0 => out ~ Linear(x, q, r), y ~ Eraser + Linear(x, q, r) = out + out = q*x + r + + q*x + (r + 0*y + 0) = q*x + r => q*x + r = q*x + r + $$ + + $$ Case 3: q, r = 0 => out ~ Linear(y, s, t), x ~ Eraser + Linear(y, s, t) = out + out = s*y + t + + 0*x + (0 + s*y + t) = s*y + t => s*y + t = s*y + t + $$ + + $$ Case 4: otherwise => Linear(x, q, r) ~ Materialize(out_x), Linear(y, s, t) ~ Materialize(out_y), out ~ Linear(TermAdd(out_x, out_y), 1, 0) + Materialize(Linear(x, q, r)) = out_x + Materialize(Linear(y, s, t)) = out_y + Linear(TermAdd(out_x, out_y), 1, 0) = out + out_x = q*x + r + out_y = s*y + t + out = 1*TermAdd(q*x + r, s*y + t) + 0 + Because TermAdd(a, b) is defined as "a+b": + out = 1*(q*x + r + s*y + t) + 0 + + q*x + (r + s*y + t) = 1*(q*x + r + s*y + t) + 0 => q*x + r + s*y + t = q*x + r + s*y + t + $$ + +**** Concrete(j) >< AddCheckLinear(out, x, q, r) => out ~ Linear(x, q, r + j) + + Concrete(j) = b + AddCheckLinear(x, q, r, b) = out + out = q*x + (r + j) + + Linear(x, q, r + j) = out + out = q*x + (r + j) + + q*x + (r + j) = q*x + (r + j) + +**** Linear(y, s, t) >< AddCheckConcrete(out, k) => out ~ Linear(y, s, t + k) + + Linear(y, s, t) = b + AddCheckConcrete(k, b) = out + out = k + s*y + t + + Linear(y, s, t + k) + out = s*y + (t + k) + + k + s*y + t = s*y + (t + k) => s*y + (t + k) = s*y + (t + k) + +**** Concrete(j) >< AddCheckConcrete(out, k) => (1), (2) + + Concrete(j) = b + AddCheckConcrete(k, b) = out + out = k + j + + $$ Case 1: j = 0 => out ~ Concrete(k) + Concrete(k) = out + out = k + + k + 0 = k => k = k + $$ + + $$ Case 2: otherwise => out ~ Concrete(k + j) + Concrete(k + j) = out + out = k + j + + k + j = k + j + $$ + +*** Mul +**** Linear(x, q, r) >< Mul(out, b) => b ~ MulCheckLinear(out, x, q, r) + + Linear(x, q, r) = a + Mul(a, b) = out + out = (q*x + r) * b + + MulCheckLinear(x, q, r, b) = out + out = q*b*x + r*b + + (q*x + r) * b = q*b*x + r*b => q*b*x + r*b = q*b*x + r*b + +**** Concrete(k) >< Mul(out, b) => (1), (2), (3) + + Concrete(k) = a + Mul(a, b) = out + out = k * b + + $$ Case 1: k = 0 => out ~ Concrete(0), b ~ Eraser + Concrete(0) = out + out = 0 + + 0 * b = 0 => 0 = 0 + $$ + + $$ Case 2: k = 1 => out ~ b + b = out + out = b + + 1 * b = b => b = b + $$ + + $$ Case 3: otherwise => b ~ MulCheckConcrete(out, k) + MulCheckConcrete(k, b) = out + out = k * b + + k * b = k * b + $$ + +**** Linear(y, s, t) >< MulCheckLinear(out, x, q, r) => (1), (2) + + Linear(y, s, t) = b + MulCheckLinear(x, q, r, b) = out + out = q\*(s*y + t)\*x + r*(s*y + t) + + $$ Case 1: (q,r = 0) or (s,t = 0) => x ~ Eraser, y ~ Eraser, out ~ Concrete(0) + Concrete(0) = out + out = 0 + + 0\*(s*y + t)\*x + 0*(s*y + t) = 0 => 0 = 0 + or + q\*(0*y + 0)\*x + r*(0*y + 0) = 0 => 0 = 0 + $$ + + $$ Case 2: otherwise => Linear(x, q, r) ~ Materialize(out_x), Linear(y, s, t) ~ Materialize(out_y), out ~ Linear(TermMul(out_x, out_y), 1, 0) + Materialize(Linear(x, q, r)) = out_x + Materialize(Linear(y, s, t)) = out_y + Linear(TermMul(out_x, out_y), 1, 0) = out + out_x = q*x + r + out_y = s*y + t + out = 1*TermMul(q*x + r, s*y + t) + 0 + Because TermMul(a, b) is defined as "a*b": + out = 1*(q*x + r)*(s*y + t) + 0 + + q*(s*y + t)\*x + r*(s*y + t) = 1*(q*x + r)\*(s*y + t) => + q\*(s*y + t)\*x + r*(s*y + t) = (q*x + r)\*(s*y + t) => + q\*(s*y + t)\*x + r*(s*y + t) = q\*(s*y + t)\*x + r*(s*y + t) + $$ + +**** Concrete(j) >< MulCheckLinear(out, x, q, r) => out ~ Linear(x, q * j, r * j) + + Concrete(j) = b + MulCheckLinear(x, q, r, b) = out + out = q*j*x + r*j + + Linear(x, q * j, r * j) = out + out = q*j*x + r*j + + q*j*x + r*j = q*j*x + r*j + +**** Linear(y, s, t) >< MulCheckConcrete(out, k) => out ~ Linear(y, s * k, t * k) + + Linear(y, s, t) = b + MulCheckConcrete(k, b) = out + out = k * (s*y + t) + + Linear(y, s * k, t * k) = out + out = s*k*y + t*k + + k * (s*y + t) = s*k*y + t*k => s*k*y + t*k = s*k*y + t*k + +**** Concrete(j) >< MulCheckConcrete(out, k) => (1), (2), (3) + + Concrete(j) = b + MulCheckConcrete(k, b) = out + out = k * j + + $$ Case 1: j = 0 => out ~ Concrete(0) + Concrete(0) = out + out = 0 + + k * 0 = 0 => 0 = 0 + $$ + + $$ Case 2: j = 1 => out ~ Concrete(k) + Concrete(k) = out + out = k + + k * 1 = k => k = k + $$ + + $$ Case 3: otherwise => out ~ Concrete(k * j) + Concrete(k * j) = out + out = k * j + + k * j = k * j + +*** ReLU +**** Linear(x, q, r) >< ReLU(out) => Linear(x, q, r) ~ Materialize(out_x), out ~ Linear(TermReLU(out_x), 1, 0) + + Linear(x, q, r) = x + ReLU(x) = out + out = IF (q*x + r) > 0 THEN (q*x + r) ELSE 0 + + Materialize(Linear(x, q, r)) = out_x + Linear(TermReLU(out_x), 1, 0) = out + out_x = q*x + r + out = 1*TermReLU(q*x + r) + 0 + Because TermReLU(x) is defined as "z3.If(x > 0, x, 0)": + out = 1*(IF (q*x + r) > 0 THEN (q*x + r) ELSE 0) + 0 + + IF (q*x + r) > 0 THEN (q*x + r) ELSE 0 = 1*(IF (q*x + r) > 0 THEN (q*x + r) ELSE 0) + 0 => + IF (q*x + r) > 0 THEN (q*x + r) ELSE 0 = IF (q*x + r) > 0 THEN (q*x + r) ELSE 0 + +**** Concrete(k) >< ReLU(out) => (1), (2) + + Concrete(k) = x + ReLU(x) = out + out = IF k > 0 THEN k ELSE 0 + + $$ Case 1: k > 0 => out ~ Concrete(k) + Concrete(k) = out + out = k + + IF true THEN k ELSE 0 = k => k = k + $$ + + $$ Case 2: k <= 0 => out ~ Concrete(0) + Concrete(0) = out + out = 0 + + IF false THEN k ELSE 0 = 0 => 0 = 0 + $$ @@ -1,9 +1,10 @@ import torch import torch.nn as nn +import torch.onnx import nneq class xor_mlp(nn.Module): - def __init__(self, hidden_dim=4): + def __init__(self, hidden_dim=8): super().__init__() self.layers = nn.Sequential( nn.Linear(2, hidden_dim), @@ -13,13 +14,13 @@ class xor_mlp(nn.Module): def forward(self, x): return self.layers(x) -def train_model(name: str): +def train_model(name: str, dim): X = torch.tensor([[0,0], [0,1], [1,0], [1,1]], dtype=torch.float32) Y = torch.tensor([[0], [1], [1], [0]], dtype=torch.float32) - net = xor_mlp() + net = xor_mlp(hidden_dim=dim) loss_fn = nn.MSELoss() - optimizer = torch.optim.Adam(net.parameters(), lr=0.01) + optimizer = torch.optim.Adam(net.parameters(), lr=0.1) print(f"Training {name}...") for epoch in range(1000): @@ -33,16 +34,18 @@ def train_model(name: str): return net if __name__ == "__main__": - net_a = train_model("Network A") - net_b = train_model("Network B") + torch_net_a = train_model("Network A", 8).eval() + torch_net_b = train_model("Network B", 16).eval() + + onnx_net_a = torch.onnx.export(torch_net_a, (torch.randn(1, 2),), verbose=False, dynamo=True).model_proto # type: ignore + onnx_net_b = torch.onnx.export(torch_net_b, (torch.randn(1, 2),), verbose=False, dynamo=True).model_proto # type: ignore + + z3_net_a = nneq.net(onnx_net_a) + z3_net_b = nneq.net(onnx_net_b) - z3_net_a = nneq.net(net_a, (2,)) - z3_net_b = nneq.net(net_b, (2,)) - print("") nneq.strict_equivalence(z3_net_a, z3_net_b) print("") nneq.epsilon_equivalence(z3_net_a, z3_net_b, 0.1) print("") nneq.argmax_equivalence(z3_net_a, z3_net_b) - |
