aboutsummaryrefslogtreecommitdiff
path: root/examples/mnist
diff options
context:
space:
mode:
authorericmarin <maarin.eric@gmail.com>2026-04-10 15:06:56 +0200
committerericmarin <maarin.eric@gmail.com>2026-04-13 10:55:06 +0200
commit8f4f24523235965cfa2041ed00cc40fc0b4bd367 (patch)
tree716862b4c898431861c27ab69165edfb245467dc /examples/mnist
parent9fb816496d392638fa6981e71800466d71434680 (diff)
downloadvein-8f4f24523235965cfa2041ed00cc40fc0b4bd367.tar.gz
vein-8f4f24523235965cfa2041ed00cc40fc0b4bd367.zip
added MNIST, changed cache and parser
Diffstat (limited to '')
-rw-r--r--examples/mnist/mnist.py (renamed from examples/fashion_mnist/fashion_mnist.py)22
-rw-r--r--examples/mnist/mnist_argmax.vnnlib (renamed from examples/fashion_mnist/fashion_mnist_argmax.vnnlib)2
-rw-r--r--examples/mnist/mnist_epsilon.vnnlib (renamed from examples/fashion_mnist/fashion_mnist_epsilon.vnnlib)2
-rw-r--r--examples/mnist/mnist_strict.vnnlib (renamed from examples/fashion_mnist/fashion_mnist_strict.vnnlib)2
4 files changed, 14 insertions, 14 deletions
diff --git a/examples/fashion_mnist/fashion_mnist.py b/examples/mnist/mnist.py
index 680f4eb..0a81878 100644
--- a/examples/fashion_mnist/fashion_mnist.py
+++ b/examples/mnist/mnist.py
@@ -1,9 +1,9 @@
import torch, torch.nn as nn
-from torchvision.datasets import FashionMNIST
+from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torchvision import transforms
-class FashionMNIST_MLP(nn.Module):
+class MNIST_MLP(nn.Module):
def __init__(self, hidden_dim):
super().__init__()
self.layers = nn.Sequential(
@@ -15,16 +15,16 @@ class FashionMNIST_MLP(nn.Module):
def forward(self, x):
return self.layers(x)
-train_dataset = FashionMNIST('./', download=True, transform=transforms.ToTensor(), train=True)
+train_dataset = MNIST('./', download=True, transform=transforms.ToTensor(), train=True)
trainloader = DataLoader(train_dataset, batch_size=128, shuffle=True)
def train_model(name: str, dim):
- net = FashionMNIST_MLP(hidden_dim=dim)
+ net = MNIST_MLP(hidden_dim=dim)
loss_fn = nn.CrossEntropyLoss()
- optimizer = torch.optim.Adam(net.parameters(), lr=1e-4)
+ optimizer = torch.optim.Adam(net.parameters(), lr=0.5e-4)
print(f"Training {name} ({dim} neurons)...")
- for epoch in range(10):
+ for epoch in range(100):
global loss
for data in trainloader:
inputs, targets = data
@@ -33,13 +33,13 @@ def train_model(name: str, dim):
loss = loss_fn(outputs, targets)
loss.backward()
optimizer.step()
- if (epoch + 1) % 1 == 0:
+ if (epoch + 1) % 10 == 0:
print(f" Epoch {epoch+1}, Loss: {loss.item():.4f}")
return net
if __name__ == "__main__":
- torch_net_a = train_model("Network A", 28).eval()
- torch_net_b = train_model("Network B", 56).eval()
+ torch_net_a = train_model("Network A", 6).eval()
+ torch_net_b = train_model("Network B", 12).eval()
- torch.onnx.export(torch_net_a, (torch.randn(1, 28, 28),), "fashion_mnist_a.onnx")
- torch.onnx.export(torch_net_b, (torch.randn(1, 28, 28),), "fashion_mnist_b.onnx")
+ torch.onnx.export(torch_net_a, (torch.randn(1, 28, 28),), "mnist_a.onnx")
+ torch.onnx.export(torch_net_b, (torch.randn(1, 28, 28),), "mnist_b.onnx")
diff --git a/examples/fashion_mnist/fashion_mnist_argmax.vnnlib b/examples/mnist/mnist_argmax.vnnlib
index 8d06485..4c7f0c9 100644
--- a/examples/fashion_mnist/fashion_mnist_argmax.vnnlib
+++ b/examples/mnist/mnist_argmax.vnnlib
@@ -1,4 +1,4 @@
-; Argmax Equivalence for reduced FashionMNIST
+; Argmax Equivalence for MNIST
; Constant declaration
(declare-const X_0 Real)
diff --git a/examples/fashion_mnist/fashion_mnist_epsilon.vnnlib b/examples/mnist/mnist_epsilon.vnnlib
index 9637489..ea76779 100644
--- a/examples/fashion_mnist/fashion_mnist_epsilon.vnnlib
+++ b/examples/mnist/mnist_epsilon.vnnlib
@@ -1,4 +1,4 @@
-; Strict Equivalence for reduced FashionMNIST
+; Epsilon Equivalence for MNIST
; Constant declaration
(declare-const X_0 Real)
diff --git a/examples/fashion_mnist/fashion_mnist_strict.vnnlib b/examples/mnist/mnist_strict.vnnlib
index c8b3f8e..356f176 100644
--- a/examples/fashion_mnist/fashion_mnist_strict.vnnlib
+++ b/examples/mnist/mnist_strict.vnnlib
@@ -1,4 +1,4 @@
-; Strict Equivalence for reduced FashionMNIST
+; Strict Equivalence for MNIST
; Constant declaration
(declare-const X_0 Real)