diff options
| author | ericmarin <maarin.eric@gmail.com> | 2026-03-21 13:08:17 +0100 |
|---|---|---|
| committer | ericmarin <maarin.eric@gmail.com> | 2026-03-21 13:08:17 +0100 |
| commit | 4a9b66faae8bf362849b961ac2bf5dedc079c6ce (patch) | |
| tree | 1d8b745248e4a0a40b86df2802e570d4c4e14807 | |
| parent | e2abe9d9ec649b849cc39b516c1db1b4fa592003 (diff) | |
| download | vein-4a9b66faae8bf362849b961ac2bf5dedc079c6ce.tar.gz vein-4a9b66faae8bf362849b961ac2bf5dedc079c6ce.zip | |
added MatMul and Flatten. Now Inpla produces a balanced tree of TermAdd
| -rw-r--r-- | nneq.py | 77 | ||||
| -rw-r--r-- | proof.norg | 6 |
2 files changed, 64 insertions, 19 deletions
@@ -54,7 +54,8 @@ rules = """ _INPLA_CACHE: Dict[Tuple[str, Optional[Tuple]], str] = {} def inpla_export(model: onnx.ModelProto, bounds: Optional[Dict[str, List[float]]] = None) -> str: - _ = bounds # TODO: Add Range agent + # TODO: Add Range agent + _ = bounds class NameGen: def __init__(self, prefix="v"): self.counter = 0 @@ -88,11 +89,34 @@ def inpla_export(model: onnx.ModelProto, bounds: Optional[Dict[str, List[float]] current = wire return current - def op_gemm(node): - attrs = get_attrs(node) + def balance_add(terms: List[str], sink: str): + if not terms: + script.append(f"{sink} ~ Eraser;") + return + if len(terms) == 1: + script.append(f"{sink} ~ {terms[0]};") + return + + nodes = terms + while len(nodes) > 1: + next_level = [] + for i in range(0, len(nodes), 2): + if i + 1 < len(nodes): + wire_out = wire_gen.next() + script.append(f"{nodes[i]} ~ Add({wire_out}, {nodes[i+1]});") + next_level.append(wire_out) + else: + next_level.append(nodes[i]) + nodes = next_level + script.append(f"{nodes[0]} ~ {sink};") + + def op_gemm(node, override_attrs=None): + attrs = override_attrs if override_attrs is not None else get_attrs(node) + W = initializers[node.input[1]] if not attrs.get("transB", 0): W = W.T out_dim, in_dim = W.shape + B = initializers[node.input[2]] if len(node.input) > 2 else np.zeros(out_dim) alpha, beta = attrs.get("alpha", 1.0), attrs.get("beta", 1.0) @@ -102,21 +126,25 @@ def inpla_export(model: onnx.ModelProto, bounds: Optional[Dict[str, List[float]] out_terms = interactions.get(node.output[0]) or [[f"Materialize(result{j})"] for j in range(out_dim)] for j in range(out_dim): - chain = flatten_nest("Dup", out_terms[j]) + sink = flatten_nest("Dup", out_terms[j]) + neuron_terms = [] for i in range(in_dim): weight = float(alpha * W[j, i]) - if weight == 0: - interactions[node.input[0]][i].append("Eraser") - else: + if weight != 0: v = var_gen.next() - wire = wire_gen.next() - script.append(f"{wire} ~ Add({chain}, {v});") - chain = wire interactions[node.input[0]][i].append(f"Mul({v}, Concrete({weight}))") + neuron_terms.append(v) + + bias_val = float(beta * B[j]) + if bias_val != 0 or not neuron_terms: + neuron_terms.append(f"Concrete({bias_val})") - script.append(f"{chain} ~ Concrete({float(beta * B[j])});") + balance_add(neuron_terms, sink) yield from [] + def op_matmul(node): + return op_gemm(node, override_attrs={"alpha": 1.0, "beta": 0.0, "transB": 0}) + def op_relu(node): out_name, in_name = node.output[0], node.input[0] if out_name in interactions: @@ -124,17 +152,28 @@ def inpla_export(model: onnx.ModelProto, bounds: Optional[Dict[str, List[float]] if in_name not in interactions: interactions[in_name] = [[] for _ in range(dim)] for i in range(dim): - dup_chain = flatten_nest("Dup", interactions[out_name][i]) + sink = flatten_nest("Dup", interactions[out_name][i]) v = var_gen.next() interactions[in_name][i].append(f"ReLU({v})") - script.append(f"{v} ~ {dup_chain};") + script.append(f"{v} ~ {sink};") + yield from [] + + def op_flatten(node): + out_name, in_name = node.output[0], node.input[0] + if out_name in interactions: + interactions[in_name] = interactions[out_name] yield from [] graph, initializers = model.graph, get_initializers(model.graph) var_gen, wire_gen = NameGen("v"), NameGen("w") interactions: Dict[str, List[List[str]]] = {} script = [] - ops = {"Gemm": op_gemm, "Relu": op_relu} + ops = { + "Gemm": op_gemm, + "Relu": op_relu, + "Flatten": op_flatten, + "MatMul": op_matmul + } if graph.output: out = graph.output[0].name @@ -150,8 +189,8 @@ def inpla_export(model: onnx.ModelProto, bounds: Optional[Dict[str, List[float]] if graph.input and graph.input[0].name in interactions: for i, terms in enumerate(interactions[graph.input[0].name]): - dup_chain = flatten_nest("Dup", terms) - script.append(f"{dup_chain} ~ Linear(Symbolic(X_{i}), 1.0, 0.0);") + sink = flatten_nest("Dup", terms) + script.append(f"{sink} ~ Linear(Symbolic(X_{i}), 1.0, 0.0);") result_lines = [f"result{i};" for i in range(len(interactions.get(graph.output[0].name, [])))] return "\n".join(script + result_lines) @@ -188,7 +227,7 @@ def z3_evaluate(model: str, X): 'TermReLU': TermReLU } lines = [line.strip() for line in model.splitlines() if line.strip()] - + def iterative_eval(expr_str): tokens = re.findall(r'\(|\)|\,|[^(),\s]+', expr_str) stack = [[]] @@ -224,7 +263,7 @@ def net(model: onnx.ModelProto, X, bounds: Optional[Dict[str, List[float]]] = No if cache_key not in _INPLA_CACHE: _INPLA_CACHE[cache_key] = inpla_run(inpla_export(model, bounds)) - + res = z3_evaluate(_INPLA_CACHE[cache_key], X) return res if res is not None else [] @@ -239,7 +278,7 @@ class Solver(z3.Solver): def load_vnnlib(self, file_path): with open(file_path, "r") as f: content = f.read() - + for match in re.finditer(r"\(assert\s+\((>=|<=)\s+(X_\d+)\s+([-+]?\d*\.?\d+(?:[eE][-+]?\d+)?)\)\)", content): op, var, val = match.groups() val = float(val) @@ -46,6 +46,12 @@ version: 1.1.1 By grouping the operations we get: /Y = alpha * A * B + beta * C/ +** Flatten + Just identity mapping + +** MatMul + Equal to Gemm with alpha=1 and beta=0 + * Proof for the Interaction Rules ** Materialize The Materialize agent transforms a Linear agent into a tree of explicit mathematical operations |
