aboutsummaryrefslogtreecommitdiff
path: root/vein.py
diff options
context:
space:
mode:
authorericmarin <maarin.eric@gmail.com>2026-04-13 19:42:39 +0200
committerericmarin <maarin.eric@gmail.com>2026-04-13 21:38:16 +0200
commitfcbbc960f43137aa170b78ba0be2d89aec3bc766 (patch)
tree15e0249bf429888d9b64f19eb0c6e2d9af0901e4 /vein.py
parent8f4f24523235965cfa2041ed00cc40fc0b4bd367 (diff)
downloadvein-fcbbc960f43137aa170b78ba0be2d89aec3bc766.tar.gz
vein-fcbbc960f43137aa170b78ba0be2d89aec3bc766.zip
New ONNX ops and testsHEADmaster
New ops: Slice, Squeeze, Unsqueeze New tests based on papers: - Wide-to-Deep, Deep-to-Wide Transformation - Pruining of stably inactive (always negative) and active (always positive) ReLUs
Diffstat (limited to 'vein.py')
-rw-r--r--vein.py45
1 files changed, 39 insertions, 6 deletions
diff --git a/vein.py b/vein.py
index c017428..baa9da5 100644
--- a/vein.py
+++ b/vein.py
@@ -178,12 +178,10 @@ def inpla_export(model: onnx.ModelProto, bounds: Optional[Dict[str, List[float]]
script.append(f"{v} ~ {sink};")
def op_flatten(node):
- out_name, in_name = node.output[0], node.input[0]
- if out_name in interactions:
- interactions[in_name] = interactions[out_name]
+ op_identity(node)
def op_reshape(node):
- op_flatten(node)
+ op_identity(node)
def op_add(node):
out_name = node.output[0]
@@ -242,6 +240,36 @@ def inpla_export(model: onnx.ModelProto, bounds: Optional[Dict[str, List[float]]
interactions[in_a][i].append(f"Add({sink}, Mul({v_b}, Concrete(-1.0)))")
interactions[in_b][i].append(f"{v_b}")
+ def op_slice(node):
+ in_name, out_name = node.input[0], node.output[0]
+ if out_name in interactions:
+ starts = initializers.get(node.input[1])
+ steps = initializers.get(node.input[4]) if len(node.input) > 4 else None
+
+ start = int(starts.flatten()[0]) if starts is not None else 0
+ step = int(steps.flatten()[0]) if steps is not None else 1
+
+ in_dim = get_dim(in_name) or 1
+ if in_name not in interactions:
+ interactions[in_name] = [[] for _ in range(in_dim)]
+
+ for i, terms in enumerate(interactions[out_name]):
+ input_index = start + (i * step)
+ if input_index < in_dim:
+ interactions[in_name][input_index].extend(terms)
+
+ def op_squeeze(node):
+ op_identity(node)
+
+ def op_unsqueeze(node):
+ op_identity(node)
+
+ def op_identity(node):
+ in_name, out_name = node.input[0], node.output[0]
+ if out_name in interactions:
+ interactions[in_name] = interactions[out_name]
+
+
graph, initializers = model.graph, get_initializers(model.graph)
wire_gen = NameGen("w")
interactions: Dict[str, List[List[str]]] = {}
@@ -253,7 +281,11 @@ def inpla_export(model: onnx.ModelProto, bounds: Optional[Dict[str, List[float]]
"Reshape": op_reshape,
"MatMul": op_matmul,
"Add": op_add,
- "Sub": op_sub
+ "Sub": op_sub,
+ "Slice": op_slice,
+ "Squeeze": op_squeeze,
+ "Unsqueeze": op_unsqueeze,
+ "Identity": op_identity
}
if graph.output:
@@ -284,7 +316,8 @@ def inpla_run(model: str) -> str:
temp_path = f.name
try:
res = subprocess.run(["./inpla", "-f", temp_path, "-foptimise-tail-calls"], capture_output=True, text=True)
- if res.stderr: print(res.stderr)
+ if res.stderr:
+ raise RuntimeError(res.stderr)
return res.stdout
finally:
if os.path.exists(temp_path):