diff options
Diffstat (limited to '')
| -rw-r--r-- | xor/xor.py (renamed from xor.py) | 15 |
1 files changed, 2 insertions, 13 deletions
@@ -1,7 +1,6 @@ import torch import torch.nn as nn import torch.onnx -import nneq class xor_mlp(nn.Module): def __init__(self, hidden_dim): @@ -37,15 +36,5 @@ if __name__ == "__main__": torch_net_a = train_model("Network A", 8).eval() torch_net_b = train_model("Network B", 16).eval() - onnx_net_a = torch.onnx.export(torch_net_a, (torch.randn(1, 2),), verbose=False).model_proto # type: ignore - onnx_net_b = torch.onnx.export(torch_net_b, (torch.randn(1, 2),), verbose=False).model_proto # type: ignore - - z3_net_a = nneq.net(onnx_net_a) - z3_net_b = nneq.net(onnx_net_b) - - print("") - nneq.strict_equivalence(z3_net_a, z3_net_b) - print("") - nneq.epsilon_equivalence(z3_net_a, z3_net_b, 0.1) - print("") - nneq.argmax_equivalence(z3_net_a, z3_net_b) + torch.onnx.export(torch_net_a, (torch.randn(1, 2),), "xor_a.onnx") + torch.onnx.export(torch_net_b, (torch.randn(1, 2),), "xor_b.onnx") |
