aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorericmarin <maarin.eric@gmail.com>2026-04-10 15:06:56 +0200
committerericmarin <maarin.eric@gmail.com>2026-04-13 10:55:06 +0200
commit8f4f24523235965cfa2041ed00cc40fc0b4bd367 (patch)
tree716862b4c898431861c27ab69165edfb245467dc
parent9fb816496d392638fa6981e71800466d71434680 (diff)
downloadvein-8f4f24523235965cfa2041ed00cc40fc0b4bd367.tar.gz
vein-8f4f24523235965cfa2041ed00cc40fc0b4bd367.zip
added MNIST, changed cache and parser
-rw-r--r--.gitignore10
-rw-r--r--LICENSE42
-rw-r--r--README.md2
-rw-r--r--examples/ACASXU/ACASXU_epsilon.vnnlib2
-rw-r--r--examples/double_integrator/double_integrator_epsilon.vnnlib2
-rw-r--r--examples/iris/iris_epsilon.vnnlib2
-rw-r--r--examples/mnist/mnist.py (renamed from examples/fashion_mnist/fashion_mnist.py)22
-rw-r--r--examples/mnist/mnist_argmax.vnnlib (renamed from examples/fashion_mnist/fashion_mnist_argmax.vnnlib)2
-rw-r--r--examples/mnist/mnist_epsilon.vnnlib (renamed from examples/fashion_mnist/fashion_mnist_epsilon.vnnlib)2
-rw-r--r--examples/mnist/mnist_strict.vnnlib (renamed from examples/fashion_mnist/fashion_mnist_strict.vnnlib)2
-rw-r--r--examples/pendulum/pendulum_epsilon.vnnlib2
-rw-r--r--examples/tll/tllBench_n=2_N=M=8_m=1_instance_0_0.onnxbin74221 -> 0 bytes
-rw-r--r--examples/tll/tllBench_n=2_N=M=8_m=1_instance_0_2.onnxbin74533 -> 0 bytes
-rw-r--r--examples/tll/tll_argmax.vnnlib16
-rw-r--r--examples/tll/tll_epsilon.vnnlib17
-rw-r--r--examples/tll/tll_strict.vnnlib16
-rw-r--r--examples/verify.py86
-rw-r--r--vein.py73
-rw-r--r--verify_example.py79
19 files changed, 147 insertions, 230 deletions
diff --git a/.gitignore b/.gitignore
index 6d0d093..18f6a2b 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,10 +1,10 @@
/inpla
/__pycache__/
-/examples/fashion_mnist/FashionMNIST/raw
-/examples/fashion_mnist/fashion_mnist_a.onnx
-/examples/fashion_mnist/fashion_mnist_a.onnx.data
-/examples/fashion_mnist/fashion_mnist_b.onnx
-/examples/fashion_mnist/fashion_mnist_b.onnx.data
+/examples/mnist/MNIST/raw
+/examples/mnist/mnist_a.onnx
+/examples/mnist/mnist_a.onnx.data
+/examples/mnist/mnist_b.onnx
+/examples/mnist/mnist_b.onnx.data
/examples/xor/xor_a.onnx
/examples/xor/xor_a.onnx.data
/examples/xor/xor_b.onnx
diff --git a/LICENSE b/LICENSE
index 16d9c93..ca9b055 100644
--- a/LICENSE
+++ b/LICENSE
@@ -617,45 +617,3 @@ Program, unless a warranty or assumption of liability accompanies a
copy of the Program in return for a fee.
END OF TERMS AND CONDITIONS
-
- How to Apply These Terms to Your New Programs
-
- If you develop a new program, and you want it to be of the greatest
-possible use to the public, the best way to achieve this is to make it
-free software which everyone can redistribute and change under these terms.
-
- To do so, attach the following notices to the program. It is safest
-to attach them to the start of each source file to most effectively
-state the exclusion of warranty; and each file should have at least
-the "copyright" line and a pointer to where the full notice is found.
-
- VErification via Interaction Nets (VEIN)
- Copyright (C) 2026 Eric Marin
-
- This program is free software: you can redistribute it and/or modify
- it under the terms of the GNU Affero General Public License as published by
- the Free Software Foundation, either version 3 of the License, or
- (at your option) any later version.
-
- This program is distributed in the hope that it will be useful,
- but WITHOUT ANY WARRANTY; without even the implied warranty of
- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- GNU Affero General Public License for more details.
-
- You should have received a copy of the GNU Affero General Public License
- along with this program. If not, see <https://www.gnu.org/licenses/>.
-
-Also add information on how to contact you by electronic and paper mail.
-
- If your software can interact with users remotely through a computer
-network, you should also make sure that it provides a way for users to
-get its source. For example, if your program is a web application, its
-interface could display a "Source" link that leads users to an archive
-of the code. There are many ways you could offer source, and different
-solutions will be better for different programs; see section 13 for the
-specific requirements.
-
- You should also get your employer (if you work as a programmer) or school,
-if any, to sign a "copyright disclaimer" for the program, if necessary.
-For more information on this, and how to apply and follow the GNU AGPL, see
-<https://www.gnu.org/licenses/>.
diff --git a/README.md b/README.md
index be2b25e..7989b1c 100644
--- a/README.md
+++ b/README.md
@@ -1,4 +1,4 @@
-# VEIN: VErification via Interaction Nets
+# VEIN: VErification via Interaction Nets</p>
Requires my [fork of Inpla](https://github.com/eric-marin/inpla).
diff --git a/examples/ACASXU/ACASXU_epsilon.vnnlib b/examples/ACASXU/ACASXU_epsilon.vnnlib
index 0ca04f8..2cbcd36 100644
--- a/examples/ACASXU/ACASXU_epsilon.vnnlib
+++ b/examples/ACASXU/ACASXU_epsilon.vnnlib
@@ -1,4 +1,4 @@
-; Strict Equivalence for ACASXU
+; Epsilon Equivalence for ACASXU
; Constant declaration
(declare-const X_0 Real)
diff --git a/examples/double_integrator/double_integrator_epsilon.vnnlib b/examples/double_integrator/double_integrator_epsilon.vnnlib
index 3af9079..f5c4ee6 100644
--- a/examples/double_integrator/double_integrator_epsilon.vnnlib
+++ b/examples/double_integrator/double_integrator_epsilon.vnnlib
@@ -1,4 +1,4 @@
-; Strict Equivalence for Double Integrator
+; Epsilon Equivalence for Double Integrator
; Constant declaration
(declare-const X_0 Real)
diff --git a/examples/iris/iris_epsilon.vnnlib b/examples/iris/iris_epsilon.vnnlib
index 9c8e825..df691c4 100644
--- a/examples/iris/iris_epsilon.vnnlib
+++ b/examples/iris/iris_epsilon.vnnlib
@@ -1,4 +1,4 @@
-; Strict Equivalence for Iris
+; Epsilon Equivalence for Iris
; Constant declaration
(declare-const X_0 Real)
diff --git a/examples/fashion_mnist/fashion_mnist.py b/examples/mnist/mnist.py
index 680f4eb..0a81878 100644
--- a/examples/fashion_mnist/fashion_mnist.py
+++ b/examples/mnist/mnist.py
@@ -1,9 +1,9 @@
import torch, torch.nn as nn
-from torchvision.datasets import FashionMNIST
+from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torchvision import transforms
-class FashionMNIST_MLP(nn.Module):
+class MNIST_MLP(nn.Module):
def __init__(self, hidden_dim):
super().__init__()
self.layers = nn.Sequential(
@@ -15,16 +15,16 @@ class FashionMNIST_MLP(nn.Module):
def forward(self, x):
return self.layers(x)
-train_dataset = FashionMNIST('./', download=True, transform=transforms.ToTensor(), train=True)
+train_dataset = MNIST('./', download=True, transform=transforms.ToTensor(), train=True)
trainloader = DataLoader(train_dataset, batch_size=128, shuffle=True)
def train_model(name: str, dim):
- net = FashionMNIST_MLP(hidden_dim=dim)
+ net = MNIST_MLP(hidden_dim=dim)
loss_fn = nn.CrossEntropyLoss()
- optimizer = torch.optim.Adam(net.parameters(), lr=1e-4)
+ optimizer = torch.optim.Adam(net.parameters(), lr=0.5e-4)
print(f"Training {name} ({dim} neurons)...")
- for epoch in range(10):
+ for epoch in range(100):
global loss
for data in trainloader:
inputs, targets = data
@@ -33,13 +33,13 @@ def train_model(name: str, dim):
loss = loss_fn(outputs, targets)
loss.backward()
optimizer.step()
- if (epoch + 1) % 1 == 0:
+ if (epoch + 1) % 10 == 0:
print(f" Epoch {epoch+1}, Loss: {loss.item():.4f}")
return net
if __name__ == "__main__":
- torch_net_a = train_model("Network A", 28).eval()
- torch_net_b = train_model("Network B", 56).eval()
+ torch_net_a = train_model("Network A", 6).eval()
+ torch_net_b = train_model("Network B", 12).eval()
- torch.onnx.export(torch_net_a, (torch.randn(1, 28, 28),), "fashion_mnist_a.onnx")
- torch.onnx.export(torch_net_b, (torch.randn(1, 28, 28),), "fashion_mnist_b.onnx")
+ torch.onnx.export(torch_net_a, (torch.randn(1, 28, 28),), "mnist_a.onnx")
+ torch.onnx.export(torch_net_b, (torch.randn(1, 28, 28),), "mnist_b.onnx")
diff --git a/examples/fashion_mnist/fashion_mnist_argmax.vnnlib b/examples/mnist/mnist_argmax.vnnlib
index 8d06485..4c7f0c9 100644
--- a/examples/fashion_mnist/fashion_mnist_argmax.vnnlib
+++ b/examples/mnist/mnist_argmax.vnnlib
@@ -1,4 +1,4 @@
-; Argmax Equivalence for reduced FashionMNIST
+; Argmax Equivalence for MNIST
; Constant declaration
(declare-const X_0 Real)
diff --git a/examples/fashion_mnist/fashion_mnist_epsilon.vnnlib b/examples/mnist/mnist_epsilon.vnnlib
index 9637489..ea76779 100644
--- a/examples/fashion_mnist/fashion_mnist_epsilon.vnnlib
+++ b/examples/mnist/mnist_epsilon.vnnlib
@@ -1,4 +1,4 @@
-; Strict Equivalence for reduced FashionMNIST
+; Epsilon Equivalence for MNIST
; Constant declaration
(declare-const X_0 Real)
diff --git a/examples/fashion_mnist/fashion_mnist_strict.vnnlib b/examples/mnist/mnist_strict.vnnlib
index c8b3f8e..356f176 100644
--- a/examples/fashion_mnist/fashion_mnist_strict.vnnlib
+++ b/examples/mnist/mnist_strict.vnnlib
@@ -1,4 +1,4 @@
-; Strict Equivalence for reduced FashionMNIST
+; Strict Equivalence for MNIST
; Constant declaration
(declare-const X_0 Real)
diff --git a/examples/pendulum/pendulum_epsilon.vnnlib b/examples/pendulum/pendulum_epsilon.vnnlib
index a192029..8209db5 100644
--- a/examples/pendulum/pendulum_epsilon.vnnlib
+++ b/examples/pendulum/pendulum_epsilon.vnnlib
@@ -1,4 +1,4 @@
-; Strict Equivalence for Pendulum
+; Epsilon Equivalence for Pendulum
; Constant declaration
(declare-const X_0 Real)
diff --git a/examples/tll/tllBench_n=2_N=M=8_m=1_instance_0_0.onnx b/examples/tll/tllBench_n=2_N=M=8_m=1_instance_0_0.onnx
deleted file mode 100644
index a6632fb..0000000
--- a/examples/tll/tllBench_n=2_N=M=8_m=1_instance_0_0.onnx
+++ /dev/null
Binary files differ
diff --git a/examples/tll/tllBench_n=2_N=M=8_m=1_instance_0_2.onnx b/examples/tll/tllBench_n=2_N=M=8_m=1_instance_0_2.onnx
deleted file mode 100644
index b650a13..0000000
--- a/examples/tll/tllBench_n=2_N=M=8_m=1_instance_0_2.onnx
+++ /dev/null
Binary files differ
diff --git a/examples/tll/tll_argmax.vnnlib b/examples/tll/tll_argmax.vnnlib
deleted file mode 100644
index c084e52..0000000
--- a/examples/tll/tll_argmax.vnnlib
+++ /dev/null
@@ -1,16 +0,0 @@
-; Argmax Equivalence for TLL
-
-; Constant declaration
-(declare-const X_0 Real)
-(declare-const X_1 Real)
-(declare-const Y_0 Real)
-(declare-const Y_1 Real)
-
-; Bounded inputs: X must be within [0, 1]
-(assert (>= X_0 0.0))
-(assert (<= X_0 1.0))
-(assert (>= X_1 0.0))
-(assert (<= X_1 1.0))
-
-; Violation of argmax equivalence
-(assert (and (> Y_0 0.5) (< Y_1 0.5)))
diff --git a/examples/tll/tll_epsilon.vnnlib b/examples/tll/tll_epsilon.vnnlib
deleted file mode 100644
index 8e0902d..0000000
--- a/examples/tll/tll_epsilon.vnnlib
+++ /dev/null
@@ -1,17 +0,0 @@
-; Strict Equivalence for TLL
-
-; Constant declaration
-(declare-const X_0 Real)
-(declare-const X_1 Real)
-(declare-const Y_0 Real)
-(declare-const Y_1 Real)
-
-; Bounded inputs: X must be within [0, 1]
-(assert (>= X_0 0.0))
-(assert (<= X_0 1.0))
-(assert (>= X_1 0.0))
-(assert (<= X_1 1.0))
-
-; Violation of epsilon equivalence (epsilon = 0.1)
-(define-fun absolute ((x Real)) Real (if (>= x 0) x (- x)))
-(assert (> (absolute (- Y_0 Y_1)) 0.1))
diff --git a/examples/tll/tll_strict.vnnlib b/examples/tll/tll_strict.vnnlib
deleted file mode 100644
index 0079b1e..0000000
--- a/examples/tll/tll_strict.vnnlib
+++ /dev/null
@@ -1,16 +0,0 @@
-; Strict Equivalence for TLL
-
-; Constant declaration
-(declare-const X_0 Real)
-(declare-const X_1 Real)
-(declare-const Y_0 Real)
-(declare-const Y_1 Real)
-
-; Bounded inputs: X must be within [0, 1]
-(assert (>= X_0 0.0))
-(assert (<= X_0 1.0))
-(assert (>= X_1 0.0))
-(assert (<= X_1 1.0))
-
-; Violation of strict equivalence
-(assert (not (= Y_0 Y_1)))
diff --git a/examples/verify.py b/examples/verify.py
deleted file mode 100644
index 65fb989..0000000
--- a/examples/verify.py
+++ /dev/null
@@ -1,86 +0,0 @@
-import sys, os
-sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
-
-import z3
-import vein
-
-def check_property(onnx_a, onnx_b, vnnlib):
- solver = vein.Solver()
-
- print(f"--- Checking {vnnlib} ---")
-
- solver.load_onnx(onnx_a)
- solver.load_onnx(onnx_b)
- solver.load_vnnlib(vnnlib)
-
- result = solver.check()
-
- if result == z3.unsat:
- print("VERIFIED (UNSAT): The networks are equivalent under this property.")
- elif result == z3.sat:
- print("FAILED (SAT): The networks are NOT equivalent.")
- print("Counter-example input:")
- print(solver.model())
- # m = solver.model()
- # sorted_symbols = sorted([s for s in m.decls() if s.name().startswith("X_")], key=lambda s: s.name())
- # for s in sorted_symbols:
- # print(f" {s.name()} = {m[s]}")
- else:
- print("UNKNOWN")
- print("")
-
-if __name__ == "__main__":
- if len(sys.argv) <= 1:
- print("Net not provided")
- print("Available Nets: 'xor', 'fashion_mnist', 'iris', 'acasxu', 'tll', 'pendulum', 'double_integrator'")
- else:
- match sys.argv[1]:
- case "xor":
- net_a = "./examples/xor/xor_a.onnx"
- net_b = "./examples/xor/xor_b.onnx"
- strict = "./examples/xor/xor_strict.vnnlib"
- epsilon = "./examples/xor/xor_epsilon.vnnlib"
- argmax = "./examples/xor/xor_argmax.vnnlib"
- case "fashion_mnist":
- net_a = "./examples/fashion_mnist/fashion_mnist_a.onnx"
- net_b = "./examples/fashion_mnist/fashion_mnist_b.onnx"
- strict = "./examples/fashion_mnist/fashion_mnist_strict.vnnlib"
- epsilon = "./examples/fashion_mnist/fashion_mnist_epsilon.vnnlib"
- argmax = "./examples/fashion_mnist/fashion_mnist_argmax.vnnlib"
- case "iris":
- net_a = "./examples/iris/iris_a.onnx"
- net_b = "./examples/iris/iris_b.onnx"
- strict = "./examples/iris/iris_strict.vnnlib"
- epsilon = "./examples/iris/iris_epsilon.vnnlib"
- argmax = "./examples/iris/iris_argmax.vnnlib"
- case "acasxu":
- net_a = "./examples/ACASXU/ACASXU_run2a_1_1_batch_2000.onnx"
- net_b = "./examples/ACASXU/ACASXU_run2a_1_1_batch_2000.onnx"
- strict = "./examples/ACASXU/ACASXU_strict.vnnlib"
- epsilon = "./examples/ACASXU/ACASXU_epsilon.vnnlib"
- argmax = "./examples/ACASXU/ACASXU_argmax.vnnlib"
- case "tll":
- net_a = "./examples/tll/tllBench_n=2_N=M=8_m=1_instance_0_0.onnx"
- net_b = "./examples/tll/tllBench_n=2_N=M=8_m=1_instance_0_2.onnx"
- strict = "./examples/tll/tll_strict.vnnlib"
- epsilon = "./examples/tll/tll_epsilon.vnnlib"
- argmax = "./examples/tll/tll_argmax.vnnlib"
- case "pendulum":
- net_a = "./examples/pendulum/pendulum_finetune_con.onnx"
- net_b = "./examples/pendulum/pendulum_finetune_con.onnx"
- strict = "./examples/pendulum/pendulum_strict.vnnlib"
- epsilon = "./examples/pendulum/pendulum_epsilon.vnnlib"
- argmax = "./examples/pendulum/pendulum_argmax.vnnlib"
- case "double_integrator":
- net_a = "./examples/double_integrator/double_integrator_finetune_inv.onnx"
- net_b = "./examples/double_integrator/double_integrator_finetune_inv.onnx"
- strict = "./examples/double_integrator/double_integrator_strict.vnnlib"
- epsilon = "./examples/double_integrator/double_integrator_epsilon.vnnlib"
- argmax = "./examples/double_integrator/double_integrator_argmax.vnnlib"
- case _:
- print("Available Nets: 'xor', 'fashion_mnist', 'iris', 'acasxu', 'tll', 'pendulum', 'double_integrator'")
- sys.exit()
-
- check_property(net_a, net_b, strict)
- check_property(net_a, net_b, epsilon)
- check_property(net_a, net_b, argmax)
diff --git a/vein.py b/vein.py
index 1aa86ae..c017428 100644
--- a/vein.py
+++ b/vein.py
@@ -19,11 +19,14 @@ import subprocess
import onnx
import onnx.shape_inference
from onnx import numpy_helper
-from typing import List, Dict, Optional, Tuple
+from typing import List, Dict, Optional
import os
import tempfile
import hashlib
+sat = z3.sat
+unsat = z3.unsat
+
rules = """
Linear(x, float q, float r) >< Add(out, b) => b ~ AddCheckLinear(out, x, q, r);
Concrete(float k) >< Add(out, b)
@@ -66,7 +69,7 @@ rules = """
Concrete(float k) >< Materialize(out) => out ~ (*L)Concrete(k);
"""
-_INPLA_CACHE: Dict[Tuple[str, Optional[Tuple]], str] = {}
+_CACHE = {}
def inpla_export(model: onnx.ModelProto, bounds: Optional[Dict[str, List[float]]] = None) -> str:
# TODO: Add Range agent
@@ -86,7 +89,7 @@ def inpla_export(model: onnx.ModelProto, bounds: Optional[Dict[str, List[float]]
initializers[init.name] = numpy_helper.to_array(init)
return initializers
- def get_attrs(node: onnx.NodeProto) -> Dict:
+ def get_attrs(node) -> Dict:
return {attr.name: onnx.helper.get_attribute_value(attr) for attr in node.attribute}
def get_dim(name):
@@ -288,7 +291,7 @@ def inpla_run(model: str) -> str:
os.remove(temp_path)
-def z3_evaluate(model: str, X):
+def z3_evaluate(model: str, X: dict):
def Symbolic(id):
if id not in X:
X[id] = z3.Real(id)
@@ -305,20 +308,33 @@ def z3_evaluate(model: str, X):
'TermMul': TermMul,
'TermReLU': TermReLU
}
- lines = [line.strip() for line in model.splitlines() if line.strip()]
- def iterative_eval(expr_str):
- tokens = re.findall(r'\(|\)|\,|[^(),\s]+', expr_str)
+ def tokenize(s):
+ i = 0
+ n = len(s)
+ while i < n:
+ c = s[i]
+ if c in '(),':
+ yield c
+ i += 1
+ elif c.isspace():
+ i += 1
+ else:
+ start = i
+ while i < n and s[i] not in '(), ' and not s[i].isspace():
+ i += 1
+ yield s[start:i]
+
+ def iterative_eval(tokens_gen):
stack = [[]]
- for token in tokens:
+ for token in tokens_gen:
if token == '(':
stack.append([])
elif token == ')':
args = stack.pop()
func_name = stack[-1].pop()
func = context.get(func_name)
- if not func:
- raise ValueError(f"Unknown function: {func_name}")
+ if not func: raise ValueError(f"Unknown: {func_name}")
stack[-1].append(func(*args))
elif token == ',':
continue
@@ -332,31 +348,34 @@ def z3_evaluate(model: str, X):
stack[-1].append(token)
return stack[0][0]
- exprs = [iterative_eval(line) for line in lines]
+ exprs = []
+ for line in model.splitlines():
+ line = line.strip()
+ exprs.append(iterative_eval(tokenize(line)))
return exprs
-def net(model: onnx.ModelProto, X, bounds: Optional[Dict[str, List[float]]] = None):
+def net(model: onnx.ModelProto, bounds: Optional[Dict[str, List[float]]] = None):
model_hash = hashlib.sha256(model.SerializeToString()).hexdigest()
bounds_key = tuple(sorted((k, tuple(v)) for k, v in bounds.items())) if bounds else None
cache_key = (model_hash, bounds_key)
- if cache_key not in _INPLA_CACHE:
+ if cache_key not in _CACHE:
exported = inpla_export(model, bounds)
reduced = inpla_run(exported)
- _INPLA_CACHE[cache_key] = reduced
+ X = {}
+ evaluated = z3_evaluate(reduced, X)
+ _CACHE[cache_key] = evaluated
- res = z3_evaluate(_INPLA_CACHE[cache_key], X)
- return res if res is not None else []
+ exprs = _CACHE[cache_key]
+ return exprs if exprs is not None else []
class Solver(z3.Solver):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
- self.X = {}
- self.Y = {}
self.bounds: Dict[str, List[float]] = {}
- self.pending_nets: List[Tuple[onnx.ModelProto, Optional[str]]] = []
+ self.pending_nets: List[onnx.ModelProto] = []
- def load_vnnlib(self, file_path):
+ def load_vnnlib(self, file_path: str):
with open(file_path, "r") as f:
content = f.read()
@@ -367,26 +386,22 @@ class Solver(z3.Solver):
if op == ">=": self.bounds[var][0] = val
else: self.bounds[var][1] = val
- content = re.sub(r"\(vnnlib-version.*?\)", "", content)
assertions = z3.parse_smt2_string(content)
self.add(assertions)
- def load_onnx(self, file, name=None):
- model = onnx.load(file)
+ def load_onnx(self, file_path: str):
+ model = onnx.load(file_path)
model = onnx.shape_inference.infer_shapes(model)
- self.pending_nets.append((model, name))
+ self.pending_nets.append(model)
def _process_nets(self):
y_count = 0
- for model, name in self.pending_nets:
- z3_outputs = net(model, self.X, bounds=self.bounds)
+ for model in self.pending_nets:
+ z3_outputs = net(model, bounds=self.bounds)
if z3_outputs:
for _, out_expr in enumerate(z3_outputs):
y_var = z3.Real(f"Y_{y_count}")
self.add(y_var == out_expr)
- if name:
- if name not in self.Y: self.Y[name] = []
- self.Y[name].append(out_expr)
y_count += 1
self.pending_nets = []
diff --git a/verify_example.py b/verify_example.py
new file mode 100644
index 0000000..1bf00fa
--- /dev/null
+++ b/verify_example.py
@@ -0,0 +1,79 @@
+import sys
+import vein
+
+def check_property(onnx_a, onnx_b, vnnlib):
+ solver = vein.Solver()
+
+ print(f"--- Checking {vnnlib} ---")
+
+ solver.load_onnx(onnx_a)
+ solver.load_onnx(onnx_b)
+ solver.load_vnnlib(vnnlib)
+
+ result = solver.check()
+
+ if result == vein.unsat:
+ print("VERIFIED (UNSAT): The networks are equivalent under this property.", end="\n\n")
+ elif result == vein.sat:
+ print("FAILED (SAT): The networks are NOT equivalent.")
+ print("Counter-example input:")
+ print(solver.model(), end="\n\n")
+ # m = solver.model()
+ # sorted_symbols = sorted([s for s in m.decls() if s.name().startswith("X_")], key=lambda s: s.name())
+ # for s in sorted_symbols:
+ # print(f" {s.name()} = {m[s]}")
+ else:
+ print("UNKNOWN", end="\n\n")
+
+if __name__ == "__main__":
+ if len(sys.argv) <= 1:
+ print("Net not provided")
+ print("Available Nets: 'xor', 'mnist', 'iris', 'acasxu', 'pendulum', 'double_integrator'")
+ sys.exit()
+
+ match sys.argv[1]:
+ case "xor":
+ net_a = "./examples/xor/xor_a.onnx"
+ net_b = "./examples/xor/xor_b.onnx"
+ strict = "./examples/xor/xor_strict.vnnlib"
+ epsilon = "./examples/xor/xor_epsilon.vnnlib"
+ argmax = "./examples/xor/xor_argmax.vnnlib"
+ case "mnist":
+ net_a = "./examples/mnist/mnist_a.onnx"
+ net_b = "./examples/mnist/mnist_b.onnx"
+ strict = "./examples/mnist/mnist_strict.vnnlib"
+ epsilon = "./examples/mnist/mnist_epsilon.vnnlib"
+ argmax = "./examples/mnist/mnist_argmax.vnnlib"
+ case "iris":
+ net_a = "./examples/iris/iris_a.onnx"
+ net_b = "./examples/iris/iris_b.onnx"
+ strict = "./examples/iris/iris_strict.vnnlib"
+ epsilon = "./examples/iris/iris_epsilon.vnnlib"
+ argmax = "./examples/iris/iris_argmax.vnnlib"
+ case "acasxu":
+ net_a = "./examples/ACASXU/ACASXU_run2a_1_1_batch_2000.onnx"
+ net_b = "./examples/ACASXU/ACASXU_run2a_1_1_batch_2000.onnx"
+ strict = "./examples/ACASXU/ACASXU_strict.vnnlib"
+ epsilon = "./examples/ACASXU/ACASXU_epsilon.vnnlib"
+ argmax = "./examples/ACASXU/ACASXU_argmax.vnnlib"
+ case "pendulum":
+ net_a = "./examples/pendulum/pendulum_finetune_con.onnx"
+ net_b = "./examples/pendulum/pendulum_finetune_con.onnx"
+ strict = "./examples/pendulum/pendulum_strict.vnnlib"
+ epsilon = "./examples/pendulum/pendulum_epsilon.vnnlib"
+ argmax = "./examples/pendulum/pendulum_argmax.vnnlib"
+ case "double_integrator":
+ net_a = "./examples/double_integrator/double_integrator_finetune_inv.onnx"
+ net_b = "./examples/double_integrator/double_integrator_finetune_inv.onnx"
+ strict = "./examples/double_integrator/double_integrator_strict.vnnlib"
+ epsilon = "./examples/double_integrator/double_integrator_epsilon.vnnlib"
+ argmax = "./examples/double_integrator/double_integrator_argmax.vnnlib"
+ case _:
+ print("Available Nets: 'xor', 'mnist', 'iris', 'acasxu', 'pendulum', 'double_integrator'")
+ sys.exit()
+
+ print(f"=== Comparing {net_a} and {net_b} ===", end="\n\n")
+
+ check_property(net_a, net_b, strict)
+ check_property(net_a, net_b, epsilon)
+ check_property(net_a, net_b, argmax)