aboutsummaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--examples/ACASXU/ACASXU_epsilon.vnnlib2
-rw-r--r--examples/double_integrator/double_integrator_epsilon.vnnlib2
-rw-r--r--examples/iris/iris_epsilon.vnnlib2
-rw-r--r--examples/mnist/mnist.py (renamed from examples/fashion_mnist/fashion_mnist.py)22
-rw-r--r--examples/mnist/mnist_argmax.vnnlib (renamed from examples/fashion_mnist/fashion_mnist_argmax.vnnlib)2
-rw-r--r--examples/mnist/mnist_epsilon.vnnlib (renamed from examples/fashion_mnist/fashion_mnist_epsilon.vnnlib)2
-rw-r--r--examples/mnist/mnist_strict.vnnlib (renamed from examples/fashion_mnist/fashion_mnist_strict.vnnlib)2
-rw-r--r--examples/pendulum/pendulum_epsilon.vnnlib2
-rw-r--r--examples/tll/tllBench_n=2_N=M=8_m=1_instance_0_0.onnxbin74221 -> 0 bytes
-rw-r--r--examples/tll/tllBench_n=2_N=M=8_m=1_instance_0_2.onnxbin74533 -> 0 bytes
-rw-r--r--examples/tll/tll_argmax.vnnlib16
-rw-r--r--examples/tll/tll_epsilon.vnnlib17
-rw-r--r--examples/tll/tll_strict.vnnlib16
-rw-r--r--examples/verify.py86
14 files changed, 18 insertions, 153 deletions
diff --git a/examples/ACASXU/ACASXU_epsilon.vnnlib b/examples/ACASXU/ACASXU_epsilon.vnnlib
index 0ca04f8..2cbcd36 100644
--- a/examples/ACASXU/ACASXU_epsilon.vnnlib
+++ b/examples/ACASXU/ACASXU_epsilon.vnnlib
@@ -1,4 +1,4 @@
-; Strict Equivalence for ACASXU
+; Epsilon Equivalence for ACASXU
; Constant declaration
(declare-const X_0 Real)
diff --git a/examples/double_integrator/double_integrator_epsilon.vnnlib b/examples/double_integrator/double_integrator_epsilon.vnnlib
index 3af9079..f5c4ee6 100644
--- a/examples/double_integrator/double_integrator_epsilon.vnnlib
+++ b/examples/double_integrator/double_integrator_epsilon.vnnlib
@@ -1,4 +1,4 @@
-; Strict Equivalence for Double Integrator
+; Epsilon Equivalence for Double Integrator
; Constant declaration
(declare-const X_0 Real)
diff --git a/examples/iris/iris_epsilon.vnnlib b/examples/iris/iris_epsilon.vnnlib
index 9c8e825..df691c4 100644
--- a/examples/iris/iris_epsilon.vnnlib
+++ b/examples/iris/iris_epsilon.vnnlib
@@ -1,4 +1,4 @@
-; Strict Equivalence for Iris
+; Epsilon Equivalence for Iris
; Constant declaration
(declare-const X_0 Real)
diff --git a/examples/fashion_mnist/fashion_mnist.py b/examples/mnist/mnist.py
index 680f4eb..0a81878 100644
--- a/examples/fashion_mnist/fashion_mnist.py
+++ b/examples/mnist/mnist.py
@@ -1,9 +1,9 @@
import torch, torch.nn as nn
-from torchvision.datasets import FashionMNIST
+from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torchvision import transforms
-class FashionMNIST_MLP(nn.Module):
+class MNIST_MLP(nn.Module):
def __init__(self, hidden_dim):
super().__init__()
self.layers = nn.Sequential(
@@ -15,16 +15,16 @@ class FashionMNIST_MLP(nn.Module):
def forward(self, x):
return self.layers(x)
-train_dataset = FashionMNIST('./', download=True, transform=transforms.ToTensor(), train=True)
+train_dataset = MNIST('./', download=True, transform=transforms.ToTensor(), train=True)
trainloader = DataLoader(train_dataset, batch_size=128, shuffle=True)
def train_model(name: str, dim):
- net = FashionMNIST_MLP(hidden_dim=dim)
+ net = MNIST_MLP(hidden_dim=dim)
loss_fn = nn.CrossEntropyLoss()
- optimizer = torch.optim.Adam(net.parameters(), lr=1e-4)
+ optimizer = torch.optim.Adam(net.parameters(), lr=0.5e-4)
print(f"Training {name} ({dim} neurons)...")
- for epoch in range(10):
+ for epoch in range(100):
global loss
for data in trainloader:
inputs, targets = data
@@ -33,13 +33,13 @@ def train_model(name: str, dim):
loss = loss_fn(outputs, targets)
loss.backward()
optimizer.step()
- if (epoch + 1) % 1 == 0:
+ if (epoch + 1) % 10 == 0:
print(f" Epoch {epoch+1}, Loss: {loss.item():.4f}")
return net
if __name__ == "__main__":
- torch_net_a = train_model("Network A", 28).eval()
- torch_net_b = train_model("Network B", 56).eval()
+ torch_net_a = train_model("Network A", 6).eval()
+ torch_net_b = train_model("Network B", 12).eval()
- torch.onnx.export(torch_net_a, (torch.randn(1, 28, 28),), "fashion_mnist_a.onnx")
- torch.onnx.export(torch_net_b, (torch.randn(1, 28, 28),), "fashion_mnist_b.onnx")
+ torch.onnx.export(torch_net_a, (torch.randn(1, 28, 28),), "mnist_a.onnx")
+ torch.onnx.export(torch_net_b, (torch.randn(1, 28, 28),), "mnist_b.onnx")
diff --git a/examples/fashion_mnist/fashion_mnist_argmax.vnnlib b/examples/mnist/mnist_argmax.vnnlib
index 8d06485..4c7f0c9 100644
--- a/examples/fashion_mnist/fashion_mnist_argmax.vnnlib
+++ b/examples/mnist/mnist_argmax.vnnlib
@@ -1,4 +1,4 @@
-; Argmax Equivalence for reduced FashionMNIST
+; Argmax Equivalence for MNIST
; Constant declaration
(declare-const X_0 Real)
diff --git a/examples/fashion_mnist/fashion_mnist_epsilon.vnnlib b/examples/mnist/mnist_epsilon.vnnlib
index 9637489..ea76779 100644
--- a/examples/fashion_mnist/fashion_mnist_epsilon.vnnlib
+++ b/examples/mnist/mnist_epsilon.vnnlib
@@ -1,4 +1,4 @@
-; Strict Equivalence for reduced FashionMNIST
+; Epsilon Equivalence for MNIST
; Constant declaration
(declare-const X_0 Real)
diff --git a/examples/fashion_mnist/fashion_mnist_strict.vnnlib b/examples/mnist/mnist_strict.vnnlib
index c8b3f8e..356f176 100644
--- a/examples/fashion_mnist/fashion_mnist_strict.vnnlib
+++ b/examples/mnist/mnist_strict.vnnlib
@@ -1,4 +1,4 @@
-; Strict Equivalence for reduced FashionMNIST
+; Strict Equivalence for MNIST
; Constant declaration
(declare-const X_0 Real)
diff --git a/examples/pendulum/pendulum_epsilon.vnnlib b/examples/pendulum/pendulum_epsilon.vnnlib
index a192029..8209db5 100644
--- a/examples/pendulum/pendulum_epsilon.vnnlib
+++ b/examples/pendulum/pendulum_epsilon.vnnlib
@@ -1,4 +1,4 @@
-; Strict Equivalence for Pendulum
+; Epsilon Equivalence for Pendulum
; Constant declaration
(declare-const X_0 Real)
diff --git a/examples/tll/tllBench_n=2_N=M=8_m=1_instance_0_0.onnx b/examples/tll/tllBench_n=2_N=M=8_m=1_instance_0_0.onnx
deleted file mode 100644
index a6632fb..0000000
--- a/examples/tll/tllBench_n=2_N=M=8_m=1_instance_0_0.onnx
+++ /dev/null
Binary files differ
diff --git a/examples/tll/tllBench_n=2_N=M=8_m=1_instance_0_2.onnx b/examples/tll/tllBench_n=2_N=M=8_m=1_instance_0_2.onnx
deleted file mode 100644
index b650a13..0000000
--- a/examples/tll/tllBench_n=2_N=M=8_m=1_instance_0_2.onnx
+++ /dev/null
Binary files differ
diff --git a/examples/tll/tll_argmax.vnnlib b/examples/tll/tll_argmax.vnnlib
deleted file mode 100644
index c084e52..0000000
--- a/examples/tll/tll_argmax.vnnlib
+++ /dev/null
@@ -1,16 +0,0 @@
-; Argmax Equivalence for TLL
-
-; Constant declaration
-(declare-const X_0 Real)
-(declare-const X_1 Real)
-(declare-const Y_0 Real)
-(declare-const Y_1 Real)
-
-; Bounded inputs: X must be within [0, 1]
-(assert (>= X_0 0.0))
-(assert (<= X_0 1.0))
-(assert (>= X_1 0.0))
-(assert (<= X_1 1.0))
-
-; Violation of argmax equivalence
-(assert (and (> Y_0 0.5) (< Y_1 0.5)))
diff --git a/examples/tll/tll_epsilon.vnnlib b/examples/tll/tll_epsilon.vnnlib
deleted file mode 100644
index 8e0902d..0000000
--- a/examples/tll/tll_epsilon.vnnlib
+++ /dev/null
@@ -1,17 +0,0 @@
-; Strict Equivalence for TLL
-
-; Constant declaration
-(declare-const X_0 Real)
-(declare-const X_1 Real)
-(declare-const Y_0 Real)
-(declare-const Y_1 Real)
-
-; Bounded inputs: X must be within [0, 1]
-(assert (>= X_0 0.0))
-(assert (<= X_0 1.0))
-(assert (>= X_1 0.0))
-(assert (<= X_1 1.0))
-
-; Violation of epsilon equivalence (epsilon = 0.1)
-(define-fun absolute ((x Real)) Real (if (>= x 0) x (- x)))
-(assert (> (absolute (- Y_0 Y_1)) 0.1))
diff --git a/examples/tll/tll_strict.vnnlib b/examples/tll/tll_strict.vnnlib
deleted file mode 100644
index 0079b1e..0000000
--- a/examples/tll/tll_strict.vnnlib
+++ /dev/null
@@ -1,16 +0,0 @@
-; Strict Equivalence for TLL
-
-; Constant declaration
-(declare-const X_0 Real)
-(declare-const X_1 Real)
-(declare-const Y_0 Real)
-(declare-const Y_1 Real)
-
-; Bounded inputs: X must be within [0, 1]
-(assert (>= X_0 0.0))
-(assert (<= X_0 1.0))
-(assert (>= X_1 0.0))
-(assert (<= X_1 1.0))
-
-; Violation of strict equivalence
-(assert (not (= Y_0 Y_1)))
diff --git a/examples/verify.py b/examples/verify.py
deleted file mode 100644
index 65fb989..0000000
--- a/examples/verify.py
+++ /dev/null
@@ -1,86 +0,0 @@
-import sys, os
-sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
-
-import z3
-import vein
-
-def check_property(onnx_a, onnx_b, vnnlib):
- solver = vein.Solver()
-
- print(f"--- Checking {vnnlib} ---")
-
- solver.load_onnx(onnx_a)
- solver.load_onnx(onnx_b)
- solver.load_vnnlib(vnnlib)
-
- result = solver.check()
-
- if result == z3.unsat:
- print("VERIFIED (UNSAT): The networks are equivalent under this property.")
- elif result == z3.sat:
- print("FAILED (SAT): The networks are NOT equivalent.")
- print("Counter-example input:")
- print(solver.model())
- # m = solver.model()
- # sorted_symbols = sorted([s for s in m.decls() if s.name().startswith("X_")], key=lambda s: s.name())
- # for s in sorted_symbols:
- # print(f" {s.name()} = {m[s]}")
- else:
- print("UNKNOWN")
- print("")
-
-if __name__ == "__main__":
- if len(sys.argv) <= 1:
- print("Net not provided")
- print("Available Nets: 'xor', 'fashion_mnist', 'iris', 'acasxu', 'tll', 'pendulum', 'double_integrator'")
- else:
- match sys.argv[1]:
- case "xor":
- net_a = "./examples/xor/xor_a.onnx"
- net_b = "./examples/xor/xor_b.onnx"
- strict = "./examples/xor/xor_strict.vnnlib"
- epsilon = "./examples/xor/xor_epsilon.vnnlib"
- argmax = "./examples/xor/xor_argmax.vnnlib"
- case "fashion_mnist":
- net_a = "./examples/fashion_mnist/fashion_mnist_a.onnx"
- net_b = "./examples/fashion_mnist/fashion_mnist_b.onnx"
- strict = "./examples/fashion_mnist/fashion_mnist_strict.vnnlib"
- epsilon = "./examples/fashion_mnist/fashion_mnist_epsilon.vnnlib"
- argmax = "./examples/fashion_mnist/fashion_mnist_argmax.vnnlib"
- case "iris":
- net_a = "./examples/iris/iris_a.onnx"
- net_b = "./examples/iris/iris_b.onnx"
- strict = "./examples/iris/iris_strict.vnnlib"
- epsilon = "./examples/iris/iris_epsilon.vnnlib"
- argmax = "./examples/iris/iris_argmax.vnnlib"
- case "acasxu":
- net_a = "./examples/ACASXU/ACASXU_run2a_1_1_batch_2000.onnx"
- net_b = "./examples/ACASXU/ACASXU_run2a_1_1_batch_2000.onnx"
- strict = "./examples/ACASXU/ACASXU_strict.vnnlib"
- epsilon = "./examples/ACASXU/ACASXU_epsilon.vnnlib"
- argmax = "./examples/ACASXU/ACASXU_argmax.vnnlib"
- case "tll":
- net_a = "./examples/tll/tllBench_n=2_N=M=8_m=1_instance_0_0.onnx"
- net_b = "./examples/tll/tllBench_n=2_N=M=8_m=1_instance_0_2.onnx"
- strict = "./examples/tll/tll_strict.vnnlib"
- epsilon = "./examples/tll/tll_epsilon.vnnlib"
- argmax = "./examples/tll/tll_argmax.vnnlib"
- case "pendulum":
- net_a = "./examples/pendulum/pendulum_finetune_con.onnx"
- net_b = "./examples/pendulum/pendulum_finetune_con.onnx"
- strict = "./examples/pendulum/pendulum_strict.vnnlib"
- epsilon = "./examples/pendulum/pendulum_epsilon.vnnlib"
- argmax = "./examples/pendulum/pendulum_argmax.vnnlib"
- case "double_integrator":
- net_a = "./examples/double_integrator/double_integrator_finetune_inv.onnx"
- net_b = "./examples/double_integrator/double_integrator_finetune_inv.onnx"
- strict = "./examples/double_integrator/double_integrator_strict.vnnlib"
- epsilon = "./examples/double_integrator/double_integrator_epsilon.vnnlib"
- argmax = "./examples/double_integrator/double_integrator_argmax.vnnlib"
- case _:
- print("Available Nets: 'xor', 'fashion_mnist', 'iris', 'acasxu', 'tll', 'pendulum', 'double_integrator'")
- sys.exit()
-
- check_property(net_a, net_b, strict)
- check_property(net_a, net_b, epsilon)
- check_property(net_a, net_b, argmax)