aboutsummaryrefslogtreecommitdiff
path: root/nneq.py
diff options
context:
space:
mode:
authorericmarin <maarin.eric@gmail.com>2026-03-21 11:47:40 +0100
committerericmarin <maarin.eric@gmail.com>2026-03-21 12:00:16 +0100
commite2abe9d9ec649b849cc39b516c1db1b4fa592003 (patch)
treed74dcc2e0691bb587d2a9a695639517d3aec9256 /nneq.py
parentaf4335cf47984576e7493a0eb6569d3f6ecc31c8 (diff)
downloadvein-e2abe9d9ec649b849cc39b516c1db1b4fa592003.tar.gz
vein-e2abe9d9ec649b849cc39b516c1db1b4fa592003.zip
created class
Diffstat (limited to 'nneq.py')
-rw-r--r--nneq.py274
1 files changed, 274 insertions, 0 deletions
diff --git a/nneq.py b/nneq.py
new file mode 100644
index 0000000..660bcce
--- /dev/null
+++ b/nneq.py
@@ -0,0 +1,274 @@
+import z3
+import re
+import numpy as np
+import subprocess
+import onnx
+from onnx import numpy_helper
+from typing import List, Dict, Optional, Tuple
+import os
+import tempfile
+import hashlib
+
+rules = """
+ Linear(x, float q, float r) >< Add(out, b) => b ~ AddCheckLinear(out, x, q, r);
+ Concrete(float k) >< Add(out, b)
+ | k == 0 => out ~ b
+ | _ => b ~ AddCheckConcrete(out, k);
+ Linear(y, float s, float t) >< AddCheckLinear(out, x, float q, float r)
+ | (q == 0) && (r == 0) && (s == 0) && (t == 0) => out ~ Concrete(0), x ~ Eraser, y ~ Eraser
+ | (s == 0) && (t == 0) => out ~ Linear(x, q, r), y ~ Eraser
+ | (q == 0) && (r == 0) => out ~ (*L)Linear(y, s, t), x ~ Eraser
+ | _ => Linear(x, q, r) ~ Materialize(out_x), (*L)Linear(y, s, t) ~ Materialize(out_y), out ~ Linear(TermAdd(out_x, out_y), 1, 0);
+ Concrete(float j) >< AddCheckLinear(out, x, float q, float r) => out ~ Linear(x, q, r + j);
+ Linear(y, float s, float t) >< AddCheckConcrete(out, float k) => out ~ Linear(y, s, t + k);
+ Concrete(float j) >< AddCheckConcrete(out, float k)
+ | j == 0 => out ~ Concrete(k)
+ | _ => out ~ Concrete(k + j);
+ Linear(x, float q, float r) >< Mul(out, b) => b ~ MulCheckLinear(out, x, q, r);
+ Concrete(float k) >< Mul(out, b)
+ | k == 0 => b ~ Eraser, out ~ (*L)Concrete(0)
+ | k == 1 => out ~ b
+ | _ => b ~ MulCheckConcrete(out, k);
+ Linear(y, float s, float t) >< MulCheckLinear(out, x, float q, float r)
+ | ((q == 0) && (r == 0)) || ((s == 0) && (t == 0)) => out ~ Concrete(0), x ~ Eraser, y ~ Eraser
+ | _ => Linear(x, q, r) ~ Materialize(out_x), (*L)Linear(y, s, t) ~ Materialize(out_y), out ~ Linear(TermMul(out_x, out_y), 1, 0);
+ Concrete(float j) >< MulCheckLinear(out, x, float q, float r) => out ~ Linear(x, q * j, r * j);
+ Linear(y, float s, float t) >< MulCheckConcrete(out, float k) => out ~ Linear(y, s * k, t * k);
+ Concrete(float j) >< MulCheckConcrete(out, float k)
+ | j == 0 => out ~ Concrete(0)
+ | j == 1 => out ~ Concrete(k)
+ | _ => out ~ Concrete(k * j);
+ Linear(x, float q, float r) >< ReLU(out) => (*L)Linear(x, q, r) ~ Materialize(out_x), out ~ Linear(TermReLU(out_x), 1, 0);
+ Concrete(float k) >< ReLU(out)
+ | k > 0 => out ~ (*L)Concrete(k)
+ | _ => out ~ Concrete(0);
+ Linear(x, float q, float r) >< Materialize(out)
+ | (q == 0) => out ~ Concrete(r), x ~ Eraser
+ | (q == 1) && (r == 0) => out ~ x
+ | (q == 1) && (r != 0) => out ~ TermAdd(x, Concrete(r))
+ | (q != 0) && (r == 0) => out ~ TermMul(Concrete(q), x)
+ | _ => out ~ TermAdd(TermMul(Concrete(q), x), Concrete(r));
+ Concrete(float k) >< Materialize(out) => out ~ (*L)Concrete(k);
+"""
+
+_INPLA_CACHE: Dict[Tuple[str, Optional[Tuple]], str] = {}
+
+def inpla_export(model: onnx.ModelProto, bounds: Optional[Dict[str, List[float]]] = None) -> str:
+ _ = bounds # TODO: Add Range agent
+ class NameGen:
+ def __init__(self, prefix="v"):
+ self.counter = 0
+ self.prefix = prefix
+ def next(self) -> str:
+ name = f"{self.prefix}{self.counter}"
+ self.counter += 1
+ return name
+
+ def get_initializers(graph) -> Dict[str, np.ndarray]:
+ initializers = {}
+ for init in graph.initializer:
+ initializers[init.name] = numpy_helper.to_array(init)
+ return initializers
+
+ def get_attrs(node: onnx.NodeProto) -> Dict:
+ return {attr.name: onnx.helper.get_attribute_value(attr) for attr in node.attribute}
+
+ def get_dim(name):
+ for i in list(graph.input) + list(graph.output) + list(graph.value_info):
+ if i.name == name: return i.type.tensor_type.shape.dim[-1].dim_value
+ return None
+
+ def flatten_nest(agent_name: str, terms: List[str]) -> str:
+ if not terms: return "Eraser"
+ if len(terms) == 1: return terms[0]
+ current = terms[0]
+ for i in range(1, len(terms)):
+ wire = wire_gen.next()
+ script.append(f"{wire} ~ {agent_name}({current}, {terms[i]});")
+ current = wire
+ return current
+
+ def op_gemm(node):
+ attrs = get_attrs(node)
+ W = initializers[node.input[1]]
+ if not attrs.get("transB", 0): W = W.T
+ out_dim, in_dim = W.shape
+ B = initializers[node.input[2]] if len(node.input) > 2 else np.zeros(out_dim)
+ alpha, beta = attrs.get("alpha", 1.0), attrs.get("beta", 1.0)
+
+ if node.input[0] not in interactions:
+ interactions[node.input[0]] = [[] for _ in range(in_dim)]
+
+ out_terms = interactions.get(node.output[0]) or [[f"Materialize(result{j})"] for j in range(out_dim)]
+
+ for j in range(out_dim):
+ chain = flatten_nest("Dup", out_terms[j])
+ for i in range(in_dim):
+ weight = float(alpha * W[j, i])
+ if weight == 0:
+ interactions[node.input[0]][i].append("Eraser")
+ else:
+ v = var_gen.next()
+ wire = wire_gen.next()
+ script.append(f"{wire} ~ Add({chain}, {v});")
+ chain = wire
+ interactions[node.input[0]][i].append(f"Mul({v}, Concrete({weight}))")
+
+ script.append(f"{chain} ~ Concrete({float(beta * B[j])});")
+ yield from []
+
+ def op_relu(node):
+ out_name, in_name = node.output[0], node.input[0]
+ if out_name in interactions:
+ dim = len(interactions[out_name])
+ if in_name not in interactions:
+ interactions[in_name] = [[] for _ in range(dim)]
+ for i in range(dim):
+ dup_chain = flatten_nest("Dup", interactions[out_name][i])
+ v = var_gen.next()
+ interactions[in_name][i].append(f"ReLU({v})")
+ script.append(f"{v} ~ {dup_chain};")
+ yield from []
+
+ graph, initializers = model.graph, get_initializers(model.graph)
+ var_gen, wire_gen = NameGen("v"), NameGen("w")
+ interactions: Dict[str, List[List[str]]] = {}
+ script = []
+ ops = {"Gemm": op_gemm, "Relu": op_relu}
+
+ if graph.output:
+ out = graph.output[0].name
+ dim = get_dim(out)
+ if dim:
+ interactions[out] = [[f"Materialize(result{i})"] for i in range(dim)]
+
+ for node in reversed(graph.node):
+ if node.op_type in ops:
+ for _ in ops[node.op_type](node): pass
+ else:
+ raise RuntimeError(f"Unsupported ONNX operator: {node.op_type}")
+
+ if graph.input and graph.input[0].name in interactions:
+ for i, terms in enumerate(interactions[graph.input[0].name]):
+ dup_chain = flatten_nest("Dup", terms)
+ script.append(f"{dup_chain} ~ Linear(Symbolic(X_{i}), 1.0, 0.0);")
+
+ result_lines = [f"result{i};" for i in range(len(interactions.get(graph.output[0].name, [])))]
+ return "\n".join(script + result_lines)
+
+
+def inpla_run(model: str) -> str:
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".inpla", delete=False) as f:
+ f.write(f"{rules}\n{model}")
+ temp_path = f.name
+ try:
+ res = subprocess.run(["./inpla", "-f", temp_path], capture_output=True, text=True)
+ if res.stderr: print(res.stderr)
+ return res.stdout
+ finally:
+ if os.path.exists(temp_path):
+ os.remove(temp_path)
+
+
+def z3_evaluate(model: str, X):
+ def Symbolic(id):
+ if id not in X:
+ X[id] = z3.Real(id)
+ return X[id]
+ def Concrete(val): return z3.RealVal(val)
+ def TermAdd(a, b): return a + b
+ def TermMul(a, b): return a * b
+ def TermReLU(x): return z3.If(x > 0, x, 0)
+
+ context = {
+ 'Concrete': Concrete,
+ 'Symbolic': Symbolic,
+ 'TermAdd': TermAdd,
+ 'TermMul': TermMul,
+ 'TermReLU': TermReLU
+ }
+ lines = [line.strip() for line in model.splitlines() if line.strip()]
+
+ def iterative_eval(expr_str):
+ tokens = re.findall(r'\(|\)|\,|[^(),\s]+', expr_str)
+ stack = [[]]
+ for token in tokens:
+ if token == '(':
+ stack.append([])
+ elif token == ')':
+ args = stack.pop()
+ func_name = stack[-1].pop()
+ func = context.get(func_name)
+ if not func:
+ raise ValueError(f"Unknown function: {func_name}")
+ stack[-1].append(func(*args))
+ elif token == ',':
+ continue
+ else:
+ if token in context:
+ stack[-1].append(token)
+ else:
+ try:
+ stack[-1].append(float(token))
+ except ValueError:
+ stack[-1].append(token)
+ return stack[0][0]
+
+ exprs = [iterative_eval(line) for line in lines]
+ return exprs
+
+def net(model: onnx.ModelProto, X, bounds: Optional[Dict[str, List[float]]] = None):
+ model_hash = hashlib.sha256(model.SerializeToString()).hexdigest()
+ bounds_key = tuple(sorted(bounds.items())) if bounds else None
+ cache_key = (model_hash, bounds_key)
+
+ if cache_key not in _INPLA_CACHE:
+ _INPLA_CACHE[cache_key] = inpla_run(inpla_export(model, bounds))
+
+ res = z3_evaluate(_INPLA_CACHE[cache_key], X)
+ return res if res is not None else []
+
+class Solver(z3.Solver):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.X = {}
+ self.Y = {}
+ self.bounds: Dict[str, List[float]] = {}
+ self.pending_nets: List[Tuple[onnx.ModelProto, Optional[str]]] = []
+
+ def load_vnnlib(self, file_path):
+ with open(file_path, "r") as f:
+ content = f.read()
+
+ for match in re.finditer(r"\(assert\s+\((>=|<=)\s+(X_\d+)\s+([-+]?\d*\.?\d+(?:[eE][-+]?\d+)?)\)\)", content):
+ op, var, val = match.groups()
+ val = float(val)
+ if var not in self.bounds: self.bounds[var] = [float('-inf'), float('inf')]
+ if op == ">=": self.bounds[var][0] = val
+ else: self.bounds[var][1] = val
+
+ content = re.sub(r"\(vnnlib-version.*?\)", "", content)
+ assertions = z3.parse_smt2_string(content)
+ self.add(assertions)
+
+ def load_onnx(self, file, name=None):
+ model = onnx.load(file)
+ self.pending_nets.append((model, name))
+
+ def _process_nets(self):
+ y_count = 0
+ for model, name in self.pending_nets:
+ z3_outputs = net(model, self.X, bounds=self.bounds)
+ if z3_outputs:
+ for _, out_expr in enumerate(z3_outputs):
+ y_var = z3.Real(f"Y_{y_count}")
+ self.add(y_var == out_expr)
+ if name:
+ if name not in self.Y: self.Y[name] = []
+ self.Y[name].append(out_expr)
+ y_count += 1
+ self.pending_nets = []
+
+ def check(self, *args):
+ self._process_nets()
+ return super().check(*args)