diff options
Diffstat (limited to 'nneq')
| -rw-r--r-- | nneq/__init__.py | 3 | ||||
| -rw-r--r-- | nneq/nneq.py | 260 |
2 files changed, 0 insertions, 263 deletions
diff --git a/nneq/__init__.py b/nneq/__init__.py deleted file mode 100644 index 8dc7048..0000000 --- a/nneq/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .nneq import * - -__all__ = nneq.__all__ diff --git a/nneq/nneq.py b/nneq/nneq.py deleted file mode 100644 index 3f2106b..0000000 --- a/nneq/nneq.py +++ /dev/null @@ -1,260 +0,0 @@ -import z3 -import re -import numpy as np -import subprocess -import onnx -import ast -from onnx import numpy_helper -from typing import List, Dict - -__all__ = ["net", "strict_equivalence", "epsilon_equivalence", "argmax_equivalence"] - -type inpla_str = str -type z3_str = str - -rules: inpla_str = """ - Linear(x, float q, float r) >< Add(out, b) => b ~ AddCheckLinear(out, x, q, r); - Concrete(float k) >< Add(out, b) - | k == 0 => out ~ b - | _ => b ~ AddCheckConcrete(out, k); - Linear(y, float s, float t) >< AddCheckLinear(out, x, float q, float r) - | (q == 0) && (r == 0) && (s == 0) && (t == 0) => out ~ Concrete(0), x ~ Eraser, y ~ Eraser - | (s == 0) && (t == 0) => out ~ Linear(x, q, r), y ~ Eraser - | (q == 0) && (r == 0) => out ~ (*L)Linear(y, s, t), x ~ Eraser - | _ => Linear(x, q, r) ~ Materialize(out_x), (*L)Linear(y, s, t) ~ Materialize(out_y), out ~ Linear(TermAdd(out_x, out_y), 1, 0); - Concrete(float j) >< AddCheckLinear(out, x, float q, float r) => out ~ Linear(x, q, r + j); - Linear(y, float s, float t) >< AddCheckConcrete(out, float k) => out ~ Linear(y, s, t + k); - Concrete(float j) >< AddCheckConcrete(out, float k) - | j == 0 => out ~ Concrete(k) - | _ => out ~ Concrete(k + j); - Linear(x, float q, float r) >< Mul(out, b) => b ~ MulCheckLinear(out, x, q, r); - Concrete(float k) >< Mul(out, b) - | k == 0 => b ~ Eraser, out ~ (*L)Concrete(0) - | k == 1 => out ~ b - | _ => b ~ MulCheckConcrete(out, k); - Linear(y, float s, float t) >< MulCheckLinear(out, x, float q, float r) - | ((q == 0) && (r == 0)) || ((s == 0) && (t == 0)) => out ~ Concrete(0), x ~ Eraser, y ~ Eraser - | _ => Linear(x, q, r) ~ Materialize(out_x), (*L)Linear(y, s, t) ~ Materialize(out_y), out ~ Linear(TermMul(out_x, out_y), 1, 0); - Concrete(float j) >< MulCheckLinear(out, x, float q, float r) => out ~ Linear(x, q * j, r * j); - Linear(y, float s, float t) >< MulCheckConcrete(out, float k) => out ~ Linear(y, s * k, t * k); - Concrete(float j) >< MulCheckConcrete(out, float k) - | j == 0 => out ~ Concrete(0) - | j == 1 => out ~ Concrete(k) - | _ => out ~ Concrete(k * j); - Linear(x, float q, float r) >< ReLU(out) => (*L)Linear(x, q, r) ~ Materialize(out_x), out ~ Linear(TermReLU(out_x), 1, 0); - Concrete(float k) >< ReLU(out) - | k > 0 => out ~ (*L)Concrete(k) - | _ => out ~ Concrete(0); - Linear(x, float q, float r) >< Materialize(out) - | (q == 0) => out ~ Concrete(r), x ~ Eraser - | (q == 1) && (r == 0) => out ~ x - | (q == 1) && (r != 0) => out ~ TermAdd(x, Concrete(r)) - | (q != 0) && (r == 0) => out ~ TermMul(Concrete(q), x) - | _ => out ~ TermAdd(TermMul(Concrete(q), x), Concrete(r)); - Concrete(float k) >< Materialize(out) => out ~ (*L)Concrete(k); -""" - -def inpla_export(model: onnx.ModelProto) -> inpla_str: - class NameGen: - def __init__(self): - self.counter = 0 - def next(self) -> str: - name = f"v{self.counter}" - self.counter += 1 - return name - - def get_initializers(graph) -> Dict[str, np.ndarray]: - initializers = {} - for init in graph.initializer: - initializers[init.name] = numpy_helper.to_array(init) - return initializers - - def get_attrs(node: onnx.NodeProto) -> Dict: - return {attr.name: onnx.helper.get_attribute_value(attr) for attr in node.attribute} - - def get_dim(name): - for i in list(graph.input) + list(graph.output) + list(graph.value_info): - if i.name == name: return i.type.tensor_type.shape.dim[-1].dim_value - return None - - def nest_dups(terms: List[str]) -> str: - if not terms: return "Eraser" - if len(terms) == 1: return terms[0] - return f"Dup({nest_dups(terms[:len(terms)//2])}, {nest_dups(terms[len(terms)//2:])})" - - def op_gemm(node): - attrs = get_attrs(node) - W = initializers[node.input[1]] - if not attrs.get("transB", 0): W = W.T - out_dim, in_dim = W.shape - B = initializers[node.input[2]] if len(node.input) > 2 else np.zeros(out_dim) - alpha, beta = attrs.get("alpha", 1.0), attrs.get("beta", 1.0) - - if node.input[0] not in interactions: interactions[node.input[0]] = [[] for _ in range(in_dim)] - - out_terms = interactions.get(node.output[0]) or [[] for _ in range(out_dim)] - - for j in range(out_dim): - chain = nest_dups(out_terms[j]) - for i in range(in_dim): - weight = float(alpha * W[j, i]) - if weight == 0: interactions[node.input[0]][i].append("Eraser") - else: - v = name_gen.next() - chain, term = f"Add({chain}, {v})", f"Mul({v}, Concrete({weight}))" - interactions[node.input[0]][i].append(term) - yield f"{chain} ~ Concrete({float(beta * B[j])});" - - def op_relu(node): - out_name, in_name = node.output[0], node.input[0] - if out_name in interactions: - dim = len(interactions[out_name]) - if in_name not in interactions: interactions[in_name] = [[] for _ in range(dim)] - for i in range(dim): - interactions[in_name][i].append(f"ReLU({nest_dups(interactions[out_name][i])})") - yield from [] - - def op_flatten(node): - out_name, in_name = node.output[0], node.input[0] - if out_name in interactions: - interactions[in_name] = interactions[out_name] - yield from [] - - graph, initializers, name_gen = model.graph, get_initializers(model.graph), NameGen() - interactions: Dict[str, List[List[str]]] = {} - ops = {"Gemm": op_gemm, "Relu": op_relu, "Flatten": op_flatten} - - if graph.output: - out = graph.output[0].name - dim = get_dim(out) - if dim: interactions[out] = [[f"Materialize(result{i})"] for i in range(dim)] - - node_script = [] - for node in reversed(graph.node): - if node.op_type in ops: node_script.extend(ops[node.op_type](node)) - - input_script = [] - if graph.input and graph.input[0].name in interactions: - for i, terms in enumerate(interactions[graph.input[0].name]): - input_script.append(f"{nest_dups(terms)} ~ Linear(Symbolic(X_{i}), 1.0, 0.0);") - - result_lines = [f"result{i};" for i in range(len(interactions.get(graph.output[0].name, [])))] - return "\n".join(input_script + list(reversed(node_script)) + result_lines) - -def inpla_run(model: inpla_str) -> z3_str: - return subprocess.run(["./inpla"], input=f"{rules}\n{model}", capture_output=True, text=True).stdout - -syms = {} -def Symbolic(id): - if id not in syms: - syms[id] = z3.Real(id) - return syms[id] - -def Concrete(val): return z3.RealVal(val) - -def TermAdd(a, b): return a + b -def TermMul(a, b): return a * b -def TermReLU(x): return z3.If(x > 0, x, 0) - -context = { - 'Concrete': Concrete, - 'Symbolic': Symbolic, - 'TermAdd': TermAdd, - 'TermMul': TermMul, - 'TermReLU': TermReLU -} - -wrap = re.compile(r"Symbolic\((.*?)\)") - - -def z3_evaluate(model: z3_str): - model = wrap.sub(r'Symbolic("\1")', model) - - def evaluate_node(node: ast.AST): - if isinstance(node, ast.Expression): - return evaluate_node(node.body) - if isinstance(node, ast.Call): - if not isinstance(node.func, ast.Name): - raise ValueError(f"Unsupported function call type: {type(node.func)}") - func_name = node.func.id - func = context.get(func_name) - if not func: - raise ValueError(f"Unknown function: {func_name}") - return func(*[evaluate_node(arg) for arg in node.args]) - if isinstance(node, ast.Constant): - return node.value - if isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub): - val = evaluate_node(node.operand) - if hasattr(val, "__neg__"): - return -val - raise ValueError(f"Value does not support negation: {type(val)}") - raise ValueError(f"Unsupported AST node: {type(node)}") - - lines = [line.strip() for line in model.splitlines() if line.strip()] - exprs = [evaluate_node(ast.parse(line, mode='eval')) for line in lines] - - if not exprs: return None - return exprs[0] if len(exprs) == 1 else exprs - -def net(model: onnx.ModelProto): - return z3_evaluate(inpla_run(inpla_export(model))) - -def strict_equivalence(net_a, net_b): - solver = z3.Solver() - - for sym in syms.values(): - solver.add(z3.Or(sym == 0, sym == 1)) - - solver.add(net_a != net_b) - - result = solver.check() - - print("Strict Equivalence") - if result == z3.unsat: - print("VERIFIED: The networks are strictly equivalent.") - elif result == z3.sat: - print("FAILED: The networks are different.") - print("Counter-example input:") - print(solver.model()) - else: - print("UNKNOWN: Solver could not decide.") - -def epsilon_equivalence(net_a, net_b, epsilon): - solver = z3.Solver() - - for sym in syms.values(): - solver.add(z3.Or(sym == 0, sym == 1)) - - solver.add(z3.Abs(net_a - net_b) > epsilon) - - result = solver.check() - - print(f"Epsilon-Equivalence | Epsilon={epsilon}.") - if result == z3.unsat: - print("VERIFIED: The networks are epsilon-equivalent.") - elif result == z3.sat: - print("FAILED: The networks are different.") - print("Counter-example input:") - print(solver.model()) - else: - print("UNKNOWN: Solver could not decide.") - -def argmax_equivalence(net_a, net_b): - solver = z3.Solver() - - for sym in syms.values(): - solver.add(z3.Or(sym == 0, sym == 1)) - - solver.add((net_a > 0.5) != (net_b > 0.5)) - - result = solver.check() - - print("ARGMAX Equivalence") - if result == z3.unsat: - print("VERIFIED: The networks are ARGMAX equivalent.") - elif result == z3.sat: - print("FAILED: The networks are different.") - print("Counter-example input:") - print(solver.model()) - else: - print("UNKNOWN: Solver could not decide.") |
