From 0fca69965786db7deee2e976551b5156531e8ed5 Mon Sep 17 00:00:00 2001 From: ericmarin Date: Tue, 17 Mar 2026 18:59:35 +0100 Subject: added proof for ONNX translation --- xor.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'xor.py') diff --git a/xor.py b/xor.py index 0f8390d..82a16b8 100644 --- a/xor.py +++ b/xor.py @@ -4,7 +4,7 @@ import torch.onnx import nneq class xor_mlp(nn.Module): - def __init__(self, hidden_dim=8): + def __init__(self, hidden_dim): super().__init__() self.layers = nn.Sequential( nn.Linear(2, hidden_dim), @@ -37,8 +37,8 @@ if __name__ == "__main__": torch_net_a = train_model("Network A", 8).eval() torch_net_b = train_model("Network B", 16).eval() - onnx_net_a = torch.onnx.export(torch_net_a, (torch.randn(1, 2),), verbose=False, dynamo=True).model_proto # type: ignore - onnx_net_b = torch.onnx.export(torch_net_b, (torch.randn(1, 2),), verbose=False, dynamo=True).model_proto # type: ignore + onnx_net_a = torch.onnx.export(torch_net_a, (torch.randn(1, 2),), verbose=False).model_proto # type: ignore + onnx_net_b = torch.onnx.export(torch_net_b, (torch.randn(1, 2),), verbose=False).model_proto # type: ignore z3_net_a = nneq.net(onnx_net_a) z3_net_b = nneq.net(onnx_net_b) -- cgit v1.2.3