diff options
Diffstat (limited to '')
| -rw-r--r-- | vein.py | 73 |
1 files changed, 44 insertions, 29 deletions
@@ -19,11 +19,14 @@ import subprocess import onnx import onnx.shape_inference from onnx import numpy_helper -from typing import List, Dict, Optional, Tuple +from typing import List, Dict, Optional import os import tempfile import hashlib +sat = z3.sat +unsat = z3.unsat + rules = """ Linear(x, float q, float r) >< Add(out, b) => b ~ AddCheckLinear(out, x, q, r); Concrete(float k) >< Add(out, b) @@ -66,7 +69,7 @@ rules = """ Concrete(float k) >< Materialize(out) => out ~ (*L)Concrete(k); """ -_INPLA_CACHE: Dict[Tuple[str, Optional[Tuple]], str] = {} +_CACHE = {} def inpla_export(model: onnx.ModelProto, bounds: Optional[Dict[str, List[float]]] = None) -> str: # TODO: Add Range agent @@ -86,7 +89,7 @@ def inpla_export(model: onnx.ModelProto, bounds: Optional[Dict[str, List[float]] initializers[init.name] = numpy_helper.to_array(init) return initializers - def get_attrs(node: onnx.NodeProto) -> Dict: + def get_attrs(node) -> Dict: return {attr.name: onnx.helper.get_attribute_value(attr) for attr in node.attribute} def get_dim(name): @@ -288,7 +291,7 @@ def inpla_run(model: str) -> str: os.remove(temp_path) -def z3_evaluate(model: str, X): +def z3_evaluate(model: str, X: dict): def Symbolic(id): if id not in X: X[id] = z3.Real(id) @@ -305,20 +308,33 @@ def z3_evaluate(model: str, X): 'TermMul': TermMul, 'TermReLU': TermReLU } - lines = [line.strip() for line in model.splitlines() if line.strip()] - def iterative_eval(expr_str): - tokens = re.findall(r'\(|\)|\,|[^(),\s]+', expr_str) + def tokenize(s): + i = 0 + n = len(s) + while i < n: + c = s[i] + if c in '(),': + yield c + i += 1 + elif c.isspace(): + i += 1 + else: + start = i + while i < n and s[i] not in '(), ' and not s[i].isspace(): + i += 1 + yield s[start:i] + + def iterative_eval(tokens_gen): stack = [[]] - for token in tokens: + for token in tokens_gen: if token == '(': stack.append([]) elif token == ')': args = stack.pop() func_name = stack[-1].pop() func = context.get(func_name) - if not func: - raise ValueError(f"Unknown function: {func_name}") + if not func: raise ValueError(f"Unknown: {func_name}") stack[-1].append(func(*args)) elif token == ',': continue @@ -332,31 +348,34 @@ def z3_evaluate(model: str, X): stack[-1].append(token) return stack[0][0] - exprs = [iterative_eval(line) for line in lines] + exprs = [] + for line in model.splitlines(): + line = line.strip() + exprs.append(iterative_eval(tokenize(line))) return exprs -def net(model: onnx.ModelProto, X, bounds: Optional[Dict[str, List[float]]] = None): +def net(model: onnx.ModelProto, bounds: Optional[Dict[str, List[float]]] = None): model_hash = hashlib.sha256(model.SerializeToString()).hexdigest() bounds_key = tuple(sorted((k, tuple(v)) for k, v in bounds.items())) if bounds else None cache_key = (model_hash, bounds_key) - if cache_key not in _INPLA_CACHE: + if cache_key not in _CACHE: exported = inpla_export(model, bounds) reduced = inpla_run(exported) - _INPLA_CACHE[cache_key] = reduced + X = {} + evaluated = z3_evaluate(reduced, X) + _CACHE[cache_key] = evaluated - res = z3_evaluate(_INPLA_CACHE[cache_key], X) - return res if res is not None else [] + exprs = _CACHE[cache_key] + return exprs if exprs is not None else [] class Solver(z3.Solver): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.X = {} - self.Y = {} self.bounds: Dict[str, List[float]] = {} - self.pending_nets: List[Tuple[onnx.ModelProto, Optional[str]]] = [] + self.pending_nets: List[onnx.ModelProto] = [] - def load_vnnlib(self, file_path): + def load_vnnlib(self, file_path: str): with open(file_path, "r") as f: content = f.read() @@ -367,26 +386,22 @@ class Solver(z3.Solver): if op == ">=": self.bounds[var][0] = val else: self.bounds[var][1] = val - content = re.sub(r"\(vnnlib-version.*?\)", "", content) assertions = z3.parse_smt2_string(content) self.add(assertions) - def load_onnx(self, file, name=None): - model = onnx.load(file) + def load_onnx(self, file_path: str): + model = onnx.load(file_path) model = onnx.shape_inference.infer_shapes(model) - self.pending_nets.append((model, name)) + self.pending_nets.append(model) def _process_nets(self): y_count = 0 - for model, name in self.pending_nets: - z3_outputs = net(model, self.X, bounds=self.bounds) + for model in self.pending_nets: + z3_outputs = net(model, bounds=self.bounds) if z3_outputs: for _, out_expr in enumerate(z3_outputs): y_var = z3.Real(f"Y_{y_count}") self.add(y_var == out_expr) - if name: - if name not in self.Y: self.Y[name] = [] - self.Y[name].append(out_expr) y_count += 1 self.pending_nets = [] |
