diff options
Diffstat (limited to '')
| -rw-r--r-- | vein.py | 45 |
1 files changed, 39 insertions, 6 deletions
@@ -178,12 +178,10 @@ def inpla_export(model: onnx.ModelProto, bounds: Optional[Dict[str, List[float]] script.append(f"{v} ~ {sink};") def op_flatten(node): - out_name, in_name = node.output[0], node.input[0] - if out_name in interactions: - interactions[in_name] = interactions[out_name] + op_identity(node) def op_reshape(node): - op_flatten(node) + op_identity(node) def op_add(node): out_name = node.output[0] @@ -242,6 +240,36 @@ def inpla_export(model: onnx.ModelProto, bounds: Optional[Dict[str, List[float]] interactions[in_a][i].append(f"Add({sink}, Mul({v_b}, Concrete(-1.0)))") interactions[in_b][i].append(f"{v_b}") + def op_slice(node): + in_name, out_name = node.input[0], node.output[0] + if out_name in interactions: + starts = initializers.get(node.input[1]) + steps = initializers.get(node.input[4]) if len(node.input) > 4 else None + + start = int(starts.flatten()[0]) if starts is not None else 0 + step = int(steps.flatten()[0]) if steps is not None else 1 + + in_dim = get_dim(in_name) or 1 + if in_name not in interactions: + interactions[in_name] = [[] for _ in range(in_dim)] + + for i, terms in enumerate(interactions[out_name]): + input_index = start + (i * step) + if input_index < in_dim: + interactions[in_name][input_index].extend(terms) + + def op_squeeze(node): + op_identity(node) + + def op_unsqueeze(node): + op_identity(node) + + def op_identity(node): + in_name, out_name = node.input[0], node.output[0] + if out_name in interactions: + interactions[in_name] = interactions[out_name] + + graph, initializers = model.graph, get_initializers(model.graph) wire_gen = NameGen("w") interactions: Dict[str, List[List[str]]] = {} @@ -253,7 +281,11 @@ def inpla_export(model: onnx.ModelProto, bounds: Optional[Dict[str, List[float]] "Reshape": op_reshape, "MatMul": op_matmul, "Add": op_add, - "Sub": op_sub + "Sub": op_sub, + "Slice": op_slice, + "Squeeze": op_squeeze, + "Unsqueeze": op_unsqueeze, + "Identity": op_identity } if graph.output: @@ -284,7 +316,8 @@ def inpla_run(model: str) -> str: temp_path = f.name try: res = subprocess.run(["./inpla", "-f", temp_path, "-foptimise-tail-calls"], capture_output=True, text=True) - if res.stderr: print(res.stderr) + if res.stderr: + raise RuntimeError(res.stderr) return res.stdout finally: if os.path.exists(temp_path): |
