aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorericmarin <maarin.eric@gmail.com>2026-03-21 13:08:17 +0100
committerericmarin <maarin.eric@gmail.com>2026-03-21 13:08:17 +0100
commit4a9b66faae8bf362849b961ac2bf5dedc079c6ce (patch)
tree1d8b745248e4a0a40b86df2802e570d4c4e14807
parente2abe9d9ec649b849cc39b516c1db1b4fa592003 (diff)
downloadvein-4a9b66faae8bf362849b961ac2bf5dedc079c6ce.tar.gz
vein-4a9b66faae8bf362849b961ac2bf5dedc079c6ce.zip
added MatMul and Flatten. Now Inpla produces a balanced tree of TermAdd
-rw-r--r--nneq.py77
-rw-r--r--proof.norg6
2 files changed, 64 insertions, 19 deletions
diff --git a/nneq.py b/nneq.py
index 660bcce..1549878 100644
--- a/nneq.py
+++ b/nneq.py
@@ -54,7 +54,8 @@ rules = """
_INPLA_CACHE: Dict[Tuple[str, Optional[Tuple]], str] = {}
def inpla_export(model: onnx.ModelProto, bounds: Optional[Dict[str, List[float]]] = None) -> str:
- _ = bounds # TODO: Add Range agent
+ # TODO: Add Range agent
+ _ = bounds
class NameGen:
def __init__(self, prefix="v"):
self.counter = 0
@@ -88,11 +89,34 @@ def inpla_export(model: onnx.ModelProto, bounds: Optional[Dict[str, List[float]]
current = wire
return current
- def op_gemm(node):
- attrs = get_attrs(node)
+ def balance_add(terms: List[str], sink: str):
+ if not terms:
+ script.append(f"{sink} ~ Eraser;")
+ return
+ if len(terms) == 1:
+ script.append(f"{sink} ~ {terms[0]};")
+ return
+
+ nodes = terms
+ while len(nodes) > 1:
+ next_level = []
+ for i in range(0, len(nodes), 2):
+ if i + 1 < len(nodes):
+ wire_out = wire_gen.next()
+ script.append(f"{nodes[i]} ~ Add({wire_out}, {nodes[i+1]});")
+ next_level.append(wire_out)
+ else:
+ next_level.append(nodes[i])
+ nodes = next_level
+ script.append(f"{nodes[0]} ~ {sink};")
+
+ def op_gemm(node, override_attrs=None):
+ attrs = override_attrs if override_attrs is not None else get_attrs(node)
+
W = initializers[node.input[1]]
if not attrs.get("transB", 0): W = W.T
out_dim, in_dim = W.shape
+
B = initializers[node.input[2]] if len(node.input) > 2 else np.zeros(out_dim)
alpha, beta = attrs.get("alpha", 1.0), attrs.get("beta", 1.0)
@@ -102,21 +126,25 @@ def inpla_export(model: onnx.ModelProto, bounds: Optional[Dict[str, List[float]]
out_terms = interactions.get(node.output[0]) or [[f"Materialize(result{j})"] for j in range(out_dim)]
for j in range(out_dim):
- chain = flatten_nest("Dup", out_terms[j])
+ sink = flatten_nest("Dup", out_terms[j])
+ neuron_terms = []
for i in range(in_dim):
weight = float(alpha * W[j, i])
- if weight == 0:
- interactions[node.input[0]][i].append("Eraser")
- else:
+ if weight != 0:
v = var_gen.next()
- wire = wire_gen.next()
- script.append(f"{wire} ~ Add({chain}, {v});")
- chain = wire
interactions[node.input[0]][i].append(f"Mul({v}, Concrete({weight}))")
+ neuron_terms.append(v)
+
+ bias_val = float(beta * B[j])
+ if bias_val != 0 or not neuron_terms:
+ neuron_terms.append(f"Concrete({bias_val})")
- script.append(f"{chain} ~ Concrete({float(beta * B[j])});")
+ balance_add(neuron_terms, sink)
yield from []
+ def op_matmul(node):
+ return op_gemm(node, override_attrs={"alpha": 1.0, "beta": 0.0, "transB": 0})
+
def op_relu(node):
out_name, in_name = node.output[0], node.input[0]
if out_name in interactions:
@@ -124,17 +152,28 @@ def inpla_export(model: onnx.ModelProto, bounds: Optional[Dict[str, List[float]]
if in_name not in interactions:
interactions[in_name] = [[] for _ in range(dim)]
for i in range(dim):
- dup_chain = flatten_nest("Dup", interactions[out_name][i])
+ sink = flatten_nest("Dup", interactions[out_name][i])
v = var_gen.next()
interactions[in_name][i].append(f"ReLU({v})")
- script.append(f"{v} ~ {dup_chain};")
+ script.append(f"{v} ~ {sink};")
+ yield from []
+
+ def op_flatten(node):
+ out_name, in_name = node.output[0], node.input[0]
+ if out_name in interactions:
+ interactions[in_name] = interactions[out_name]
yield from []
graph, initializers = model.graph, get_initializers(model.graph)
var_gen, wire_gen = NameGen("v"), NameGen("w")
interactions: Dict[str, List[List[str]]] = {}
script = []
- ops = {"Gemm": op_gemm, "Relu": op_relu}
+ ops = {
+ "Gemm": op_gemm,
+ "Relu": op_relu,
+ "Flatten": op_flatten,
+ "MatMul": op_matmul
+ }
if graph.output:
out = graph.output[0].name
@@ -150,8 +189,8 @@ def inpla_export(model: onnx.ModelProto, bounds: Optional[Dict[str, List[float]]
if graph.input and graph.input[0].name in interactions:
for i, terms in enumerate(interactions[graph.input[0].name]):
- dup_chain = flatten_nest("Dup", terms)
- script.append(f"{dup_chain} ~ Linear(Symbolic(X_{i}), 1.0, 0.0);")
+ sink = flatten_nest("Dup", terms)
+ script.append(f"{sink} ~ Linear(Symbolic(X_{i}), 1.0, 0.0);")
result_lines = [f"result{i};" for i in range(len(interactions.get(graph.output[0].name, [])))]
return "\n".join(script + result_lines)
@@ -188,7 +227,7 @@ def z3_evaluate(model: str, X):
'TermReLU': TermReLU
}
lines = [line.strip() for line in model.splitlines() if line.strip()]
-
+
def iterative_eval(expr_str):
tokens = re.findall(r'\(|\)|\,|[^(),\s]+', expr_str)
stack = [[]]
@@ -224,7 +263,7 @@ def net(model: onnx.ModelProto, X, bounds: Optional[Dict[str, List[float]]] = No
if cache_key not in _INPLA_CACHE:
_INPLA_CACHE[cache_key] = inpla_run(inpla_export(model, bounds))
-
+
res = z3_evaluate(_INPLA_CACHE[cache_key], X)
return res if res is not None else []
@@ -239,7 +278,7 @@ class Solver(z3.Solver):
def load_vnnlib(self, file_path):
with open(file_path, "r") as f:
content = f.read()
-
+
for match in re.finditer(r"\(assert\s+\((>=|<=)\s+(X_\d+)\s+([-+]?\d*\.?\d+(?:[eE][-+]?\d+)?)\)\)", content):
op, var, val = match.groups()
val = float(val)
diff --git a/proof.norg b/proof.norg
index 2a0f768..7f35c1e 100644
--- a/proof.norg
+++ b/proof.norg
@@ -46,6 +46,12 @@ version: 1.1.1
By grouping the operations we get:
/Y = alpha * A * B + beta * C/
+** Flatten
+ Just identity mapping
+
+** MatMul
+ Equal to Gemm with alpha=1 and beta=0
+
* Proof for the Interaction Rules
** Materialize
The Materialize agent transforms a Linear agent into a tree of explicit mathematical operations