aboutsummaryrefslogtreecommitdiff
path: root/examples/iris/iris.py
diff options
context:
space:
mode:
Diffstat (limited to 'examples/iris/iris.py')
-rw-r--r--examples/iris/iris.py22
1 files changed, 4 insertions, 18 deletions
diff --git a/examples/iris/iris.py b/examples/iris/iris.py
index db631c0..84d1ac6 100644
--- a/examples/iris/iris.py
+++ b/examples/iris/iris.py
@@ -1,5 +1,4 @@
-import torch
-import torch.nn as nn
+import torch, torch.nn as nn
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler
from torch.utils.data import DataLoader, TensorDataset
@@ -29,7 +28,7 @@ def train_model(name: str, dim):
optimizer = torch.optim.Adam(net.parameters(), lr=1e-2)
print(f"Training {name} ({dim} neurons)...")
- for epoch in range(200):
+ for epoch in range(100):
global loss
for data in trainloader:
inputs, targets = data
@@ -38,26 +37,13 @@ def train_model(name: str, dim):
loss = loss_fn(outputs, targets)
loss.backward()
optimizer.step()
- if (epoch + 1) % 100 == 0:
+ if (epoch + 1) % 10 == 0:
print(f" Epoch {epoch+1}, Loss: {loss.item():.4f}")
return net
if __name__ == "__main__":
torch_net_a = train_model("Network A", 10).eval()
- torch_net_b = Iris_MLP(hidden_dim=20).eval()
-
- with torch.no_grad():
- torch_net_b.layers[0].weight[:10].copy_(torch_net_a.layers[0].weight) # pyright: ignore
- torch_net_b.layers[0].bias[:10].copy_(torch_net_a.layers[0].bias) # pyright: ignore
- torch_net_b.layers[0].weight[10:].copy_(torch_net_a.layers[0].weight) # pyright: ignore
- torch_net_b.layers[0].bias[10:].copy_(torch_net_a.layers[0].bias) # pyright: ignore
-
- half_weights = torch_net_a.layers[2].weight / 2.0 # pyright: ignore
-
- torch_net_b.layers[2].weight[:, :10].copy_(half_weights) # pyright: ignore
- torch_net_b.layers[2].weight[:, 10:].copy_(half_weights) # pyright: ignore
-
- torch_net_b.layers[2].bias.copy_(torch_net_a.layers[2].bias) # pyright: ignore
+ torch_net_b = train_model("Network B", 20).eval()
torch.onnx.export(torch_net_a, (torch.randn(1, 4),), "iris_a.onnx")
torch.onnx.export(torch_net_b, (torch.randn(1, 4),), "iris_b.onnx")