diff options
Diffstat (limited to 'xor.py')
| -rw-r--r-- | xor.py | 6 |
1 files changed, 3 insertions, 3 deletions
@@ -4,7 +4,7 @@ import torch.onnx import nneq class xor_mlp(nn.Module): - def __init__(self, hidden_dim=8): + def __init__(self, hidden_dim): super().__init__() self.layers = nn.Sequential( nn.Linear(2, hidden_dim), @@ -37,8 +37,8 @@ if __name__ == "__main__": torch_net_a = train_model("Network A", 8).eval() torch_net_b = train_model("Network B", 16).eval() - onnx_net_a = torch.onnx.export(torch_net_a, (torch.randn(1, 2),), verbose=False, dynamo=True).model_proto # type: ignore - onnx_net_b = torch.onnx.export(torch_net_b, (torch.randn(1, 2),), verbose=False, dynamo=True).model_proto # type: ignore + onnx_net_a = torch.onnx.export(torch_net_a, (torch.randn(1, 2),), verbose=False).model_proto # type: ignore + onnx_net_b = torch.onnx.export(torch_net_b, (torch.randn(1, 2),), verbose=False).model_proto # type: ignore z3_net_a = nneq.net(onnx_net_a) z3_net_b = nneq.net(onnx_net_b) |
