diff options
| author | ericmarin <maarin.eric@gmail.com> | 2026-03-18 16:43:01 +0100 |
|---|---|---|
| committer | ericmarin <maarin.eric@gmail.com> | 2026-03-18 17:47:15 +0100 |
| commit | af4335cf47984576e7493a0eb6569d3f6ecc31c8 (patch) | |
| tree | 31f6876cf3b4a14ea2bcb167fa301fcec69c02a5 /nneq/nneq.py | |
| parent | 0fca69965786db7deee2e976551b5156531e8ed5 (diff) | |
| download | vein-af4335cf47984576e7493a0eb6569d3f6ecc31c8.tar.gz vein-af4335cf47984576e7493a0eb6569d3f6ecc31c8.zip | |
changed eval() for a proper parser. Now it also handles multiple outputs
Diffstat (limited to '')
| -rw-r--r-- | nneq/nneq.py | 40 |
1 files changed, 36 insertions, 4 deletions
diff --git a/nneq/nneq.py b/nneq/nneq.py index 4f46cbf..3f2106b 100644 --- a/nneq/nneq.py +++ b/nneq/nneq.py @@ -3,6 +3,7 @@ import re import numpy as np import subprocess import onnx +import ast from onnx import numpy_helper from typing import List, Dict @@ -113,9 +114,15 @@ def inpla_export(model: onnx.ModelProto) -> inpla_str: interactions[in_name][i].append(f"ReLU({nest_dups(interactions[out_name][i])})") yield from [] + def op_flatten(node): + out_name, in_name = node.output[0], node.input[0] + if out_name in interactions: + interactions[in_name] = interactions[out_name] + yield from [] + graph, initializers, name_gen = model.graph, get_initializers(model.graph), NameGen() interactions: Dict[str, List[List[str]]] = {} - ops = {"Gemm": op_gemm, "Relu": op_relu} + ops = {"Gemm": op_gemm, "Relu": op_relu, "Flatten": op_flatten} if graph.output: out = graph.output[0].name @@ -159,14 +166,39 @@ context = { wrap = re.compile(r"Symbolic\((.*?)\)") + def z3_evaluate(model: z3_str): - model = wrap.sub(r'Symbolic("\1")', model); - return eval(model, context) + model = wrap.sub(r'Symbolic("\1")', model) + + def evaluate_node(node: ast.AST): + if isinstance(node, ast.Expression): + return evaluate_node(node.body) + if isinstance(node, ast.Call): + if not isinstance(node.func, ast.Name): + raise ValueError(f"Unsupported function call type: {type(node.func)}") + func_name = node.func.id + func = context.get(func_name) + if not func: + raise ValueError(f"Unknown function: {func_name}") + return func(*[evaluate_node(arg) for arg in node.args]) + if isinstance(node, ast.Constant): + return node.value + if isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub): + val = evaluate_node(node.operand) + if hasattr(val, "__neg__"): + return -val + raise ValueError(f"Value does not support negation: {type(val)}") + raise ValueError(f"Unsupported AST node: {type(node)}") + + lines = [line.strip() for line in model.splitlines() if line.strip()] + exprs = [evaluate_node(ast.parse(line, mode='eval')) for line in lines] + + if not exprs: return None + return exprs[0] if len(exprs) == 1 else exprs def net(model: onnx.ModelProto): return z3_evaluate(inpla_run(inpla_export(model))) - def strict_equivalence(net_a, net_b): solver = z3.Solver() |
