aboutsummaryrefslogtreecommitdiff
path: root/xor.py
diff options
context:
space:
mode:
authorericmarin <maarin.eric@gmail.com>2026-03-12 15:37:53 +0100
committerericmarin <maarin.eric@gmail.com>2026-03-12 15:56:24 +0100
commit19652ec48be4c6faf3f7815a9281b611aed94727 (patch)
tree94ab6abc2c50835d2d9a51242025f8ad13b7f7c5 /xor.py
parentfb544e2089e0c52bd83ffe56f2f4e8d7176564ee (diff)
downloadvein-19652ec48be4c6faf3f7815a9281b611aed94727.tar.gz
vein-19652ec48be4c6faf3f7815a9281b611aed94727.zip
changed to float
Diffstat (limited to 'xor.py')
-rw-r--r--xor.py12
1 files changed, 6 insertions, 6 deletions
diff --git a/xor.py b/xor.py
index 905ddfb..493eaef 100644
--- a/xor.py
+++ b/xor.py
@@ -37,7 +37,7 @@ def get_rules() -> str:
rules_lines.append(line)
return "".join(rules_lines)
-def export_to_inpla_wiring(model: nn.Module, input_shape: tuple, scale: int = 1000) -> str:
+def export_to_inpla_wiring(model: nn.Module, input_shape: tuple) -> str:
traced = fx.symbolic_trace(model)
name_gen = NameGen()
script: List[str] = []
@@ -46,7 +46,7 @@ def export_to_inpla_wiring(model: nn.Module, input_shape: tuple, scale: int = 10
for node in traced.graph.nodes:
if node.op == 'placeholder':
num_inputs = int(np.prod(input_shape))
- wire_map[node.name] = [f"Linear(Symbolic({i}), 1, 0)" for i in range(num_inputs)]
+ wire_map[node.name] = [f"Linear(Symbolic(X_{i}), 1.0, 0.0)" for i in range(num_inputs)]
elif node.op == 'call_module':
target_str = str(node.target)
@@ -61,8 +61,8 @@ def export_to_inpla_wiring(model: nn.Module, input_shape: tuple, scale: int = 10
wire_map[node.name] = input_wires
elif isinstance(module, nn.Linear):
- W = (module.weight.data.detach().cpu().numpy() * scale).astype(int)
- B = (module.bias.data.detach().cpu().numpy() * scale).astype(int)
+ W = (module.weight.data.detach().cpu().numpy()).astype(float)
+ B = (module.bias.data.detach().cpu().numpy()).astype(float)
out_dim, in_dim = W.shape
neuron_wires = [f"Concrete({B[j]})" for j in range(out_dim)]
@@ -70,7 +70,7 @@ def export_to_inpla_wiring(model: nn.Module, input_shape: tuple, scale: int = 10
for i in range(in_dim):
in_term = input_wires[i]
if out_dim == 1:
- weight = int(W[0, i])
+ weight = float(W[0, i])
if weight == 0:
script.append(f"Eraser ~ {in_term};")
elif weight == 1:
@@ -94,7 +94,7 @@ def export_to_inpla_wiring(model: nn.Module, input_shape: tuple, scale: int = 10
script.append(f"{nest_dups(branch_wires)} ~ {in_term};")
for j in range(out_dim):
- weight = int(W[j, i])
+ weight = float(W[j, i])
if weight == 0:
script.append(f"Eraser ~ {branch_wires[j]};")
elif weight == 1: