diff options
Diffstat (limited to '')
| -rw-r--r-- | nneq/nneq.py | 1 | ||||
| -rw-r--r-- | notes.norg | 6 | ||||
| -rw-r--r-- | proof.norg | 784 | ||||
| -rw-r--r-- | xor.py | 6 |
4 files changed, 442 insertions, 355 deletions
diff --git a/nneq/nneq.py b/nneq/nneq.py index d9d7d30..4f46cbf 100644 --- a/nneq/nneq.py +++ b/nneq/nneq.py @@ -135,7 +135,6 @@ def inpla_export(model: onnx.ModelProto) -> inpla_str: return "\n".join(input_script + list(reversed(node_script)) + result_lines) def inpla_run(model: inpla_str) -> z3_str: - print(model) return subprocess.run(["./inpla"], input=f"{rules}\n{model}", capture_output=True, text=True).stdout syms = {} @@ -4,13 +4,13 @@ description: WIP tool to prove NNEQ using Interaction Nets as pre-processor fo m authors: ericmarin categories: research created: 2026-03-14T09:21:24 -updated: 2026-03-17T11:18:11 +updated: 2026-03-18T13:18:15 version: 1.1.1 @end * TODO - (?) Scalability %Maybe done? I have increased the limits of Inpla, but I have yet to test% - - (x) Soundness of translated NN + - (x) Soundness of translated NN: {:proof.norg:}[PROOF] - ( ) Compatibility with other types of NN - ( ) Comparison with other tool ({https://github.com/NeuralNetworkVerification/Marabou}[Marabou], {https://github.com/guykatzz/ReluplexCav2017}[Reluplex]) - ( ) Add Range agent to enable ReLU optimization @@ -26,7 +26,7 @@ version: 1.1.1 - Symbolic(id): represent the variable id - Add(out, b): represent the addition (has various steps AddCheckLinear/AddCheckConcrete) - Mul(out, b): represent the multiplication (has various steps MulCheckLinear/MulCheckConcrete) - - ReLU(out): represent "if x > 0 ? x ; 0" + - ReLU(out): represent "IF x > 0 THEN x ELSE 0" - Materialize(out): transforms a Linear packet into a final representation of TermAdd/TermMul * Rules @@ -4,364 +4,452 @@ description: authors: ericmarin categories: created: 2026-03-16T11:34:52 -updated: 2026-03-16T18:31:41 +updated: 2026-03-17T18:59:09 version: 1.1.1 @end +* Mathematical Definitions + - Linear(x, q, r) ~ out => out = q*x + r %with q,r Real% + - Concrete(k) ~ out => out = k %with k Real% + - Add(out, b) ~ a => out = a + b + - AddCheckLinear(out, x, q, r) ~ b => out = q*x + (r + b) %with q,r Real% + - AddCheckConcrete(out, k) ~ b => out = k + b %with k Real% + - Mul(out, b) ~ a => out = a * b + - MulCheckLinear(out, x, q, r) ~ b => out = q*b*x + r*b %with q,r Real% + - MulCheckConcrete(out, k) ~ b => out = k*b %with k Real% + - ReLU(out) ~ x => out = IF (x > 0) THEN x ELSE 0 + - Materialize(out) ~ x => out = x + * Proof for translation from Pytorch representation to Interaction Net graph +** ReLU + ONNX ReLU node is defined as: + /Y = X if X > 0 else 0/ + + The translation defines the interactions: + /x_i ~ ReLU(y_i)/ + + By definition this interaction is equal to: + /y_i = IF (x_i >0) THEN x_i ELSE 0/ + +** Gemm + ONNX Gemm node is defined as: + Y = alpha * A * B + beta * C + + The translation defines the interactions: + /a_i ~ Mul(v_i, Concrete(alpha * b_i))/ + /Add(...(Add(y_i, v_1), ...), v_n) ~ Concrete(beta * c_i)/ + + By definition this interaction is equal to: + /v_i = alpha * a_i * b_i/ + /y_i = v_1 + v_2 + ... + v_n + beta * c_i/ + + By grouping the operations we get: + /Y = alpha * A * B + beta * C/ * Proof for the Interaction Rules -** Mathematical Definitions - - Linear(x, q, r) = q*x + r %with q,r Real% - - Concrete(k) = k %with k Real% - - Add(a, b) = a + b - - AddCheckLinear(x, q, r, b) = q*x + (r + b) %with q,r Real% - - AddCheckConcrete(k, b) = k + b %with k Real% - - Mul(a, b) = a * b - - MulCheckLinear(x, q, r, b) = q*b*x + r*b %with q,r Real% - - MulCheckConcrete(k, b) = k*b %with k Real% - - ReLU(x) = IF (x > 0) THEN x ELSE 0 - - Materialize(x) = x - -** Rules -*** Formatting - Agent1 >< Agent2 => Wiring - - LEFT SIDE MATHEMATICAL INTERPRETATION - - RIGHT SIDE MATHEMATICAL INTERPRETATION - - SHOWING EQUIVALENCE - -*** Materialize - The Materialize agent transforms a Linear agent into a tree of explicit mathematical operations - that are used as final representation for the solver. - In the Python module the terms are defined as: - @code python +** Materialize + The Materialize agent transforms a Linear agent into a tree of explicit mathematical operations + that are used as final representation for the solver. + In the Python module the terms are defined as: + @code python def TermAdd(a, b): return a + b def TermMul(a, b): return a * b def TermReLU(x): return z3.If(x > 0, x, 0) - @end -**** Linear(x, q, r) >< Materialize(out) => (1), (2), (3), (4), (5) - - Linear(x, q, r) = term - Materialize(term) = out - out = q*x + r - - $$ Case 1: q = 0 => out ~ Concrete(r), x ~ Eraser - Concrete(r) = out - out = r - - 0*x + r = r => r = r - $$ - - $$ Case 2: q = 1, r = 0 => out ~ x - x = out - out = x - - 1*x + 0 = x => x = x - $$ - - $$ Case 3: q = 1 => out ~ TermAdd(x, Concrete(r)) - TermAdd(x, Concrete(r)) = out - out = x + r - - 1*x + r = x + r => x + r = x + r - $$ - - $$ Case 4: r = 0 => out ~ TermMul(Concrete(q), x) - TermMul(Concrete(q), x) = out - out = q*x - - q*x + 0 = q*x => q*x = q*x - $$ - - $$ Case 5: otherwise => out ~ TermAdd(TermMul(Concrete(q), x), Concrete(r)) - TermAdd(TermMul(Concrete(q), x), r) = out - out = q*x + r - - q*x + r = q*x + r - $$ - -**** Concrete(k) >< Materialize(out) => out ~ Concrete(k) - - Concrete(k) = term - Materialize(term) = out - out = k - - Concrete(k) = out - out = k - - k = k - -*** Add -**** Linear(x, q, r) >< Add(out, b) => b ~ AddCheckLinear(out, x, q, r) - - Linear(x, q, r) = a - Add(a, b) = out - out = q*x + r + b - - AddCheckLinear(x, q, r, b) = out - out = q*x + (r + b) - - q*x + r + b = q*x + (r + b) => q*x + (r + b) = q*x + (r + b) - -**** Concrete(k) >< Add(out, b) => (1), (2) - - Concrete(k) = a - Add(a, b) = out - out = k + b - - $$ Case 1: k = 0 => out ~ b - b = out - out = b - - 0 + b = b => b = b - $$ - - $$ Case 2: otherwise => b ~ AddCheckConcrete(out, k) - AddCheckConcrete(k, b) = out - out = k + b - - k + b = k + b - $$ - -**** Linear(y, s, t) >< AddCheckLinear(out, x, q, r) => (1), (2), (3), (4) - - Linear(y, s, t) = b - AddCheckLinear(x, q, r, b) = out - out = q*x + (r + s*y + t) - - $$ Case 1: q,r,s,t = 0 => out ~ Concrete(0), x ~ Eraser, y ~ Eraser - Concrete(0) = out - out = 0 - - 0*x + (0 + 0*y + 0) = 0 => 0 = 0 - $$ - - $$ Case 2: s,t = 0 => out ~ Linear(x, q, r), y ~ Eraser - Linear(x, q, r) = out - out = q*x + r - - q*x + (r + 0*y + 0) = q*x + r => q*x + r = q*x + r - $$ - - $$ Case 3: q, r = 0 => out ~ Linear(y, s, t), x ~ Eraser - Linear(y, s, t) = out - out = s*y + t - - 0*x + (0 + s*y + t) = s*y + t => s*y + t = s*y + t - $$ - - $$ Case 4: otherwise => Linear(x, q, r) ~ Materialize(out_x), Linear(y, s, t) ~ Materialize(out_y), out ~ Linear(TermAdd(out_x, out_y), 1, 0) - Materialize(Linear(x, q, r)) = out_x - Materialize(Linear(y, s, t)) = out_y - Linear(TermAdd(out_x, out_y), 1, 0) = out - out_x = q*x + r - out_y = s*y + t - out = 1*TermAdd(q*x + r, s*y + t) + 0 - Because TermAdd(a, b) is defined as "a+b": - out = 1*(q*x + r + s*y + t) + 0 - - q*x + (r + s*y + t) = 1*(q*x + r + s*y + t) + 0 => q*x + r + s*y + t = q*x + r + s*y + t - $$ - -**** Concrete(j) >< AddCheckLinear(out, x, q, r) => out ~ Linear(x, q, r + j) - - Concrete(j) = b - AddCheckLinear(x, q, r, b) = out - out = q*x + (r + j) - - Linear(x, q, r + j) = out - out = q*x + (r + j) - - q*x + (r + j) = q*x + (r + j) - -**** Linear(y, s, t) >< AddCheckConcrete(out, k) => out ~ Linear(y, s, t + k) - - Linear(y, s, t) = b - AddCheckConcrete(k, b) = out - out = k + s*y + t - - Linear(y, s, t + k) - out = s*y + (t + k) - - k + s*y + t = s*y + (t + k) => s*y + (t + k) = s*y + (t + k) - -**** Concrete(j) >< AddCheckConcrete(out, k) => (1), (2) - - Concrete(j) = b - AddCheckConcrete(k, b) = out - out = k + j - - $$ Case 1: j = 0 => out ~ Concrete(k) - Concrete(k) = out - out = k - - k + 0 = k => k = k - $$ - - $$ Case 2: otherwise => out ~ Concrete(k + j) - Concrete(k + j) = out - out = k + j - - k + j = k + j - $$ - -*** Mul -**** Linear(x, q, r) >< Mul(out, b) => b ~ MulCheckLinear(out, x, q, r) - - Linear(x, q, r) = a - Mul(a, b) = out - out = (q*x + r) * b - - MulCheckLinear(x, q, r, b) = out - out = q*b*x + r*b - - (q*x + r) * b = q*b*x + r*b => q*b*x + r*b = q*b*x + r*b - -**** Concrete(k) >< Mul(out, b) => (1), (2), (3) - - Concrete(k) = a - Mul(a, b) = out - out = k * b - - $$ Case 1: k = 0 => out ~ Concrete(0), b ~ Eraser - Concrete(0) = out - out = 0 - - 0 * b = 0 => 0 = 0 - $$ - - $$ Case 2: k = 1 => out ~ b - b = out - out = b - - 1 * b = b => b = b - $$ - - $$ Case 3: otherwise => b ~ MulCheckConcrete(out, k) - MulCheckConcrete(k, b) = out - out = k * b - - k * b = k * b - $$ - -**** Linear(y, s, t) >< MulCheckLinear(out, x, q, r) => (1), (2) - - Linear(y, s, t) = b - MulCheckLinear(x, q, r, b) = out - out = q\*(s*y + t)\*x + r*(s*y + t) - - $$ Case 1: (q,r = 0) or (s,t = 0) => x ~ Eraser, y ~ Eraser, out ~ Concrete(0) - Concrete(0) = out - out = 0 - - 0\*(s*y + t)\*x + 0*(s*y + t) = 0 => 0 = 0 - or - q\*(0*y + 0)\*x + r*(0*y + 0) = 0 => 0 = 0 - $$ - - $$ Case 2: otherwise => Linear(x, q, r) ~ Materialize(out_x), Linear(y, s, t) ~ Materialize(out_y), out ~ Linear(TermMul(out_x, out_y), 1, 0) - Materialize(Linear(x, q, r)) = out_x - Materialize(Linear(y, s, t)) = out_y - Linear(TermMul(out_x, out_y), 1, 0) = out - out_x = q*x + r - out_y = s*y + t - out = 1*TermMul(q*x + r, s*y + t) + 0 - Because TermMul(a, b) is defined as "a*b": - out = 1*(q*x + r)*(s*y + t) + 0 - - q*(s*y + t)\*x + r*(s*y + t) = 1*(q*x + r)\*(s*y + t) => - q\*(s*y + t)\*x + r*(s*y + t) = (q*x + r)\*(s*y + t) => - q\*(s*y + t)\*x + r*(s*y + t) = q\*(s*y + t)\*x + r*(s*y + t) - $$ - -**** Concrete(j) >< MulCheckLinear(out, x, q, r) => out ~ Linear(x, q * j, r * j) - - Concrete(j) = b - MulCheckLinear(x, q, r, b) = out - out = q*j*x + r*j - - Linear(x, q * j, r * j) = out - out = q*j*x + r*j - - q*j*x + r*j = q*j*x + r*j - -**** Linear(y, s, t) >< MulCheckConcrete(out, k) => out ~ Linear(y, s * k, t * k) - - Linear(y, s, t) = b - MulCheckConcrete(k, b) = out - out = k * (s*y + t) - - Linear(y, s * k, t * k) = out - out = s*k*y + t*k - - k * (s*y + t) = s*k*y + t*k => s*k*y + t*k = s*k*y + t*k - -**** Concrete(j) >< MulCheckConcrete(out, k) => (1), (2), (3) - - Concrete(j) = b - MulCheckConcrete(k, b) = out - out = k * j - - $$ Case 1: j = 0 => out ~ Concrete(0) - Concrete(0) = out - out = 0 - - k * 0 = 0 => 0 = 0 - $$ - - $$ Case 2: j = 1 => out ~ Concrete(k) - Concrete(k) = out - out = k - - k * 1 = k => k = k - $$ - - $$ Case 3: otherwise => out ~ Concrete(k * j) - Concrete(k * j) = out - out = k * j - - k * j = k * j - -*** ReLU -**** Linear(x, q, r) >< ReLU(out) => Linear(x, q, r) ~ Materialize(out_x), out ~ Linear(TermReLU(out_x), 1, 0) - - Linear(x, q, r) = x - ReLU(x) = out - out = IF (q*x + r) > 0 THEN (q*x + r) ELSE 0 - - Materialize(Linear(x, q, r)) = out_x - Linear(TermReLU(out_x), 1, 0) = out - out_x = q*x + r - out = 1*TermReLU(q*x + r) + 0 - Because TermReLU(x) is defined as "z3.If(x > 0, x, 0)": - out = 1*(IF (q*x + r) > 0 THEN (q*x + r) ELSE 0) + 0 - - IF (q*x + r) > 0 THEN (q*x + r) ELSE 0 = 1*(IF (q*x + r) > 0 THEN (q*x + r) ELSE 0) + 0 => - IF (q*x + r) > 0 THEN (q*x + r) ELSE 0 = IF (q*x + r) > 0 THEN (q*x + r) ELSE 0 - -**** Concrete(k) >< ReLU(out) => (1), (2) - - Concrete(k) = x - ReLU(x) = out - out = IF k > 0 THEN k ELSE 0 - - $$ Case 1: k > 0 => out ~ Concrete(k) - Concrete(k) = out - out = k - - IF true THEN k ELSE 0 = k => k = k - $$ - - $$ Case 2: k <= 0 => out ~ Concrete(0) - Concrete(0) = out - out = 0 - - IF false THEN k ELSE 0 = 0 => 0 = 0 - $$ + @end +*** Linear(x, q, r) >< Materialize(out) => (1), (2), (3), (4), (5) + LHS: + Linear(x, q, r) ~ wire + Materialize(out) ~ wire + q*x + r = wire + out = wire + out = q*x + r + + $$ Case 1: q = 0 => out ~ Concrete(r), x ~ Eraser + RHS: + out = r + + EQUIVALENCE: + 0*x + r = r => r = r + $$ + + $$ Case 2: q = 1, r = 0 => out ~ x + RHS: + x = out + out = x + + EQUIVALENCE: + 1*x + 0 = x => x = x + $$ + + $$ Case 3: q = 1 => out ~ TermAdd(x, Concrete(r)) + RHS: + out = x + r + + EQUIVALENCE: + 1*x + r = x + r => x + r = x + r + $$ + + $$ Case 4: r = 0 => out ~ TermMul(Concrete(q), x) + RHS: + out = q*x + + EQUIVALENCE: + q*x + 0 = q*x => q*x = q*x + $$ + + $$ Case 5: otherwise => out ~ TermAdd(TermMul(Concrete(q), x), Concrete(r)) + RHS: + out = q*x + r + + EQUIVALENCE: + q*x + r = q*x + r + $$ + +*** Concrete(k) >< Materialize(out) => out ~ Concrete(k) + LHS: + Concrete(k) ~ wire + Materialize(out) ~ wire + k = wire + out = wire + out = k + + RHS: + out = k + + EQUIVALENCE: + k = k + +** Add +*** Linear(x, q, r) >< Add(out, b) => b ~ AddCheckLinear(out, x, q, r) + LHS: + Linear(x, q, r) ~ wire + Add(out, b) ~ wire + q*x + r = wire + out = wire + b + out = q*x + r + b + + RHS: + out = q*x + (r + b) + + EQUIVALENCE: + q*x + r + b = q*x + (r + b) => q*x + (r + b) = q*x + (r + b) + +*** Concrete(k) >< Add(out, b) => (1), (2) + LHS: + Concrete(k) ~ wire + Add(out, b) ~ wire + k = wire + out = wire + b + out = k + b + + $$ Case 1: k = 0 => out ~ b + RHS: + out = b + + EQUIVALENCE: + 0 + b = b => b = b + $$ + + $$ Case 2: otherwise => b ~ AddCheckConcrete(out, k) + RHS: + out = k + b + + EQUIVALENCE: + k + b = k + b + $$ + +*** Linear(y, s, t) >< AddCheckLinear(out, x, q, r) => (1), (2), (3), (4) + LHS: + Linear(y, s, t) ~ wire + AddCheckLinear(out, x, q, r) ~ wire + s*y + t = wire + out = q*x + (r + wire) + out = q*x + (r + s*y + t) + + $$ Case 1: q,r,s,t = 0 => out ~ Concrete(0), x ~ Eraser, y ~ Eraser + RHS: + out = 0 + + EQUIVALENCE: + 0*x + (0 + 0*y + 0) = 0 => 0 = 0 + $$ + + $$ Case 2: s,t = 0 => out ~ Linear(x, q, r), y ~ Eraser + RHS: + out = q*x + r + + EQUIVALENCE: + q*x + (r + 0*y + 0) = q*x + r => q*x + r = q*x + r + $$ + + $$ Case 3: q, r = 0 => out ~ Linear(y, s, t), x ~ Eraser + RHS: + out = s*y + t + + EQUIVALENCE: + 0*x + (0 + s*y + t) = s*y + t => s*y + t = s*y + t + $$ + + $$ Case 4: otherwise => Linear(x, q, r) ~ Materialize(out_x), Linear(y, s, t) ~ Materialize(out_y), out ~ Linear(TermAdd(out_x, out_y), 1, 0) + RHS: + Linear(x, q, r) ~ wire1 + Materialize(out_x) ~ wire1 + q*x + r = wire1 + out_x = wire1 + Linear(y, s, t) ~ wire2 + Materialize(out_y) ~ wire2 + s*y + t = wire2 + out_y = wire2 + out = 1*TermAdd(out_x, out_y) + 0 + Because TermAdd(a, b) is defined as "a+b": + out = 1*(q*x + r + s*y + t) + 0 + + EQUIVALENCE: + q*x + (r + s*y + t) = 1*(q*x + r + s*y + t) + 0 => q*x + r + s*y + t = q*x + r + s*y + t + $$ + +*** Concrete(j) >< AddCheckLinear(out, x, q, r) => out ~ Linear(x, q, r + j) + LHS: + Concrete(j) ~ wire + AddCheckLinear(out, x, q, r) ~ wire + j = wire + out = q*x + (r + wire) + out = q*x + (r + j) + + RHS: + out = q*x + (r + j) + + EQUIVALENCE: + q*x + (r + j) = q*x + (r + j) + +*** Linear(y, s, t) >< AddCheckConcrete(out, k) => out ~ Linear(y, s, t + k) + LHS: + Linear(y, s, t) ~ wire + AddCheckConcrete(out, k) ~ wire + s*y + t = wire + out = k + wire + out = k + s*y + t + + RHS: + out = s*y + (t + k) + + EQUIVALENCE: + k + s*y + t = s*y + (t + k) => s*y + (t + k) = s*y + (t + k) + +*** Concrete(j) >< AddCheckConcrete(out, k) => (1), (2) + LHS: + Concrete(j) ~ wire + AddCheckConcrete(out, k) ~ wire + j = wire + out = k + wire + out = k + j + + $$ Case 1: j = 0 => out ~ Concrete(k) + RHS: + out = k + + EQUIVALENCE: + k + 0 = k => k = k + $$ + + $$ Case 2: otherwise => out ~ Concrete(k + j) + RHS: + out = k + j + + EQUIVALENCE: + k + j = k + j + $$ + +** Mul +*** Linear(x, q, r) >< Mul(out, b) => b ~ MulCheckLinear(out, x, q, r) + LHS: + Linear(x, q, r) ~ wire + Mul(out, b) ~ wire + q*x + r = wire + out = wire * b + out = (q*x + r) * b + + RHS: + out = q*b*x + r*b + + EQUIVALENCE: + (q*x + r) * b = q*b*x + r*b => q*b*x + r*b = q*b*x + r*b + +*** Concrete(k) >< Mul(out, b) => (1), (2), (3) + LHS: + Concrete(k) ~ wire + Mul(out, b) ~ wire + k = wire + out = wire * b + out = k * b + + $$ Case 1: k = 0 => out ~ Concrete(0), b ~ Eraser + RHS: + out = 0 + + EQUIVALENCE: + 0 * b = 0 => 0 = 0 + $$ + + $$ Case 2: k = 1 => out ~ b + RHS: + out = b + + EQUIVALENCE: + 1 * b = b => b = b + $$ + + $$ Case 3: otherwise => b ~ MulCheckConcrete(out, k) + RHS: + out = k * b + + EQUIVALENCE: + k * b = k * b + $$ + +*** Linear(y, s, t) >< MulCheckLinear(out, x, q, r) => (1), (2) + LHS: + Linear(y, s, t) ~ wire + MulCheckLinear(out, x, q, r) ~ wire + s*y + t = wire + out = q*wire*x + r*wire + out = q\*(s*y + t)\*x + r*(s*y + t) + + $$ Case 1: (q,r = 0) or (s,t = 0) => x ~ Eraser, y ~ Eraser, out ~ Concrete(0) + RHS: + out = 0 + + EQUIVALENCE: + 0\*(s*y + t)\*x + 0*(s*y + t) = 0 => 0 = 0 + or + q\*(0*y + 0)\*x + r*(0*y + 0) = 0 => 0 = 0 + $$ + + $$ Case 2: otherwise => Linear(x, q, r) ~ Materialize(out_x), Linear(y, s, t) ~ Materialize(out_y), out ~ Linear(TermMul(out_x, out_y), 1, 0) + RHS: + Linear(x, q, r) ~ wire1 + Materialize(out_x) ~ wire1 + q*x + r = wire1 + out_x = wire1 + Linear(y, s, t) ~ wire2 + Materialize(out_y) ~ wire2 + s*y + t = wire2 + out_y = wire2 + out = 1*TermMul(out_x, out_y) + 0 + Because TermMul(a, b) is defined as "a*b": + out = 1*(q*x + r)*(s*y + t) + 0 + + EQUIVALENCE: + q*(s*y + t)\*x + r*(s*y + t) = 1*(q*x + r)\*(s*y + t) => + q\*(s*y + t)\*x + r*(s*y + t) = (q*x + r)\*(s*y + t) => + q\*(s*y + t)\*x + r*(s*y + t) = q\*(s*y + t)\*x + r*(s*y + t) + $$ + +*** Concrete(j) >< MulCheckLinear(out, x, q, r) => out ~ Linear(x, q * j, r * j) + LHS: + Concrete(j) ~ wire + MulCheckLinear(out, x, q, r) ~ wire + j = wire + out = q*wire*x + r*wire + out = q*j*x + r*j + + RHS: + out = q*j*x + r*j + + EQUIVALENCE: + q*j*x + r*j = q*j*x + r*j + +*** Linear(y, s, t) >< MulCheckConcrete(out, k) => out ~ Linear(y, s * k, t * k) + LHS: + Linear(y, s, t) ~ wire + MulCheckConcrete(out, k) ~ wire + s*y + t = wire + out = k * wire + out = k * (s*y + t) + + RHS: + out = s*k*y + t*k + + EQUIVALENCE: + k * (s*y + t) = s*k*y + t*k => s*k*y + t*k = s*k*y + t*k + +*** Concrete(j) >< MulCheckConcrete(out, k) => (1), (2), (3) + LHS: + Concrete(j) ~ wire + MulCheckConcrete(out, k) ~ wire + j = wire + out = k * wire + out = k * j + + $$ Case 1: j = 0 => out ~ Concrete(0) + RHS: + out = 0 + + EQUIVALENCE: + k * 0 = 0 => 0 = 0 + $$ + + $$ Case 2: j = 1 => out ~ Concrete(k) + RHS: + out = k + + EQUIVALENCE: + k * 1 = k => k = k + $$ + + $$ Case 3: otherwise => out ~ Concrete(k * j) + RHS: + out = k * j + + EQUIVALENCE: + k * j = k * j + +** ReLU +*** Linear(x, q, r) >< ReLU(out) => Linear(x, q, r) ~ Materialize(out_x), out ~ Linear(TermReLU(out_x), 1, 0) + LHS: + Linear(x, q, r) ~ wire + ReLU(out) ~ wire + q*x + r = wire + out = IF wire > 0 THEN wire ELSE 0 + out = IF (q*x + r) > 0 THEN (q*x + r) ELSE 0 + + RHS: + Linear(x, q, r) ~ wire + Materialize(out_x) ~ wire + q*x + r = wire + out_x = wire + out = 1*TermReLU(out_x) + 0 + Because TermReLU(x) is defined as "z3.If(x > 0, x, 0)": + out = 1*(IF (q*x + r) > 0 THEN (q*x + r) ELSE 0) + 0 + + EQUIVALENCE: + IF (q*x + r) > 0 THEN (q*x + r) ELSE 0 = 1*(IF (q*x + r) > 0 THEN (q*x + r) ELSE 0) + 0 => + IF (q*x + r) > 0 THEN (q*x + r) ELSE 0 = IF (q*x + r) > 0 THEN (q*x + r) ELSE 0 + +*** Concrete(k) >< ReLU(out) => (1), (2) + LHS: + Concrete(k) ~ wire + ReLU(out) ~ wire + k = wire + out = IF wire > 0 THEN wire ELSE 0 + out = IF k > 0 THEN k ELSE 0 + + $$ Case 1: k > 0 => out ~ Concrete(k) + RHS: + out = k + + EQUIVALENCE: + IF true THEN k ELSE 0 = k => k = k + $$ + + $$ Case 2: k <= 0 => out ~ Concrete(0) + RHS: + out = 0 + + EQUIVALENCE: + IF false THEN k ELSE 0 = 0 => 0 = 0 + $$ @@ -4,7 +4,7 @@ import torch.onnx import nneq class xor_mlp(nn.Module): - def __init__(self, hidden_dim=8): + def __init__(self, hidden_dim): super().__init__() self.layers = nn.Sequential( nn.Linear(2, hidden_dim), @@ -37,8 +37,8 @@ if __name__ == "__main__": torch_net_a = train_model("Network A", 8).eval() torch_net_b = train_model("Network B", 16).eval() - onnx_net_a = torch.onnx.export(torch_net_a, (torch.randn(1, 2),), verbose=False, dynamo=True).model_proto # type: ignore - onnx_net_b = torch.onnx.export(torch_net_b, (torch.randn(1, 2),), verbose=False, dynamo=True).model_proto # type: ignore + onnx_net_a = torch.onnx.export(torch_net_a, (torch.randn(1, 2),), verbose=False).model_proto # type: ignore + onnx_net_b = torch.onnx.export(torch_net_b, (torch.randn(1, 2),), verbose=False).model_proto # type: ignore z3_net_a = nneq.net(onnx_net_a) z3_net_b = nneq.net(onnx_net_b) |
