From fcbbc960f43137aa170b78ba0be2d89aec3bc766 Mon Sep 17 00:00:00 2001 From: ericmarin Date: Mon, 13 Apr 2026 19:42:39 +0200 Subject: New ONNX ops and tests New ops: Slice, Squeeze, Unsqueeze New tests based on papers: - Wide-to-Deep, Deep-to-Wide Transformation - Pruining of stably inactive (always negative) and active (always positive) ReLUs --- vein.py | 45 +++++++++++++++++++++++++++++++++++++++------ 1 file changed, 39 insertions(+), 6 deletions(-) (limited to 'vein.py') diff --git a/vein.py b/vein.py index c017428..baa9da5 100644 --- a/vein.py +++ b/vein.py @@ -178,12 +178,10 @@ def inpla_export(model: onnx.ModelProto, bounds: Optional[Dict[str, List[float]] script.append(f"{v} ~ {sink};") def op_flatten(node): - out_name, in_name = node.output[0], node.input[0] - if out_name in interactions: - interactions[in_name] = interactions[out_name] + op_identity(node) def op_reshape(node): - op_flatten(node) + op_identity(node) def op_add(node): out_name = node.output[0] @@ -242,6 +240,36 @@ def inpla_export(model: onnx.ModelProto, bounds: Optional[Dict[str, List[float]] interactions[in_a][i].append(f"Add({sink}, Mul({v_b}, Concrete(-1.0)))") interactions[in_b][i].append(f"{v_b}") + def op_slice(node): + in_name, out_name = node.input[0], node.output[0] + if out_name in interactions: + starts = initializers.get(node.input[1]) + steps = initializers.get(node.input[4]) if len(node.input) > 4 else None + + start = int(starts.flatten()[0]) if starts is not None else 0 + step = int(steps.flatten()[0]) if steps is not None else 1 + + in_dim = get_dim(in_name) or 1 + if in_name not in interactions: + interactions[in_name] = [[] for _ in range(in_dim)] + + for i, terms in enumerate(interactions[out_name]): + input_index = start + (i * step) + if input_index < in_dim: + interactions[in_name][input_index].extend(terms) + + def op_squeeze(node): + op_identity(node) + + def op_unsqueeze(node): + op_identity(node) + + def op_identity(node): + in_name, out_name = node.input[0], node.output[0] + if out_name in interactions: + interactions[in_name] = interactions[out_name] + + graph, initializers = model.graph, get_initializers(model.graph) wire_gen = NameGen("w") interactions: Dict[str, List[List[str]]] = {} @@ -253,7 +281,11 @@ def inpla_export(model: onnx.ModelProto, bounds: Optional[Dict[str, List[float]] "Reshape": op_reshape, "MatMul": op_matmul, "Add": op_add, - "Sub": op_sub + "Sub": op_sub, + "Slice": op_slice, + "Squeeze": op_squeeze, + "Unsqueeze": op_unsqueeze, + "Identity": op_identity } if graph.output: @@ -284,7 +316,8 @@ def inpla_run(model: str) -> str: temp_path = f.name try: res = subprocess.run(["./inpla", "-f", temp_path, "-foptimise-tail-calls"], capture_output=True, text=True) - if res.stderr: print(res.stderr) + if res.stderr: + raise RuntimeError(res.stderr) return res.stdout finally: if os.path.exists(temp_path): -- cgit v1.2.3