diff options
Diffstat (limited to '')
| -rw-r--r-- | examples/ACASXU/ACASXU_epsilon.vnnlib | 2 | ||||
| -rw-r--r-- | examples/double_integrator/double_integrator_epsilon.vnnlib | 2 | ||||
| -rw-r--r-- | examples/iris/iris_epsilon.vnnlib | 2 | ||||
| -rw-r--r-- | examples/mnist/mnist.py (renamed from examples/fashion_mnist/fashion_mnist.py) | 22 | ||||
| -rw-r--r-- | examples/mnist/mnist_argmax.vnnlib (renamed from examples/fashion_mnist/fashion_mnist_argmax.vnnlib) | 2 | ||||
| -rw-r--r-- | examples/mnist/mnist_epsilon.vnnlib (renamed from examples/fashion_mnist/fashion_mnist_epsilon.vnnlib) | 2 | ||||
| -rw-r--r-- | examples/mnist/mnist_strict.vnnlib (renamed from examples/fashion_mnist/fashion_mnist_strict.vnnlib) | 2 | ||||
| -rw-r--r-- | examples/pendulum/pendulum_epsilon.vnnlib | 2 | ||||
| -rw-r--r-- | examples/tll/tllBench_n=2_N=M=8_m=1_instance_0_0.onnx | bin | 74221 -> 0 bytes | |||
| -rw-r--r-- | examples/tll/tllBench_n=2_N=M=8_m=1_instance_0_2.onnx | bin | 74533 -> 0 bytes | |||
| -rw-r--r-- | examples/tll/tll_argmax.vnnlib | 16 | ||||
| -rw-r--r-- | examples/tll/tll_epsilon.vnnlib | 17 | ||||
| -rw-r--r-- | examples/tll/tll_strict.vnnlib | 16 | ||||
| -rw-r--r-- | examples/verify.py | 86 |
14 files changed, 18 insertions, 153 deletions
diff --git a/examples/ACASXU/ACASXU_epsilon.vnnlib b/examples/ACASXU/ACASXU_epsilon.vnnlib index 0ca04f8..2cbcd36 100644 --- a/examples/ACASXU/ACASXU_epsilon.vnnlib +++ b/examples/ACASXU/ACASXU_epsilon.vnnlib @@ -1,4 +1,4 @@ -; Strict Equivalence for ACASXU +; Epsilon Equivalence for ACASXU ; Constant declaration (declare-const X_0 Real) diff --git a/examples/double_integrator/double_integrator_epsilon.vnnlib b/examples/double_integrator/double_integrator_epsilon.vnnlib index 3af9079..f5c4ee6 100644 --- a/examples/double_integrator/double_integrator_epsilon.vnnlib +++ b/examples/double_integrator/double_integrator_epsilon.vnnlib @@ -1,4 +1,4 @@ -; Strict Equivalence for Double Integrator +; Epsilon Equivalence for Double Integrator ; Constant declaration (declare-const X_0 Real) diff --git a/examples/iris/iris_epsilon.vnnlib b/examples/iris/iris_epsilon.vnnlib index 9c8e825..df691c4 100644 --- a/examples/iris/iris_epsilon.vnnlib +++ b/examples/iris/iris_epsilon.vnnlib @@ -1,4 +1,4 @@ -; Strict Equivalence for Iris +; Epsilon Equivalence for Iris ; Constant declaration (declare-const X_0 Real) diff --git a/examples/fashion_mnist/fashion_mnist.py b/examples/mnist/mnist.py index 680f4eb..0a81878 100644 --- a/examples/fashion_mnist/fashion_mnist.py +++ b/examples/mnist/mnist.py @@ -1,9 +1,9 @@ import torch, torch.nn as nn -from torchvision.datasets import FashionMNIST +from torchvision.datasets import MNIST from torch.utils.data import DataLoader from torchvision import transforms -class FashionMNIST_MLP(nn.Module): +class MNIST_MLP(nn.Module): def __init__(self, hidden_dim): super().__init__() self.layers = nn.Sequential( @@ -15,16 +15,16 @@ class FashionMNIST_MLP(nn.Module): def forward(self, x): return self.layers(x) -train_dataset = FashionMNIST('./', download=True, transform=transforms.ToTensor(), train=True) +train_dataset = MNIST('./', download=True, transform=transforms.ToTensor(), train=True) trainloader = DataLoader(train_dataset, batch_size=128, shuffle=True) def train_model(name: str, dim): - net = FashionMNIST_MLP(hidden_dim=dim) + net = MNIST_MLP(hidden_dim=dim) loss_fn = nn.CrossEntropyLoss() - optimizer = torch.optim.Adam(net.parameters(), lr=1e-4) + optimizer = torch.optim.Adam(net.parameters(), lr=0.5e-4) print(f"Training {name} ({dim} neurons)...") - for epoch in range(10): + for epoch in range(100): global loss for data in trainloader: inputs, targets = data @@ -33,13 +33,13 @@ def train_model(name: str, dim): loss = loss_fn(outputs, targets) loss.backward() optimizer.step() - if (epoch + 1) % 1 == 0: + if (epoch + 1) % 10 == 0: print(f" Epoch {epoch+1}, Loss: {loss.item():.4f}") return net if __name__ == "__main__": - torch_net_a = train_model("Network A", 28).eval() - torch_net_b = train_model("Network B", 56).eval() + torch_net_a = train_model("Network A", 6).eval() + torch_net_b = train_model("Network B", 12).eval() - torch.onnx.export(torch_net_a, (torch.randn(1, 28, 28),), "fashion_mnist_a.onnx") - torch.onnx.export(torch_net_b, (torch.randn(1, 28, 28),), "fashion_mnist_b.onnx") + torch.onnx.export(torch_net_a, (torch.randn(1, 28, 28),), "mnist_a.onnx") + torch.onnx.export(torch_net_b, (torch.randn(1, 28, 28),), "mnist_b.onnx") diff --git a/examples/fashion_mnist/fashion_mnist_argmax.vnnlib b/examples/mnist/mnist_argmax.vnnlib index 8d06485..4c7f0c9 100644 --- a/examples/fashion_mnist/fashion_mnist_argmax.vnnlib +++ b/examples/mnist/mnist_argmax.vnnlib @@ -1,4 +1,4 @@ -; Argmax Equivalence for reduced FashionMNIST +; Argmax Equivalence for MNIST ; Constant declaration (declare-const X_0 Real) diff --git a/examples/fashion_mnist/fashion_mnist_epsilon.vnnlib b/examples/mnist/mnist_epsilon.vnnlib index 9637489..ea76779 100644 --- a/examples/fashion_mnist/fashion_mnist_epsilon.vnnlib +++ b/examples/mnist/mnist_epsilon.vnnlib @@ -1,4 +1,4 @@ -; Strict Equivalence for reduced FashionMNIST +; Epsilon Equivalence for MNIST ; Constant declaration (declare-const X_0 Real) diff --git a/examples/fashion_mnist/fashion_mnist_strict.vnnlib b/examples/mnist/mnist_strict.vnnlib index c8b3f8e..356f176 100644 --- a/examples/fashion_mnist/fashion_mnist_strict.vnnlib +++ b/examples/mnist/mnist_strict.vnnlib @@ -1,4 +1,4 @@ -; Strict Equivalence for reduced FashionMNIST +; Strict Equivalence for MNIST ; Constant declaration (declare-const X_0 Real) diff --git a/examples/pendulum/pendulum_epsilon.vnnlib b/examples/pendulum/pendulum_epsilon.vnnlib index a192029..8209db5 100644 --- a/examples/pendulum/pendulum_epsilon.vnnlib +++ b/examples/pendulum/pendulum_epsilon.vnnlib @@ -1,4 +1,4 @@ -; Strict Equivalence for Pendulum +; Epsilon Equivalence for Pendulum ; Constant declaration (declare-const X_0 Real) diff --git a/examples/tll/tllBench_n=2_N=M=8_m=1_instance_0_0.onnx b/examples/tll/tllBench_n=2_N=M=8_m=1_instance_0_0.onnx Binary files differdeleted file mode 100644 index a6632fb..0000000 --- a/examples/tll/tllBench_n=2_N=M=8_m=1_instance_0_0.onnx +++ /dev/null diff --git a/examples/tll/tllBench_n=2_N=M=8_m=1_instance_0_2.onnx b/examples/tll/tllBench_n=2_N=M=8_m=1_instance_0_2.onnx Binary files differdeleted file mode 100644 index b650a13..0000000 --- a/examples/tll/tllBench_n=2_N=M=8_m=1_instance_0_2.onnx +++ /dev/null diff --git a/examples/tll/tll_argmax.vnnlib b/examples/tll/tll_argmax.vnnlib deleted file mode 100644 index c084e52..0000000 --- a/examples/tll/tll_argmax.vnnlib +++ /dev/null @@ -1,16 +0,0 @@ -; Argmax Equivalence for TLL - -; Constant declaration -(declare-const X_0 Real) -(declare-const X_1 Real) -(declare-const Y_0 Real) -(declare-const Y_1 Real) - -; Bounded inputs: X must be within [0, 1] -(assert (>= X_0 0.0)) -(assert (<= X_0 1.0)) -(assert (>= X_1 0.0)) -(assert (<= X_1 1.0)) - -; Violation of argmax equivalence -(assert (and (> Y_0 0.5) (< Y_1 0.5))) diff --git a/examples/tll/tll_epsilon.vnnlib b/examples/tll/tll_epsilon.vnnlib deleted file mode 100644 index 8e0902d..0000000 --- a/examples/tll/tll_epsilon.vnnlib +++ /dev/null @@ -1,17 +0,0 @@ -; Strict Equivalence for TLL - -; Constant declaration -(declare-const X_0 Real) -(declare-const X_1 Real) -(declare-const Y_0 Real) -(declare-const Y_1 Real) - -; Bounded inputs: X must be within [0, 1] -(assert (>= X_0 0.0)) -(assert (<= X_0 1.0)) -(assert (>= X_1 0.0)) -(assert (<= X_1 1.0)) - -; Violation of epsilon equivalence (epsilon = 0.1) -(define-fun absolute ((x Real)) Real (if (>= x 0) x (- x))) -(assert (> (absolute (- Y_0 Y_1)) 0.1)) diff --git a/examples/tll/tll_strict.vnnlib b/examples/tll/tll_strict.vnnlib deleted file mode 100644 index 0079b1e..0000000 --- a/examples/tll/tll_strict.vnnlib +++ /dev/null @@ -1,16 +0,0 @@ -; Strict Equivalence for TLL - -; Constant declaration -(declare-const X_0 Real) -(declare-const X_1 Real) -(declare-const Y_0 Real) -(declare-const Y_1 Real) - -; Bounded inputs: X must be within [0, 1] -(assert (>= X_0 0.0)) -(assert (<= X_0 1.0)) -(assert (>= X_1 0.0)) -(assert (<= X_1 1.0)) - -; Violation of strict equivalence -(assert (not (= Y_0 Y_1))) diff --git a/examples/verify.py b/examples/verify.py deleted file mode 100644 index 65fb989..0000000 --- a/examples/verify.py +++ /dev/null @@ -1,86 +0,0 @@ -import sys, os -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - -import z3 -import vein - -def check_property(onnx_a, onnx_b, vnnlib): - solver = vein.Solver() - - print(f"--- Checking {vnnlib} ---") - - solver.load_onnx(onnx_a) - solver.load_onnx(onnx_b) - solver.load_vnnlib(vnnlib) - - result = solver.check() - - if result == z3.unsat: - print("VERIFIED (UNSAT): The networks are equivalent under this property.") - elif result == z3.sat: - print("FAILED (SAT): The networks are NOT equivalent.") - print("Counter-example input:") - print(solver.model()) - # m = solver.model() - # sorted_symbols = sorted([s for s in m.decls() if s.name().startswith("X_")], key=lambda s: s.name()) - # for s in sorted_symbols: - # print(f" {s.name()} = {m[s]}") - else: - print("UNKNOWN") - print("") - -if __name__ == "__main__": - if len(sys.argv) <= 1: - print("Net not provided") - print("Available Nets: 'xor', 'fashion_mnist', 'iris', 'acasxu', 'tll', 'pendulum', 'double_integrator'") - else: - match sys.argv[1]: - case "xor": - net_a = "./examples/xor/xor_a.onnx" - net_b = "./examples/xor/xor_b.onnx" - strict = "./examples/xor/xor_strict.vnnlib" - epsilon = "./examples/xor/xor_epsilon.vnnlib" - argmax = "./examples/xor/xor_argmax.vnnlib" - case "fashion_mnist": - net_a = "./examples/fashion_mnist/fashion_mnist_a.onnx" - net_b = "./examples/fashion_mnist/fashion_mnist_b.onnx" - strict = "./examples/fashion_mnist/fashion_mnist_strict.vnnlib" - epsilon = "./examples/fashion_mnist/fashion_mnist_epsilon.vnnlib" - argmax = "./examples/fashion_mnist/fashion_mnist_argmax.vnnlib" - case "iris": - net_a = "./examples/iris/iris_a.onnx" - net_b = "./examples/iris/iris_b.onnx" - strict = "./examples/iris/iris_strict.vnnlib" - epsilon = "./examples/iris/iris_epsilon.vnnlib" - argmax = "./examples/iris/iris_argmax.vnnlib" - case "acasxu": - net_a = "./examples/ACASXU/ACASXU_run2a_1_1_batch_2000.onnx" - net_b = "./examples/ACASXU/ACASXU_run2a_1_1_batch_2000.onnx" - strict = "./examples/ACASXU/ACASXU_strict.vnnlib" - epsilon = "./examples/ACASXU/ACASXU_epsilon.vnnlib" - argmax = "./examples/ACASXU/ACASXU_argmax.vnnlib" - case "tll": - net_a = "./examples/tll/tllBench_n=2_N=M=8_m=1_instance_0_0.onnx" - net_b = "./examples/tll/tllBench_n=2_N=M=8_m=1_instance_0_2.onnx" - strict = "./examples/tll/tll_strict.vnnlib" - epsilon = "./examples/tll/tll_epsilon.vnnlib" - argmax = "./examples/tll/tll_argmax.vnnlib" - case "pendulum": - net_a = "./examples/pendulum/pendulum_finetune_con.onnx" - net_b = "./examples/pendulum/pendulum_finetune_con.onnx" - strict = "./examples/pendulum/pendulum_strict.vnnlib" - epsilon = "./examples/pendulum/pendulum_epsilon.vnnlib" - argmax = "./examples/pendulum/pendulum_argmax.vnnlib" - case "double_integrator": - net_a = "./examples/double_integrator/double_integrator_finetune_inv.onnx" - net_b = "./examples/double_integrator/double_integrator_finetune_inv.onnx" - strict = "./examples/double_integrator/double_integrator_strict.vnnlib" - epsilon = "./examples/double_integrator/double_integrator_epsilon.vnnlib" - argmax = "./examples/double_integrator/double_integrator_argmax.vnnlib" - case _: - print("Available Nets: 'xor', 'fashion_mnist', 'iris', 'acasxu', 'tll', 'pendulum', 'double_integrator'") - sys.exit() - - check_property(net_a, net_b, strict) - check_property(net_a, net_b, epsilon) - check_property(net_a, net_b, argmax) |
