aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--nneq/nneq.py1
-rw-r--r--notes.norg6
-rw-r--r--proof.norg784
-rw-r--r--xor.py6
4 files changed, 442 insertions, 355 deletions
diff --git a/nneq/nneq.py b/nneq/nneq.py
index d9d7d30..4f46cbf 100644
--- a/nneq/nneq.py
+++ b/nneq/nneq.py
@@ -135,7 +135,6 @@ def inpla_export(model: onnx.ModelProto) -> inpla_str:
return "\n".join(input_script + list(reversed(node_script)) + result_lines)
def inpla_run(model: inpla_str) -> z3_str:
- print(model)
return subprocess.run(["./inpla"], input=f"{rules}\n{model}", capture_output=True, text=True).stdout
syms = {}
diff --git a/notes.norg b/notes.norg
index 4427354..4471531 100644
--- a/notes.norg
+++ b/notes.norg
@@ -4,13 +4,13 @@ description: WIP tool to prove NNEQ using Interaction Nets as pre-processor fo m
authors: ericmarin
categories: research
created: 2026-03-14T09:21:24
-updated: 2026-03-17T11:18:11
+updated: 2026-03-18T13:18:15
version: 1.1.1
@end
* TODO
- (?) Scalability %Maybe done? I have increased the limits of Inpla, but I have yet to test%
- - (x) Soundness of translated NN
+ - (x) Soundness of translated NN: {:proof.norg:}[PROOF]
- ( ) Compatibility with other types of NN
- ( ) Comparison with other tool ({https://github.com/NeuralNetworkVerification/Marabou}[Marabou], {https://github.com/guykatzz/ReluplexCav2017}[Reluplex])
- ( ) Add Range agent to enable ReLU optimization
@@ -26,7 +26,7 @@ version: 1.1.1
- Symbolic(id): represent the variable id
- Add(out, b): represent the addition (has various steps AddCheckLinear/AddCheckConcrete)
- Mul(out, b): represent the multiplication (has various steps MulCheckLinear/MulCheckConcrete)
- - ReLU(out): represent "if x > 0 ? x ; 0"
+ - ReLU(out): represent "IF x > 0 THEN x ELSE 0"
- Materialize(out): transforms a Linear packet into a final representation of TermAdd/TermMul
* Rules
diff --git a/proof.norg b/proof.norg
index 4ec18ce..1379f78 100644
--- a/proof.norg
+++ b/proof.norg
@@ -4,364 +4,452 @@ description:
authors: ericmarin
categories:
created: 2026-03-16T11:34:52
-updated: 2026-03-16T18:31:41
+updated: 2026-03-17T18:59:09
version: 1.1.1
@end
+* Mathematical Definitions
+ - Linear(x, q, r) ~ out => out = q*x + r %with q,r Real%
+ - Concrete(k) ~ out => out = k %with k Real%
+ - Add(out, b) ~ a => out = a + b
+ - AddCheckLinear(out, x, q, r) ~ b => out = q*x + (r + b) %with q,r Real%
+ - AddCheckConcrete(out, k) ~ b => out = k + b %with k Real%
+ - Mul(out, b) ~ a => out = a * b
+ - MulCheckLinear(out, x, q, r) ~ b => out = q*b*x + r*b %with q,r Real%
+ - MulCheckConcrete(out, k) ~ b => out = k*b %with k Real%
+ - ReLU(out) ~ x => out = IF (x > 0) THEN x ELSE 0
+ - Materialize(out) ~ x => out = x
+
* Proof for translation from Pytorch representation to Interaction Net graph
+** ReLU
+ ONNX ReLU node is defined as:
+ /Y = X if X > 0 else 0/
+
+ The translation defines the interactions:
+ /x_i ~ ReLU(y_i)/
+
+ By definition this interaction is equal to:
+ /y_i = IF (x_i >0) THEN x_i ELSE 0/
+
+** Gemm
+ ONNX Gemm node is defined as:
+ Y = alpha * A * B + beta * C
+
+ The translation defines the interactions:
+ /a_i ~ Mul(v_i, Concrete(alpha * b_i))/
+ /Add(...(Add(y_i, v_1), ...), v_n) ~ Concrete(beta * c_i)/
+
+ By definition this interaction is equal to:
+ /v_i = alpha * a_i * b_i/
+ /y_i = v_1 + v_2 + ... + v_n + beta * c_i/
+
+ By grouping the operations we get:
+ /Y = alpha * A * B + beta * C/
* Proof for the Interaction Rules
-** Mathematical Definitions
- - Linear(x, q, r) = q*x + r %with q,r Real%
- - Concrete(k) = k %with k Real%
- - Add(a, b) = a + b
- - AddCheckLinear(x, q, r, b) = q*x + (r + b) %with q,r Real%
- - AddCheckConcrete(k, b) = k + b %with k Real%
- - Mul(a, b) = a * b
- - MulCheckLinear(x, q, r, b) = q*b*x + r*b %with q,r Real%
- - MulCheckConcrete(k, b) = k*b %with k Real%
- - ReLU(x) = IF (x > 0) THEN x ELSE 0
- - Materialize(x) = x
-
-** Rules
-*** Formatting
- Agent1 >< Agent2 => Wiring
-
- LEFT SIDE MATHEMATICAL INTERPRETATION
-
- RIGHT SIDE MATHEMATICAL INTERPRETATION
-
- SHOWING EQUIVALENCE
-
-*** Materialize
- The Materialize agent transforms a Linear agent into a tree of explicit mathematical operations
- that are used as final representation for the solver.
- In the Python module the terms are defined as:
- @code python
+** Materialize
+ The Materialize agent transforms a Linear agent into a tree of explicit mathematical operations
+ that are used as final representation for the solver.
+ In the Python module the terms are defined as:
+ @code python
def TermAdd(a, b):
return a + b
def TermMul(a, b):
return a * b
def TermReLU(x):
return z3.If(x > 0, x, 0)
- @end
-**** Linear(x, q, r) >< Materialize(out) => (1), (2), (3), (4), (5)
-
- Linear(x, q, r) = term
- Materialize(term) = out
- out = q*x + r
-
- $$ Case 1: q = 0 => out ~ Concrete(r), x ~ Eraser
- Concrete(r) = out
- out = r
-
- 0*x + r = r => r = r
- $$
-
- $$ Case 2: q = 1, r = 0 => out ~ x
- x = out
- out = x
-
- 1*x + 0 = x => x = x
- $$
-
- $$ Case 3: q = 1 => out ~ TermAdd(x, Concrete(r))
- TermAdd(x, Concrete(r)) = out
- out = x + r
-
- 1*x + r = x + r => x + r = x + r
- $$
-
- $$ Case 4: r = 0 => out ~ TermMul(Concrete(q), x)
- TermMul(Concrete(q), x) = out
- out = q*x
-
- q*x + 0 = q*x => q*x = q*x
- $$
-
- $$ Case 5: otherwise => out ~ TermAdd(TermMul(Concrete(q), x), Concrete(r))
- TermAdd(TermMul(Concrete(q), x), r) = out
- out = q*x + r
-
- q*x + r = q*x + r
- $$
-
-**** Concrete(k) >< Materialize(out) => out ~ Concrete(k)
-
- Concrete(k) = term
- Materialize(term) = out
- out = k
-
- Concrete(k) = out
- out = k
-
- k = k
-
-*** Add
-**** Linear(x, q, r) >< Add(out, b) => b ~ AddCheckLinear(out, x, q, r)
-
- Linear(x, q, r) = a
- Add(a, b) = out
- out = q*x + r + b
-
- AddCheckLinear(x, q, r, b) = out
- out = q*x + (r + b)
-
- q*x + r + b = q*x + (r + b) => q*x + (r + b) = q*x + (r + b)
-
-**** Concrete(k) >< Add(out, b) => (1), (2)
-
- Concrete(k) = a
- Add(a, b) = out
- out = k + b
-
- $$ Case 1: k = 0 => out ~ b
- b = out
- out = b
-
- 0 + b = b => b = b
- $$
-
- $$ Case 2: otherwise => b ~ AddCheckConcrete(out, k)
- AddCheckConcrete(k, b) = out
- out = k + b
-
- k + b = k + b
- $$
-
-**** Linear(y, s, t) >< AddCheckLinear(out, x, q, r) => (1), (2), (3), (4)
-
- Linear(y, s, t) = b
- AddCheckLinear(x, q, r, b) = out
- out = q*x + (r + s*y + t)
-
- $$ Case 1: q,r,s,t = 0 => out ~ Concrete(0), x ~ Eraser, y ~ Eraser
- Concrete(0) = out
- out = 0
-
- 0*x + (0 + 0*y + 0) = 0 => 0 = 0
- $$
-
- $$ Case 2: s,t = 0 => out ~ Linear(x, q, r), y ~ Eraser
- Linear(x, q, r) = out
- out = q*x + r
-
- q*x + (r + 0*y + 0) = q*x + r => q*x + r = q*x + r
- $$
-
- $$ Case 3: q, r = 0 => out ~ Linear(y, s, t), x ~ Eraser
- Linear(y, s, t) = out
- out = s*y + t
-
- 0*x + (0 + s*y + t) = s*y + t => s*y + t = s*y + t
- $$
-
- $$ Case 4: otherwise => Linear(x, q, r) ~ Materialize(out_x), Linear(y, s, t) ~ Materialize(out_y), out ~ Linear(TermAdd(out_x, out_y), 1, 0)
- Materialize(Linear(x, q, r)) = out_x
- Materialize(Linear(y, s, t)) = out_y
- Linear(TermAdd(out_x, out_y), 1, 0) = out
- out_x = q*x + r
- out_y = s*y + t
- out = 1*TermAdd(q*x + r, s*y + t) + 0
- Because TermAdd(a, b) is defined as "a+b":
- out = 1*(q*x + r + s*y + t) + 0
-
- q*x + (r + s*y + t) = 1*(q*x + r + s*y + t) + 0 => q*x + r + s*y + t = q*x + r + s*y + t
- $$
-
-**** Concrete(j) >< AddCheckLinear(out, x, q, r) => out ~ Linear(x, q, r + j)
-
- Concrete(j) = b
- AddCheckLinear(x, q, r, b) = out
- out = q*x + (r + j)
-
- Linear(x, q, r + j) = out
- out = q*x + (r + j)
-
- q*x + (r + j) = q*x + (r + j)
-
-**** Linear(y, s, t) >< AddCheckConcrete(out, k) => out ~ Linear(y, s, t + k)
-
- Linear(y, s, t) = b
- AddCheckConcrete(k, b) = out
- out = k + s*y + t
-
- Linear(y, s, t + k)
- out = s*y + (t + k)
-
- k + s*y + t = s*y + (t + k) => s*y + (t + k) = s*y + (t + k)
-
-**** Concrete(j) >< AddCheckConcrete(out, k) => (1), (2)
-
- Concrete(j) = b
- AddCheckConcrete(k, b) = out
- out = k + j
-
- $$ Case 1: j = 0 => out ~ Concrete(k)
- Concrete(k) = out
- out = k
-
- k + 0 = k => k = k
- $$
-
- $$ Case 2: otherwise => out ~ Concrete(k + j)
- Concrete(k + j) = out
- out = k + j
-
- k + j = k + j
- $$
-
-*** Mul
-**** Linear(x, q, r) >< Mul(out, b) => b ~ MulCheckLinear(out, x, q, r)
-
- Linear(x, q, r) = a
- Mul(a, b) = out
- out = (q*x + r) * b
-
- MulCheckLinear(x, q, r, b) = out
- out = q*b*x + r*b
-
- (q*x + r) * b = q*b*x + r*b => q*b*x + r*b = q*b*x + r*b
-
-**** Concrete(k) >< Mul(out, b) => (1), (2), (3)
-
- Concrete(k) = a
- Mul(a, b) = out
- out = k * b
-
- $$ Case 1: k = 0 => out ~ Concrete(0), b ~ Eraser
- Concrete(0) = out
- out = 0
-
- 0 * b = 0 => 0 = 0
- $$
-
- $$ Case 2: k = 1 => out ~ b
- b = out
- out = b
-
- 1 * b = b => b = b
- $$
-
- $$ Case 3: otherwise => b ~ MulCheckConcrete(out, k)
- MulCheckConcrete(k, b) = out
- out = k * b
-
- k * b = k * b
- $$
-
-**** Linear(y, s, t) >< MulCheckLinear(out, x, q, r) => (1), (2)
-
- Linear(y, s, t) = b
- MulCheckLinear(x, q, r, b) = out
- out = q\*(s*y + t)\*x + r*(s*y + t)
-
- $$ Case 1: (q,r = 0) or (s,t = 0) => x ~ Eraser, y ~ Eraser, out ~ Concrete(0)
- Concrete(0) = out
- out = 0
-
- 0\*(s*y + t)\*x + 0*(s*y + t) = 0 => 0 = 0
- or
- q\*(0*y + 0)\*x + r*(0*y + 0) = 0 => 0 = 0
- $$
-
- $$ Case 2: otherwise => Linear(x, q, r) ~ Materialize(out_x), Linear(y, s, t) ~ Materialize(out_y), out ~ Linear(TermMul(out_x, out_y), 1, 0)
- Materialize(Linear(x, q, r)) = out_x
- Materialize(Linear(y, s, t)) = out_y
- Linear(TermMul(out_x, out_y), 1, 0) = out
- out_x = q*x + r
- out_y = s*y + t
- out = 1*TermMul(q*x + r, s*y + t) + 0
- Because TermMul(a, b) is defined as "a*b":
- out = 1*(q*x + r)*(s*y + t) + 0
-
- q*(s*y + t)\*x + r*(s*y + t) = 1*(q*x + r)\*(s*y + t) =>
- q\*(s*y + t)\*x + r*(s*y + t) = (q*x + r)\*(s*y + t) =>
- q\*(s*y + t)\*x + r*(s*y + t) = q\*(s*y + t)\*x + r*(s*y + t)
- $$
-
-**** Concrete(j) >< MulCheckLinear(out, x, q, r) => out ~ Linear(x, q * j, r * j)
-
- Concrete(j) = b
- MulCheckLinear(x, q, r, b) = out
- out = q*j*x + r*j
-
- Linear(x, q * j, r * j) = out
- out = q*j*x + r*j
-
- q*j*x + r*j = q*j*x + r*j
-
-**** Linear(y, s, t) >< MulCheckConcrete(out, k) => out ~ Linear(y, s * k, t * k)
-
- Linear(y, s, t) = b
- MulCheckConcrete(k, b) = out
- out = k * (s*y + t)
-
- Linear(y, s * k, t * k) = out
- out = s*k*y + t*k
-
- k * (s*y + t) = s*k*y + t*k => s*k*y + t*k = s*k*y + t*k
-
-**** Concrete(j) >< MulCheckConcrete(out, k) => (1), (2), (3)
-
- Concrete(j) = b
- MulCheckConcrete(k, b) = out
- out = k * j
-
- $$ Case 1: j = 0 => out ~ Concrete(0)
- Concrete(0) = out
- out = 0
-
- k * 0 = 0 => 0 = 0
- $$
-
- $$ Case 2: j = 1 => out ~ Concrete(k)
- Concrete(k) = out
- out = k
-
- k * 1 = k => k = k
- $$
-
- $$ Case 3: otherwise => out ~ Concrete(k * j)
- Concrete(k * j) = out
- out = k * j
-
- k * j = k * j
-
-*** ReLU
-**** Linear(x, q, r) >< ReLU(out) => Linear(x, q, r) ~ Materialize(out_x), out ~ Linear(TermReLU(out_x), 1, 0)
-
- Linear(x, q, r) = x
- ReLU(x) = out
- out = IF (q*x + r) > 0 THEN (q*x + r) ELSE 0
-
- Materialize(Linear(x, q, r)) = out_x
- Linear(TermReLU(out_x), 1, 0) = out
- out_x = q*x + r
- out = 1*TermReLU(q*x + r) + 0
- Because TermReLU(x) is defined as "z3.If(x > 0, x, 0)":
- out = 1*(IF (q*x + r) > 0 THEN (q*x + r) ELSE 0) + 0
-
- IF (q*x + r) > 0 THEN (q*x + r) ELSE 0 = 1*(IF (q*x + r) > 0 THEN (q*x + r) ELSE 0) + 0 =>
- IF (q*x + r) > 0 THEN (q*x + r) ELSE 0 = IF (q*x + r) > 0 THEN (q*x + r) ELSE 0
-
-**** Concrete(k) >< ReLU(out) => (1), (2)
-
- Concrete(k) = x
- ReLU(x) = out
- out = IF k > 0 THEN k ELSE 0
-
- $$ Case 1: k > 0 => out ~ Concrete(k)
- Concrete(k) = out
- out = k
-
- IF true THEN k ELSE 0 = k => k = k
- $$
-
- $$ Case 2: k <= 0 => out ~ Concrete(0)
- Concrete(0) = out
- out = 0
-
- IF false THEN k ELSE 0 = 0 => 0 = 0
- $$
+ @end
+*** Linear(x, q, r) >< Materialize(out) => (1), (2), (3), (4), (5)
+ LHS:
+ Linear(x, q, r) ~ wire
+ Materialize(out) ~ wire
+ q*x + r = wire
+ out = wire
+ out = q*x + r
+
+ $$ Case 1: q = 0 => out ~ Concrete(r), x ~ Eraser
+ RHS:
+ out = r
+
+ EQUIVALENCE:
+ 0*x + r = r => r = r
+ $$
+
+ $$ Case 2: q = 1, r = 0 => out ~ x
+ RHS:
+ x = out
+ out = x
+
+ EQUIVALENCE:
+ 1*x + 0 = x => x = x
+ $$
+
+ $$ Case 3: q = 1 => out ~ TermAdd(x, Concrete(r))
+ RHS:
+ out = x + r
+
+ EQUIVALENCE:
+ 1*x + r = x + r => x + r = x + r
+ $$
+
+ $$ Case 4: r = 0 => out ~ TermMul(Concrete(q), x)
+ RHS:
+ out = q*x
+
+ EQUIVALENCE:
+ q*x + 0 = q*x => q*x = q*x
+ $$
+
+ $$ Case 5: otherwise => out ~ TermAdd(TermMul(Concrete(q), x), Concrete(r))
+ RHS:
+ out = q*x + r
+
+ EQUIVALENCE:
+ q*x + r = q*x + r
+ $$
+
+*** Concrete(k) >< Materialize(out) => out ~ Concrete(k)
+ LHS:
+ Concrete(k) ~ wire
+ Materialize(out) ~ wire
+ k = wire
+ out = wire
+ out = k
+
+ RHS:
+ out = k
+
+ EQUIVALENCE:
+ k = k
+
+** Add
+*** Linear(x, q, r) >< Add(out, b) => b ~ AddCheckLinear(out, x, q, r)
+ LHS:
+ Linear(x, q, r) ~ wire
+ Add(out, b) ~ wire
+ q*x + r = wire
+ out = wire + b
+ out = q*x + r + b
+
+ RHS:
+ out = q*x + (r + b)
+
+ EQUIVALENCE:
+ q*x + r + b = q*x + (r + b) => q*x + (r + b) = q*x + (r + b)
+
+*** Concrete(k) >< Add(out, b) => (1), (2)
+ LHS:
+ Concrete(k) ~ wire
+ Add(out, b) ~ wire
+ k = wire
+ out = wire + b
+ out = k + b
+
+ $$ Case 1: k = 0 => out ~ b
+ RHS:
+ out = b
+
+ EQUIVALENCE:
+ 0 + b = b => b = b
+ $$
+
+ $$ Case 2: otherwise => b ~ AddCheckConcrete(out, k)
+ RHS:
+ out = k + b
+
+ EQUIVALENCE:
+ k + b = k + b
+ $$
+
+*** Linear(y, s, t) >< AddCheckLinear(out, x, q, r) => (1), (2), (3), (4)
+ LHS:
+ Linear(y, s, t) ~ wire
+ AddCheckLinear(out, x, q, r) ~ wire
+ s*y + t = wire
+ out = q*x + (r + wire)
+ out = q*x + (r + s*y + t)
+
+ $$ Case 1: q,r,s,t = 0 => out ~ Concrete(0), x ~ Eraser, y ~ Eraser
+ RHS:
+ out = 0
+
+ EQUIVALENCE:
+ 0*x + (0 + 0*y + 0) = 0 => 0 = 0
+ $$
+
+ $$ Case 2: s,t = 0 => out ~ Linear(x, q, r), y ~ Eraser
+ RHS:
+ out = q*x + r
+
+ EQUIVALENCE:
+ q*x + (r + 0*y + 0) = q*x + r => q*x + r = q*x + r
+ $$
+
+ $$ Case 3: q, r = 0 => out ~ Linear(y, s, t), x ~ Eraser
+ RHS:
+ out = s*y + t
+
+ EQUIVALENCE:
+ 0*x + (0 + s*y + t) = s*y + t => s*y + t = s*y + t
+ $$
+
+ $$ Case 4: otherwise => Linear(x, q, r) ~ Materialize(out_x), Linear(y, s, t) ~ Materialize(out_y), out ~ Linear(TermAdd(out_x, out_y), 1, 0)
+ RHS:
+ Linear(x, q, r) ~ wire1
+ Materialize(out_x) ~ wire1
+ q*x + r = wire1
+ out_x = wire1
+ Linear(y, s, t) ~ wire2
+ Materialize(out_y) ~ wire2
+ s*y + t = wire2
+ out_y = wire2
+ out = 1*TermAdd(out_x, out_y) + 0
+ Because TermAdd(a, b) is defined as "a+b":
+ out = 1*(q*x + r + s*y + t) + 0
+
+ EQUIVALENCE:
+ q*x + (r + s*y + t) = 1*(q*x + r + s*y + t) + 0 => q*x + r + s*y + t = q*x + r + s*y + t
+ $$
+
+*** Concrete(j) >< AddCheckLinear(out, x, q, r) => out ~ Linear(x, q, r + j)
+ LHS:
+ Concrete(j) ~ wire
+ AddCheckLinear(out, x, q, r) ~ wire
+ j = wire
+ out = q*x + (r + wire)
+ out = q*x + (r + j)
+
+ RHS:
+ out = q*x + (r + j)
+
+ EQUIVALENCE:
+ q*x + (r + j) = q*x + (r + j)
+
+*** Linear(y, s, t) >< AddCheckConcrete(out, k) => out ~ Linear(y, s, t + k)
+ LHS:
+ Linear(y, s, t) ~ wire
+ AddCheckConcrete(out, k) ~ wire
+ s*y + t = wire
+ out = k + wire
+ out = k + s*y + t
+
+ RHS:
+ out = s*y + (t + k)
+
+ EQUIVALENCE:
+ k + s*y + t = s*y + (t + k) => s*y + (t + k) = s*y + (t + k)
+
+*** Concrete(j) >< AddCheckConcrete(out, k) => (1), (2)
+ LHS:
+ Concrete(j) ~ wire
+ AddCheckConcrete(out, k) ~ wire
+ j = wire
+ out = k + wire
+ out = k + j
+
+ $$ Case 1: j = 0 => out ~ Concrete(k)
+ RHS:
+ out = k
+
+ EQUIVALENCE:
+ k + 0 = k => k = k
+ $$
+
+ $$ Case 2: otherwise => out ~ Concrete(k + j)
+ RHS:
+ out = k + j
+
+ EQUIVALENCE:
+ k + j = k + j
+ $$
+
+** Mul
+*** Linear(x, q, r) >< Mul(out, b) => b ~ MulCheckLinear(out, x, q, r)
+ LHS:
+ Linear(x, q, r) ~ wire
+ Mul(out, b) ~ wire
+ q*x + r = wire
+ out = wire * b
+ out = (q*x + r) * b
+
+ RHS:
+ out = q*b*x + r*b
+
+ EQUIVALENCE:
+ (q*x + r) * b = q*b*x + r*b => q*b*x + r*b = q*b*x + r*b
+
+*** Concrete(k) >< Mul(out, b) => (1), (2), (3)
+ LHS:
+ Concrete(k) ~ wire
+ Mul(out, b) ~ wire
+ k = wire
+ out = wire * b
+ out = k * b
+
+ $$ Case 1: k = 0 => out ~ Concrete(0), b ~ Eraser
+ RHS:
+ out = 0
+
+ EQUIVALENCE:
+ 0 * b = 0 => 0 = 0
+ $$
+
+ $$ Case 2: k = 1 => out ~ b
+ RHS:
+ out = b
+
+ EQUIVALENCE:
+ 1 * b = b => b = b
+ $$
+
+ $$ Case 3: otherwise => b ~ MulCheckConcrete(out, k)
+ RHS:
+ out = k * b
+
+ EQUIVALENCE:
+ k * b = k * b
+ $$
+
+*** Linear(y, s, t) >< MulCheckLinear(out, x, q, r) => (1), (2)
+ LHS:
+ Linear(y, s, t) ~ wire
+ MulCheckLinear(out, x, q, r) ~ wire
+ s*y + t = wire
+ out = q*wire*x + r*wire
+ out = q\*(s*y + t)\*x + r*(s*y + t)
+
+ $$ Case 1: (q,r = 0) or (s,t = 0) => x ~ Eraser, y ~ Eraser, out ~ Concrete(0)
+ RHS:
+ out = 0
+
+ EQUIVALENCE:
+ 0\*(s*y + t)\*x + 0*(s*y + t) = 0 => 0 = 0
+ or
+ q\*(0*y + 0)\*x + r*(0*y + 0) = 0 => 0 = 0
+ $$
+
+ $$ Case 2: otherwise => Linear(x, q, r) ~ Materialize(out_x), Linear(y, s, t) ~ Materialize(out_y), out ~ Linear(TermMul(out_x, out_y), 1, 0)
+ RHS:
+ Linear(x, q, r) ~ wire1
+ Materialize(out_x) ~ wire1
+ q*x + r = wire1
+ out_x = wire1
+ Linear(y, s, t) ~ wire2
+ Materialize(out_y) ~ wire2
+ s*y + t = wire2
+ out_y = wire2
+ out = 1*TermMul(out_x, out_y) + 0
+ Because TermMul(a, b) is defined as "a*b":
+ out = 1*(q*x + r)*(s*y + t) + 0
+
+ EQUIVALENCE:
+ q*(s*y + t)\*x + r*(s*y + t) = 1*(q*x + r)\*(s*y + t) =>
+ q\*(s*y + t)\*x + r*(s*y + t) = (q*x + r)\*(s*y + t) =>
+ q\*(s*y + t)\*x + r*(s*y + t) = q\*(s*y + t)\*x + r*(s*y + t)
+ $$
+
+*** Concrete(j) >< MulCheckLinear(out, x, q, r) => out ~ Linear(x, q * j, r * j)
+ LHS:
+ Concrete(j) ~ wire
+ MulCheckLinear(out, x, q, r) ~ wire
+ j = wire
+ out = q*wire*x + r*wire
+ out = q*j*x + r*j
+
+ RHS:
+ out = q*j*x + r*j
+
+ EQUIVALENCE:
+ q*j*x + r*j = q*j*x + r*j
+
+*** Linear(y, s, t) >< MulCheckConcrete(out, k) => out ~ Linear(y, s * k, t * k)
+ LHS:
+ Linear(y, s, t) ~ wire
+ MulCheckConcrete(out, k) ~ wire
+ s*y + t = wire
+ out = k * wire
+ out = k * (s*y + t)
+
+ RHS:
+ out = s*k*y + t*k
+
+ EQUIVALENCE:
+ k * (s*y + t) = s*k*y + t*k => s*k*y + t*k = s*k*y + t*k
+
+*** Concrete(j) >< MulCheckConcrete(out, k) => (1), (2), (3)
+ LHS:
+ Concrete(j) ~ wire
+ MulCheckConcrete(out, k) ~ wire
+ j = wire
+ out = k * wire
+ out = k * j
+
+ $$ Case 1: j = 0 => out ~ Concrete(0)
+ RHS:
+ out = 0
+
+ EQUIVALENCE:
+ k * 0 = 0 => 0 = 0
+ $$
+
+ $$ Case 2: j = 1 => out ~ Concrete(k)
+ RHS:
+ out = k
+
+ EQUIVALENCE:
+ k * 1 = k => k = k
+ $$
+
+ $$ Case 3: otherwise => out ~ Concrete(k * j)
+ RHS:
+ out = k * j
+
+ EQUIVALENCE:
+ k * j = k * j
+
+** ReLU
+*** Linear(x, q, r) >< ReLU(out) => Linear(x, q, r) ~ Materialize(out_x), out ~ Linear(TermReLU(out_x), 1, 0)
+ LHS:
+ Linear(x, q, r) ~ wire
+ ReLU(out) ~ wire
+ q*x + r = wire
+ out = IF wire > 0 THEN wire ELSE 0
+ out = IF (q*x + r) > 0 THEN (q*x + r) ELSE 0
+
+ RHS:
+ Linear(x, q, r) ~ wire
+ Materialize(out_x) ~ wire
+ q*x + r = wire
+ out_x = wire
+ out = 1*TermReLU(out_x) + 0
+ Because TermReLU(x) is defined as "z3.If(x > 0, x, 0)":
+ out = 1*(IF (q*x + r) > 0 THEN (q*x + r) ELSE 0) + 0
+
+ EQUIVALENCE:
+ IF (q*x + r) > 0 THEN (q*x + r) ELSE 0 = 1*(IF (q*x + r) > 0 THEN (q*x + r) ELSE 0) + 0 =>
+ IF (q*x + r) > 0 THEN (q*x + r) ELSE 0 = IF (q*x + r) > 0 THEN (q*x + r) ELSE 0
+
+*** Concrete(k) >< ReLU(out) => (1), (2)
+ LHS:
+ Concrete(k) ~ wire
+ ReLU(out) ~ wire
+ k = wire
+ out = IF wire > 0 THEN wire ELSE 0
+ out = IF k > 0 THEN k ELSE 0
+
+ $$ Case 1: k > 0 => out ~ Concrete(k)
+ RHS:
+ out = k
+
+ EQUIVALENCE:
+ IF true THEN k ELSE 0 = k => k = k
+ $$
+
+ $$ Case 2: k <= 0 => out ~ Concrete(0)
+ RHS:
+ out = 0
+
+ EQUIVALENCE:
+ IF false THEN k ELSE 0 = 0 => 0 = 0
+ $$
diff --git a/xor.py b/xor.py
index 0f8390d..82a16b8 100644
--- a/xor.py
+++ b/xor.py
@@ -4,7 +4,7 @@ import torch.onnx
import nneq
class xor_mlp(nn.Module):
- def __init__(self, hidden_dim=8):
+ def __init__(self, hidden_dim):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(2, hidden_dim),
@@ -37,8 +37,8 @@ if __name__ == "__main__":
torch_net_a = train_model("Network A", 8).eval()
torch_net_b = train_model("Network B", 16).eval()
- onnx_net_a = torch.onnx.export(torch_net_a, (torch.randn(1, 2),), verbose=False, dynamo=True).model_proto # type: ignore
- onnx_net_b = torch.onnx.export(torch_net_b, (torch.randn(1, 2),), verbose=False, dynamo=True).model_proto # type: ignore
+ onnx_net_a = torch.onnx.export(torch_net_a, (torch.randn(1, 2),), verbose=False).model_proto # type: ignore
+ onnx_net_b = torch.onnx.export(torch_net_b, (torch.randn(1, 2),), verbose=False).model_proto # type: ignore
z3_net_a = nneq.net(onnx_net_a)
z3_net_b = nneq.net(onnx_net_b)