aboutsummaryrefslogtreecommitdiff
path: root/xor
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--xor.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/xor.py b/xor.py
index 0f8390d..82a16b8 100644
--- a/xor.py
+++ b/xor.py
@@ -4,7 +4,7 @@ import torch.onnx
import nneq
class xor_mlp(nn.Module):
- def __init__(self, hidden_dim=8):
+ def __init__(self, hidden_dim):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(2, hidden_dim),
@@ -37,8 +37,8 @@ if __name__ == "__main__":
torch_net_a = train_model("Network A", 8).eval()
torch_net_b = train_model("Network B", 16).eval()
- onnx_net_a = torch.onnx.export(torch_net_a, (torch.randn(1, 2),), verbose=False, dynamo=True).model_proto # type: ignore
- onnx_net_b = torch.onnx.export(torch_net_b, (torch.randn(1, 2),), verbose=False, dynamo=True).model_proto # type: ignore
+ onnx_net_a = torch.onnx.export(torch_net_a, (torch.randn(1, 2),), verbose=False).model_proto # type: ignore
+ onnx_net_b = torch.onnx.export(torch_net_b, (torch.randn(1, 2),), verbose=False).model_proto # type: ignore
z3_net_a = nneq.net(onnx_net_a)
z3_net_b = nneq.net(onnx_net_b)