From 19652ec48be4c6faf3f7815a9281b611aed94727 Mon Sep 17 00:00:00 2001 From: ericmarin Date: Thu, 12 Mar 2026 15:37:53 +0100 Subject: changed to float --- xor.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) (limited to 'xor.py') diff --git a/xor.py b/xor.py index 905ddfb..493eaef 100644 --- a/xor.py +++ b/xor.py @@ -37,7 +37,7 @@ def get_rules() -> str: rules_lines.append(line) return "".join(rules_lines) -def export_to_inpla_wiring(model: nn.Module, input_shape: tuple, scale: int = 1000) -> str: +def export_to_inpla_wiring(model: nn.Module, input_shape: tuple) -> str: traced = fx.symbolic_trace(model) name_gen = NameGen() script: List[str] = [] @@ -46,7 +46,7 @@ def export_to_inpla_wiring(model: nn.Module, input_shape: tuple, scale: int = 10 for node in traced.graph.nodes: if node.op == 'placeholder': num_inputs = int(np.prod(input_shape)) - wire_map[node.name] = [f"Linear(Symbolic({i}), 1, 0)" for i in range(num_inputs)] + wire_map[node.name] = [f"Linear(Symbolic(X_{i}), 1.0, 0.0)" for i in range(num_inputs)] elif node.op == 'call_module': target_str = str(node.target) @@ -61,8 +61,8 @@ def export_to_inpla_wiring(model: nn.Module, input_shape: tuple, scale: int = 10 wire_map[node.name] = input_wires elif isinstance(module, nn.Linear): - W = (module.weight.data.detach().cpu().numpy() * scale).astype(int) - B = (module.bias.data.detach().cpu().numpy() * scale).astype(int) + W = (module.weight.data.detach().cpu().numpy()).astype(float) + B = (module.bias.data.detach().cpu().numpy()).astype(float) out_dim, in_dim = W.shape neuron_wires = [f"Concrete({B[j]})" for j in range(out_dim)] @@ -70,7 +70,7 @@ def export_to_inpla_wiring(model: nn.Module, input_shape: tuple, scale: int = 10 for i in range(in_dim): in_term = input_wires[i] if out_dim == 1: - weight = int(W[0, i]) + weight = float(W[0, i]) if weight == 0: script.append(f"Eraser ~ {in_term};") elif weight == 1: @@ -94,7 +94,7 @@ def export_to_inpla_wiring(model: nn.Module, input_shape: tuple, scale: int = 10 script.append(f"{nest_dups(branch_wires)} ~ {in_term};") for j in range(out_dim): - weight = int(W[j, i]) + weight = float(W[j, i]) if weight == 0: script.append(f"Eraser ~ {branch_wires[j]};") elif weight == 1: -- cgit v1.2.3