aboutsummaryrefslogtreecommitdiff
path: root/nneq
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--nneq/nneq.py40
1 files changed, 36 insertions, 4 deletions
diff --git a/nneq/nneq.py b/nneq/nneq.py
index 4f46cbf..3f2106b 100644
--- a/nneq/nneq.py
+++ b/nneq/nneq.py
@@ -3,6 +3,7 @@ import re
import numpy as np
import subprocess
import onnx
+import ast
from onnx import numpy_helper
from typing import List, Dict
@@ -113,9 +114,15 @@ def inpla_export(model: onnx.ModelProto) -> inpla_str:
interactions[in_name][i].append(f"ReLU({nest_dups(interactions[out_name][i])})")
yield from []
+ def op_flatten(node):
+ out_name, in_name = node.output[0], node.input[0]
+ if out_name in interactions:
+ interactions[in_name] = interactions[out_name]
+ yield from []
+
graph, initializers, name_gen = model.graph, get_initializers(model.graph), NameGen()
interactions: Dict[str, List[List[str]]] = {}
- ops = {"Gemm": op_gemm, "Relu": op_relu}
+ ops = {"Gemm": op_gemm, "Relu": op_relu, "Flatten": op_flatten}
if graph.output:
out = graph.output[0].name
@@ -159,14 +166,39 @@ context = {
wrap = re.compile(r"Symbolic\((.*?)\)")
+
def z3_evaluate(model: z3_str):
- model = wrap.sub(r'Symbolic("\1")', model);
- return eval(model, context)
+ model = wrap.sub(r'Symbolic("\1")', model)
+
+ def evaluate_node(node: ast.AST):
+ if isinstance(node, ast.Expression):
+ return evaluate_node(node.body)
+ if isinstance(node, ast.Call):
+ if not isinstance(node.func, ast.Name):
+ raise ValueError(f"Unsupported function call type: {type(node.func)}")
+ func_name = node.func.id
+ func = context.get(func_name)
+ if not func:
+ raise ValueError(f"Unknown function: {func_name}")
+ return func(*[evaluate_node(arg) for arg in node.args])
+ if isinstance(node, ast.Constant):
+ return node.value
+ if isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub):
+ val = evaluate_node(node.operand)
+ if hasattr(val, "__neg__"):
+ return -val
+ raise ValueError(f"Value does not support negation: {type(val)}")
+ raise ValueError(f"Unsupported AST node: {type(node)}")
+
+ lines = [line.strip() for line in model.splitlines() if line.strip()]
+ exprs = [evaluate_node(ast.parse(line, mode='eval')) for line in lines]
+
+ if not exprs: return None
+ return exprs[0] if len(exprs) == 1 else exprs
def net(model: onnx.ModelProto):
return z3_evaluate(inpla_run(inpla_export(model)))
-
def strict_equivalence(net_a, net_b):
solver = z3.Solver()