aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorericmarin <maarin.eric@gmail.com>2026-03-09 18:18:34 +0100
committerericmarin <maarin.eric@gmail.com>2026-03-09 20:00:47 +0100
commit51cd389b4e322313671dd0e53513ce84b72a1652 (patch)
tree55d6ec8e8ad89ab1b754525171fbcd0cd177430b
parent2152db1e181609a3c8e686ce647079c6a04c6740 (diff)
downloadvein-51cd389b4e322313671dd0e53513ce84b72a1652.tar.gz
vein-51cd389b4e322313671dd0e53513ce84b72a1652.zip
linear folding
-rw-r--r--nn.in95
-rw-r--r--nn.in~71
-rw-r--r--parser.py53
-rw-r--r--test.smt27
4 files changed, 112 insertions, 114 deletions
diff --git a/nn.in b/nn.in
index fb9a863..cfb5de9 100644
--- a/nn.in
+++ b/nn.in
@@ -1,79 +1,88 @@
// Rules
-Symbolic(id) >< Add(out, b) => b ~ AddCheck(out, (*L)Symbolic(id));
+Linear(x, int q, int r) >< Add(out, b) => b ~ AddCheckLinear(out, x, q, r);
+
Concrete(int k) >< Add(out, b)
| k == 0 => out ~ b
| _ => b ~ AddCheckConcrete(out, k);
-other >< Add(out, b) => out ~ Add((*L)other, b);
-Symbolic(id) >< AddCheck(out, a) => out ~ Add(a, (*L)Symbolic(id));
-Concrete(int j) >< AddCheck(out, a)
- | j == 0 => out ~ a
- | _ => out ~ Add(a, (*L)Concrete(j));
-other >< AddCheck(out, a) => out ~ Add(a, (*L)other);
+Linear(y, int s, int t) >< AddCheckLinear(out, x, int q, int r)
+ | q == 0 && r == 0 && s == 0 && t == 0 => out ~ Concrete(0), x ~ Eraser, y ~ Eraser
+ | s == 0 && t == 0 => Linear(x, q, r) ~ Materialize(out), y ~ Eraser
+ | q == 0 && r == 0 => (*L)Linear(y, s, t) ~ Materialize(out), x ~ Eraser
+ | _ => Linear(x, q, r) ~ Materialize(out_x), (*L)Linear(y, s, r) ~ Materialize(out_y), out ~ TermAdd(out_x, out_y);
+
+Concrete(int j) >< AddCheckLinear(out, x, int q, int r) => out ~ Linear(x, q, r + j);
+
+Linear(y, int s, int t) >< AddCheckConcrete(out, int k) => out ~ Linear(y, s, t + k);
-Symbolic(id) >< AddCheckConcrete(out, int k) => out ~ Add(Concrete(k), (*L) Symbolic(id));
Concrete(int j) >< AddCheckConcrete(out, int k)
| j == 0 => out ~ Concrete(k)
| _ => out ~ Concrete(k + j);
-other >< AddCheckConcrete(out, int k) => out ~ Add(Concrete(k), (*L)other);
-Symbolic(id) >< Mul(out, b) => b ~ MulCheck(out, (*L)Symbolic(id));
+Linear(x, int q, int r) >< Mul(out, b) => b ~ MulCheck(out, x, q, r);
+
Concrete(int k) >< Mul(out, b)
| k == 0 => b ~ Eraser, out ~ (*L)Concrete(0)
| k == 1 => out ~ b
| _ => b ~ MulCheckConcrete(out, k);
-other >< Mul(out, b) => out ~ Mul((*L)other, b);
-Symbolic(id) >< MulCheck(out, a) => out ~ Mul(a, (*L)Symbolic(id));
-Concrete(int j) >< MulCheck(out, a)
- | j == 0 => a ~ Eraser, out ~ (*L)Concrete(0)
- | j == 1 => out ~ a
- | _ => out ~ Mul(a, (*L)Concrete(j));
-other >< MulCheck(out, a) => out ~ Mul(a, (*L)other);
+Linear(y, int s, int t) >< MulCheck(out, x, int q, int r)
+ | q == 0 && r == 0 && s == 0 && t == 0 => out ~ Concrete(0), x ~ Eraser, y ~ Eraser
+ | s == 0 && t == 0 => Linear(x, q, r) ~ Materialize(out), y ~ Eraser
+ | q == 0 && r == 0 => (*L)Linear(y, s, t) ~ Materialize(out), x ~ Eraser
+ | _ => Linear(x, q, r) ~ Materialize(out_x), (*L)Linear(y, s, r) ~ Materialize(out_y), out ~ TermMul(out_x, out_y);
+
+Concrete(int j) >< MulCheck(out, x, int q, int r) => out ~ Linear(x, q * j, r * j);
+
+Linear(y, int s, int t) >< MulCheckConcrete(out, int k) => out ~ Linear(y, s * k, t * k);
-Symbolic(id) >< MulCheckConcrete(out, int k) => out ~ Mul(Concrete(k), (*L)Symbolic(id));
Concrete(int j) >< MulCheckConcrete(out, int k)
- | j == 0 => out ~ Concrete(0)
+ | j == 0 => out ~ Concrete(0)
| j == 1 => out ~ Concrete(k)
| _ => out ~ Concrete(k * j);
-other >< MulCheckConcrete(out, int k) => out ~ Mul(Concrete(k), (*L)other);
-Symbolic(id) >< ReLU(out) => out ~ ReLU((*L)Symbolic(id));
+Linear(x, int q, int r) >< ReLU(out) => (*L)Linear(x, q, r) ~ Materialize(out_x), out ~ TermReLU(out_x);
Concrete(int k) >< ReLU(out)
| k > 0 => out ~ (*L)Concrete(k)
| _ => out ~ Concrete(0);
-other >< ReLU(out) => out ~ ReLU((*L)other);
+
+Linear(x, int q, int r) >< Materialize(out)
+ | ? => out ~ Concrete(r), x ~ Eraser
+ | ? => out ~ Symbolic(x)
+ | ? => out ~ TermAdd(x, Concrete(r))
+ | ? => out ~ TermMul(Concrete(q), x)
+ | _ => out ~ TermAdd(TermMul(Concrete(q), x), Concrete(r));
// Net
-Symbolic(X) ~ Add(out, b);
-Concrete(0) ~ Mul(b, Symbolic(Y));
-out;
+Linear(X, 1, 0) ~ Add(a, b);
+Concrete(0) ~ Mul(b, Linear(Y, 1, 0));
+a ~ Materialize(out);
+out; // Symbolic(X)
+free out a b;
-free out;
-Symbolic(X) ~ Mul(out, Concrete(0));
-out;
+Linear(X, 1, 0) ~ Mul(a, Concrete(0));
+a ~ Materialize(out);
+out; // Concrete(0)
+free out a;
-free out;
-Symbolic(X) ~ Mul(out, b);
-Symbolic(Y) ~ Add(b, Concrete(1));
-out;
+Linear(X, 1, 0) ~ Mul(a, b);
+Linear(Y, 1, 0) ~ Add(b, Concrete(1));
+out; // Mul(Symbolic(X),Add(Symbolic(Y),Concrete(1)))
+free out a b;
-free out;
Concrete(1) ~ Add(out, b);
Concrete(2) ~ Add(b, c);
Concrete(3) ~ Mul(c, Concrete(2));
-out;
+out; // Concrete(9)
+free out a b c;
-free out;
-Symbolic(X) ~ Mul(a, Symbolic(W1));
-Symbolic(Y) ~ Mul(b, Symbolic(W2));
-b ~ Add(c, Symbolic(B));
+Linear(X, 1, 0) ~ Mul(a, Linear(W1, 1, 0));
+Linear(Y, 1, 0) ~ Mul(b, Linear(W2, 1, 0));
+b ~ Add(c, Linear(B, 1, 0));
a ~ Add(d, c);
d ~ ReLU(out);
-out;
+out; // ReLU(Add(Mul(Symbolic(X),Symbolic(W1)),Add(Mul(Symbolic(Y),Symbolic(W2)),Symbolic(B))))
+free out a b c d;
-free out;
-Symbolic(X) ~ Dup(a, b);
-a ~ Add(c, Concrete(1));
-b ~ Mul(out, c);
+Linear(X, 2, 10) ~ Materialize(out);
out;
diff --git a/nn.in~ b/nn.in~
deleted file mode 100644
index dd89d03..0000000
--- a/nn.in~
+++ /dev/null
@@ -1,71 +0,0 @@
-// Agents
-
-Any >< Optimize(out) =>
- out ~ Any;
-
-// Rules
-
-// Optimizer interacts with Const
-Const(k) >< Optimize(out) =>
- out ~ Const(k);
-
-
-// Optimizer interacts with Add
-Add(a, b) >< Optimize(out) =>
- a ~ Optimize(a_opt),
- a_opt ~ Add_CheckLeft(out, b);
-
-// CHECK LEFT
-Const(int k) >< Add_CheckLeft(out, b)
- | k == 0 => b ~ Optimize(b_opt), out ~ b_opt
- | _ => b ~ Optimize(b_opt), b_opt ~ Add_CheckRight_WithConst(out, k);
-a_opt >< Add_CheckLeft(out, b) =>
- b ~ Optimize(b_opt),
- b_opt ~ Add_CheckRight_WithAny(out, (*L)a_opt);
-
-// CHECK RIGHT
-Const(int j) >< Add_CheckRight_WithConst(out, int k) =>
- out ~ Const(k + j);
-b_opt >< Add_CheckRight_WithConst(out, int k) =>
- out ~ Add(Const(k), (*L)b_opt);
-Const(int j) >< Add_CheckRight_WithAny(out, a_opt)
- | j == 0 => out ~ a_opt
- | _ => out ~ Add(a_opt, Const(j));
-b_opt >< Add_CheckRight_WithAny(out, a_opt) =>
- out ~ Add(a_opt, (*L)b_opt);
-
-
-// Optimizer interacts with Mul
-Mul(a, b) >< Optimize(out) =>
- a ~ Optimize(a_opt),
- a_opt ~ Mul_CheckLeft(out, b);
-
-// Check LEFT
-Const(int k) >< Mul_CheckLeft(out, b)
- | k == 0 => b ~ Eraser, out ~ Const(0)
- | k == 1 => b ~ Optimize(b_opt), out ~ b_opt
- | _ => b ~ Optimize(b_opt), b_opt ~ Mul_CheckRight_WithConst(out, k);
-a_opt >< Mul_CheckLeft(out, b) =>
- b ~ Optimize(b_opt),
- b_opt ~ Mul_CheckRight_WithAny(out, (*L)a_opt);
-
-// Check RIGHT
-Const(int j) >< Mul_CheckRight_WithConst(out, int k) =>
- out ~ Const(k * j);
-b_opt >< Mul_CheckRight_WithConst(out, int k) =>
- out ~ Mul(Const(k), (*L)b_opt);
-Const(int j) >< Mul_CheckRight_WithAny(out, a_opt)
- | j == 0 => Eraser ~ a_opt, out ~ Const(0)
- | j == 1 => out ~ a_opt
- | _ => out ~ Mul(a_opt, Const(j));
-b_opt >< Mul_CheckRight_WithAny(out, a_opt) =>
- out ~ Mul(a_opt, (*L)b_opt);
-
-
-// Net
-Mul(Const(1), Add(Any, Mul(Const(0), Add(Any, Any)))) ~ Optimize(out);
-out;
-
-free out;
-Mul(Const(0), Add(Any, Any)) ~ Optimize(out);
-out;
diff --git a/parser.py b/parser.py
new file mode 100644
index 0000000..e66f9d2
--- /dev/null
+++ b/parser.py
@@ -0,0 +1,53 @@
+
+class Concrete:
+ def __init__(self, val): self.val = val
+ def __str__(self): return str(self.val)
+
+class Symbolic:
+ def __init__(self, id): self.id = id
+ def __str__(self): return f"x_{self.id}"
+
+class TermAdd:
+ def __init__(self, a, b): self.a, self.b = a, b
+ def __str__(self): return f"(+ {self.a} {self.b})"
+
+class TermMul:
+ def __init__(self, a, b): self.a, self.b = a, b
+ def __str__(self): return f"(* {self.a} {self.b})"
+
+class TermReLU:
+ def __init__(self, a): self.a = a
+ def __str__(self): return f"(ite (> {self.a} 0) {self.a} 0)"
+
+def generate_z3(net_a_str, net_b_str, epsilon=1e-5):
+ context = {
+ 'Concrete': Concrete,
+ 'Symbolic': Symbolic,
+ 'TermAdd': TermAdd,
+ 'TermMul': TermMul,
+ 'TermReLU': TermReLU
+ }
+
+ try:
+ tree_a = eval(net_a_str, context)
+ tree_b = eval(net_b_str, context)
+ except Exception as e:
+ print(f"; Error parsing Inpla output: {e}")
+ return
+
+ print("(declare-const x_0 Real)")
+ print("(declare-const x_1 Real)")
+
+ print(f"(define-fun net_a () Real {tree_a})")
+ print(f"(define-fun net_b () Real {tree_b})")
+
+ print(f"(assert (> (abs (- net_a net_b)) {epsilon:.10f}))")
+
+ print("(check-sat)")
+ print("(get-model)")
+
+if __name__ == "__main__":
+ output_net_a = "TermAdd(TermMul(Symbolic(0), Concrete(2)), Concrete(3))" # 2x + 3
+ output_net_b = "TermAdd(Concrete(3), TermMul(Concrete(2), Symbolic(0)))" # 3 + 2x
+
+ generate_z3(output_net_a, output_net_b)
diff --git a/test.smt2 b/test.smt2
new file mode 100644
index 0000000..79d05c6
--- /dev/null
+++ b/test.smt2
@@ -0,0 +1,7 @@
+(declare-const x_0 Real)
+(declare-const x_1 Real)
+(define-fun net_a () Real (+ (* x_0 2) 3))
+(define-fun net_b () Real (+ 3 (* 2 x_0)))
+(assert (> (abs (- net_a net_b)) 0.0000100000))
+(check-sat)
+(get-model)