aboutsummaryrefslogtreecommitdiff
path: root/vein.py
diff options
context:
space:
mode:
authorericmarin <maarin.eric@gmail.com>2026-04-10 15:06:56 +0200
committerericmarin <maarin.eric@gmail.com>2026-04-13 10:55:06 +0200
commit8f4f24523235965cfa2041ed00cc40fc0b4bd367 (patch)
tree716862b4c898431861c27ab69165edfb245467dc /vein.py
parent9fb816496d392638fa6981e71800466d71434680 (diff)
downloadvein-8f4f24523235965cfa2041ed00cc40fc0b4bd367.tar.gz
vein-8f4f24523235965cfa2041ed00cc40fc0b4bd367.zip
added MNIST, changed cache and parser
Diffstat (limited to 'vein.py')
-rw-r--r--vein.py73
1 files changed, 44 insertions, 29 deletions
diff --git a/vein.py b/vein.py
index 1aa86ae..c017428 100644
--- a/vein.py
+++ b/vein.py
@@ -19,11 +19,14 @@ import subprocess
import onnx
import onnx.shape_inference
from onnx import numpy_helper
-from typing import List, Dict, Optional, Tuple
+from typing import List, Dict, Optional
import os
import tempfile
import hashlib
+sat = z3.sat
+unsat = z3.unsat
+
rules = """
Linear(x, float q, float r) >< Add(out, b) => b ~ AddCheckLinear(out, x, q, r);
Concrete(float k) >< Add(out, b)
@@ -66,7 +69,7 @@ rules = """
Concrete(float k) >< Materialize(out) => out ~ (*L)Concrete(k);
"""
-_INPLA_CACHE: Dict[Tuple[str, Optional[Tuple]], str] = {}
+_CACHE = {}
def inpla_export(model: onnx.ModelProto, bounds: Optional[Dict[str, List[float]]] = None) -> str:
# TODO: Add Range agent
@@ -86,7 +89,7 @@ def inpla_export(model: onnx.ModelProto, bounds: Optional[Dict[str, List[float]]
initializers[init.name] = numpy_helper.to_array(init)
return initializers
- def get_attrs(node: onnx.NodeProto) -> Dict:
+ def get_attrs(node) -> Dict:
return {attr.name: onnx.helper.get_attribute_value(attr) for attr in node.attribute}
def get_dim(name):
@@ -288,7 +291,7 @@ def inpla_run(model: str) -> str:
os.remove(temp_path)
-def z3_evaluate(model: str, X):
+def z3_evaluate(model: str, X: dict):
def Symbolic(id):
if id not in X:
X[id] = z3.Real(id)
@@ -305,20 +308,33 @@ def z3_evaluate(model: str, X):
'TermMul': TermMul,
'TermReLU': TermReLU
}
- lines = [line.strip() for line in model.splitlines() if line.strip()]
- def iterative_eval(expr_str):
- tokens = re.findall(r'\(|\)|\,|[^(),\s]+', expr_str)
+ def tokenize(s):
+ i = 0
+ n = len(s)
+ while i < n:
+ c = s[i]
+ if c in '(),':
+ yield c
+ i += 1
+ elif c.isspace():
+ i += 1
+ else:
+ start = i
+ while i < n and s[i] not in '(), ' and not s[i].isspace():
+ i += 1
+ yield s[start:i]
+
+ def iterative_eval(tokens_gen):
stack = [[]]
- for token in tokens:
+ for token in tokens_gen:
if token == '(':
stack.append([])
elif token == ')':
args = stack.pop()
func_name = stack[-1].pop()
func = context.get(func_name)
- if not func:
- raise ValueError(f"Unknown function: {func_name}")
+ if not func: raise ValueError(f"Unknown: {func_name}")
stack[-1].append(func(*args))
elif token == ',':
continue
@@ -332,31 +348,34 @@ def z3_evaluate(model: str, X):
stack[-1].append(token)
return stack[0][0]
- exprs = [iterative_eval(line) for line in lines]
+ exprs = []
+ for line in model.splitlines():
+ line = line.strip()
+ exprs.append(iterative_eval(tokenize(line)))
return exprs
-def net(model: onnx.ModelProto, X, bounds: Optional[Dict[str, List[float]]] = None):
+def net(model: onnx.ModelProto, bounds: Optional[Dict[str, List[float]]] = None):
model_hash = hashlib.sha256(model.SerializeToString()).hexdigest()
bounds_key = tuple(sorted((k, tuple(v)) for k, v in bounds.items())) if bounds else None
cache_key = (model_hash, bounds_key)
- if cache_key not in _INPLA_CACHE:
+ if cache_key not in _CACHE:
exported = inpla_export(model, bounds)
reduced = inpla_run(exported)
- _INPLA_CACHE[cache_key] = reduced
+ X = {}
+ evaluated = z3_evaluate(reduced, X)
+ _CACHE[cache_key] = evaluated
- res = z3_evaluate(_INPLA_CACHE[cache_key], X)
- return res if res is not None else []
+ exprs = _CACHE[cache_key]
+ return exprs if exprs is not None else []
class Solver(z3.Solver):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
- self.X = {}
- self.Y = {}
self.bounds: Dict[str, List[float]] = {}
- self.pending_nets: List[Tuple[onnx.ModelProto, Optional[str]]] = []
+ self.pending_nets: List[onnx.ModelProto] = []
- def load_vnnlib(self, file_path):
+ def load_vnnlib(self, file_path: str):
with open(file_path, "r") as f:
content = f.read()
@@ -367,26 +386,22 @@ class Solver(z3.Solver):
if op == ">=": self.bounds[var][0] = val
else: self.bounds[var][1] = val
- content = re.sub(r"\(vnnlib-version.*?\)", "", content)
assertions = z3.parse_smt2_string(content)
self.add(assertions)
- def load_onnx(self, file, name=None):
- model = onnx.load(file)
+ def load_onnx(self, file_path: str):
+ model = onnx.load(file_path)
model = onnx.shape_inference.infer_shapes(model)
- self.pending_nets.append((model, name))
+ self.pending_nets.append(model)
def _process_nets(self):
y_count = 0
- for model, name in self.pending_nets:
- z3_outputs = net(model, self.X, bounds=self.bounds)
+ for model in self.pending_nets:
+ z3_outputs = net(model, bounds=self.bounds)
if z3_outputs:
for _, out_expr in enumerate(z3_outputs):
y_var = z3.Real(f"Y_{y_count}")
self.add(y_var == out_expr)
- if name:
- if name not in self.Y: self.Y[name] = []
- self.Y[name].append(out_expr)
y_count += 1
self.pending_nets = []