aboutsummaryrefslogtreecommitdiff

Soundness Proof

Mathematical Definitions

Linear(x, q, r) ~ out => out = q*x + r
Concrete(k) ~ out => out =
Add(out, b) ~ a => out = a + b
AddCheckLinear(out, x, q, r) ~ b => out = q*x + (r + b)
AddCheckConcrete(out, k) ~ b => out = k + b
Mul(out, b) ~ a => out = a * b
MulCheckLinear(out, x, q, r) ~ b => out = q*b*x + r*b
MulCheckConcrete(out, k) ~ b => out = k*b
ReLU(out) ~ x => out = IF (x > 0) THEN x ELSE 0
Materialize(out) ~ x => out = x

Soundness of Translation

ReLU

ONNX ReLU node is defined as:
Y = X if X > 0 else 0

The translation defines the interactions:
x_i ~ ReLU(y_i)

By definition this interaction is equal to:
y_i = IF (x_i > 0) THEN x_i ELSE 0

Gemm

ONNX Gemm node is defined as:
Y = alpha * A * B + beta * C

The translation defines the interactions:
a_i ~ Mul(v_i, Concrete(alpha * b_i))
Add(...(Add(y_i, v_1), ...), v_n) ~ Concrete(beta * c_i)

By definition this interaction is equal to:
v_i = alpha * a_i * b_i
y_i = v_1 + v_2 + ... + v_n + beta * c_i

By grouping the operations we get:
Y = alpha * A * B + beta * C

Identiry / Flatten / Reshape / Squeeze / Unsqueeze

Just identity mapping because wires represent a single element and they are not structured as Tensors.
out_i ~ in_i

MatMul

Equal to Gemm with alpha=1, beta=0 and C=0.

Add

ONNX Add node is defined as:
C = A + B

The translation defines the interactions:
Add(c_i, b_i) ~ a_i

By definition this interaction is equal to:
c_i = a_i + b_i

By grouping the operations we get:
C = A + B

Sub

ONNX Sub node is defined as:
C = A - B

The translation defines the interactions:
Add(c_i, neg_b_i) ~ a_i
Mul(neg_b_i, Concrete(-1)) ~ b_i

By definition this interaction is equal to:
c_i = a_i + neg_b_i
neg_b_i = -1 * b_i

By grouping the operations we get:
C = A - B

Slice

ONNX Slice is defined as:
out_j = in_{start + (j * step)}

The translations creates a wiring analog to the above definition:
out_j ~ in_{start + (j * step)}

Soundness of Interaction Rules

Materialize

The Materialize agent transforms a Linear agent into a tree of explicit mathematical operations
that are used as final representation for the solver.
In the Python module the terms are defined as:

def TermAdd(a, b):
    return a + b  
def TermMul(a, b):
    return a * b
def TermReLU(x):
    return z3.If(x > 0, x, 0)

Linear >< Materialize

Linear(x, q, r) >< Materialize(out) => (1), (2), (3), (4), (5)

LHS:
Linear(x, q, r) ~ wire
Materialize(out) ~ wire
q*x + r = wire
out = wire
out = q*x + r

Case 1:

q = 0 => out ~ Concrete(r), x ~ Eraser

RHS:
out = r

EQUIVALENCE:
0*x + r = r => r = r

Case 2:

q = 1, r = 0 => out ~ x

RHS:
x = out
out = x

EQUIVALENCE:
1*x + 0 = x => x = x

Case 3:

q = 1 => out ~ TermAdd(x, Concrete(r))

RHS:
out = x + r

EQUIVALENCE:
1*x + r = x + r => x + r = x + r

Case 4:

r = 0 => out ~ TermMul(Concrete(q), x)

RHS:
out = q*x

EQUIVALENCE:
q*x + 0 = q*x => q*x = q*x

Case 5:

otherwise => out ~ TermAdd(TermMul(Concrete(q), x), Concrete(r))

RHS:
out = q*x + r

EQUIVALENCE:
q*x + r = q*x + r

Concrete >< Materialize

Concrete(k) >< Materialize(out) => out ~ Concrete(k)

LHS:
Concrete(k) ~ wire
Materialize(out) ~ wire
k = wire
out = wire
out = k

RHS:
out = k

EQUIVALENCE:
k = k

Add

Linear >< Add

Linear(x, q, r) >< Add(out, b) => b ~ AddCheckLinear(out, x, q, r)

LHS:
Linear(x, q, r) ~ wire
Add(out, b) ~ wire
q*x + r = wire
out = wire + b
out = q*x + r + b

RHS:
out = q*x + (r + b)

EQUIVALENCE:
q*x + r + b = q*x + (r + b) => q*x + (r + b) = q*x + (r + b)

Concrete >< Add

Concrete(k) >< Add(out, b) => (1), (2)

LHS:
Concrete(k) ~ wire
Add(out, b) ~ wire
k = wire
out = wire + b
out = k + b

Case 1:

k = 0 => out ~ b

RHS:
out = b

EQUIVALENCE:
0 + b = b => b = b

Case 2:

otherwise => b ~ AddCheckConcrete(out, k)

RHS:
out = k + b

EQUIVALENCE:
k + b = k + b

Linear >< AddCheckLinear

Linear(y, s, t) >< AddCheckLinear(out, x, q, r) => (1), (2), (3), (4)

LHS:
Linear(y, s, t) ~ wire
AddCheckLinear(out, x, q, r) ~ wire
s*y + t = wire
out = q*x + (r + wire)
out = q*x + (r + s*y + t)

Case 1:

q,r,s,t = 0 => out ~ Concrete(0), x ~ Eraser, y ~ Eraser

RHS:
out = 0

EQUIVALENCE:
0*x + (0 + 0*y + 0) = 0 => 0 = 0

Case 2:

s,t = 0 => out ~ Linear(x, q, r), y ~ Eraser

RHS:
out = q*x + r

EQUIVALENCE:
q*x + (r + 0*y + 0) = q*x + r => q*x + r = q*x + r

Case 3:

q, r = 0 => out ~ Linear(y, s, t), x ~ Eraser

RHS:
out = s*y + t

EQUIVALENCE:
0*x + (0 + s*y + t) = s*y + t => s*y + t = s*y + t

Case 4:

otherwise => Linear(x, q, r) ~ Materialize(out_x), Linear(y, s, t) ~ Materialize(out_y), out ~ Linear(TermAdd(out_x, out_y), 1, 0)

RHS:
Linear(x, q, r) ~ wire_1
Materialize(out_x) ~ wire_1
q*x + r = wire_1
out_x = wire_1
Linear(y, s, t) ~ wire_2
Materialize(out_y) ~ wire_2
s*y + t = wire_2
out_y = wire_2
out = 1*TermAdd(out_x, out_y) + 0
Because TermAdd(a, b) is defined as a+b:
out = 1*(q*x + r + s*y + t) + 0

EQUIVALENCE:
q*x + (r + s*y + t) = 1*(q*x + r + s*y + t) + 0 => q*x + r + s*y + t = q*x + r + s*y + t

Concrete >< AddCheckLinear

Concrete(j) >< AddCheckLinear(out, x, q, r) => out ~ Linear(x, q, r + j)

LHS:
Concrete(j) ~ wire
AddCheckLinear(out, x, q, r) ~ wire
j = wire
out = q*x + (r + wire)
out = q*x + (r + j)

RHS:
out = q*x + (r + j)

EQUIVALENCE:
q*x + (r + j) = q*x + (r + j)

Linear >< AddCheckConcrete

Linear(y, s, t) >< AddCheckConcrete(out, k) => out ~ Linear(y, s, t + k)

LHS:
Linear(y, s, t) ~ wire
AddCheckConcrete(out, k) ~ wire
s*y + t = wire
out = k + wire
out = k + s*y + t

RHS:
out = s*y + (t + k)

EQUIVALENCE:
k + s*y + t = s*y + (t + k) => s*y + (t + k) = s*y + (t + k)

Concrete >< AddCheckConcrete

Concrete(j) >< AddCheckConcrete(out, k) => (1), (2)

LHS:
Concrete(j) ~ wire
AddCheckConcrete(out, k) ~ wire
j = wire
out = k + wire
out = k + j

Case 1:

j = 0 => out ~ Concrete(k)

RHS:
out = k

EQUIVALENCE:
k + 0 = k => k = k

Case 2:

otherwise => out ~ Concrete(k + j)

RHS:
out = k + j

EQUIVALENCE:
k + j = k + j

Mul

Linear >< Mul

Linear(x, q, r) >< Mul(out, b) => b ~ MulCheckLinear(out, x, q, r)

LHS:
Linear(x, q, r) ~ wire
Mul(out, b) ~ wire
q*x + r = wire
out = wire * b
out = (q*x + r) * b

RHS:
out = q*b*x + r*b

EQUIVALENCE:
(q*x + r) * b = q*b*x + r*b => q*b*x + r*b = q*b*x + r*b

Concrete >< Mul

Concrete(k) >< Mul(out, b) => (1), (2), (3)

LHS:
Concrete(k) ~ wire
Mul(out, b) ~ wire
k = wire
out = wire * b
out = k * b

Case 1:

k = 0 => out ~ Concrete(0), b ~ Eraser

RHS:
out = 0

EQUIVALENCE:
0 * b = 0 => 0 = 0

Case 2:

k = 1 => out ~ b

RHS:
out = b

EQUIVALENCE:
1 * b = b => b = b

Case 3:

otherwise => b ~ MulCheckConcrete(out, k)

RHS:
out = k * b

EQUIVALENCE:
k * b = k * b

Linear >< MulCheckLinear

Linear(y, s, t) >< MulCheckLinear(out, x, q, r) => (1), (2)

LHS:
Linear(y, s, t) ~ wire
MulCheckLinear(out, x, q, r) ~ wire
s*y + t = wire
out = q*wire*x + r*wire
out = q*(s*y + t)*x + r*(s*y + t)

Case 1:

(q,r = 0) or (s,t = 0) => x ~ Eraser, y ~ Eraser, out ~ Concrete(0)

RHS:
out = 0

EQUIVALENCE:
0*(s*y + t)*x + 0*(s*y + t) = 0 => 0 = 0
or
q*(0*y + 0)*x + r*(0*y + 0) = 0 => 0 = 0

Case 2:

otherwise => Linear(x, q, r) ~ Materialize(out_x), Linear(y, s, t) ~ Materialize(out_y), out ~ Linear(TermMul(out_x, out_y), 1, 0)

RHS:
Linear(x, q, r) ~ wire_1
Materialize(out_x) ~ wire_1
q*x + r = wire_1
out_x = wire_1
Linear(y, s, t) ~ wire_2
Materialize(out_y) ~ wire_2
s*y + t = wire_2
out_y = wire_2
out = 1*TermMul(out_x, out_y) + 0
Because TermMul(a, b) is defined as a*b:
out = 1*(q*x + r)*(s*y + t) + 0

EQUIVALENCE:
q*(s*y + t)*x + r*(s*y + t) = 1*(q*x + r)*(s*y + t) => 
q*(s*y + t)*x + r*(s*y + t) = (q*x + r)*(s*y + t) => 
q*(s*y + t)*x + r*(s*y + t) = q*(s*y + t)*x + r*(s*y + t)

Concrete >< MulCheckLinear

Concrete(j) >< MulCheckLinear(out, x, q, r) => out ~ Linear(x, q * j, r * j)

LHS:
Concrete(j) ~ wire
MulCheckLinear(out, x, q, r) ~ wire
j = wire
out = q*wire*x + r*wire
out = q*j*x + r*j

RHS:
out = q*j*x + r*j

EQUIVALENCE:
q*j*x + r*j = q*j*x + r*j

Linear >< MulCheckConcrete

Linear(y, s, t) >< MulCheckConcrete(out, k) => out ~ Linear(y, s * k, t * k)

LHS:
Linear(y, s, t) ~ wire
MulCheckConcrete(out, k) ~ wire
s*y + t = wire
out = k * wire
out = k * (s*y + t)

RHS:
out = s*k*y + t*k

EQUIVALENCE:
k * (s*y + t) = s*k*y + t*k => s*k*y + t*k = s*k*y + t*k

Concrete >< MulCheckConcrete

Concrete(j) >< MulCheckConcrete(out, k) => (1), (2), (3)

LHS:
Concrete(j) ~ wire
MulCheckConcrete(out, k) ~ wire
j = wire
out = k * wire
out = k * j

Case 1:

j = 0 => out ~ Concrete(0)

RHS:
out = 0

EQUIVALENCE:
k * 0 = 0 => 0 = 0

Case 2:

j = 1 => out ~ Concrete(k)

RHS:
out = k

EQUIVALENCE:
k * 1 = k => k = k

Case 3:

otherwise => out ~ Concrete(k * j)

RHS:
out = k * j

EQUIVALENCE:
k * j = k * j

ReLU

Linear >< ReLU

Linear(x, q, r) >< ReLU(out) => Linear(x, q, r) ~ Materialize(out_x), out ~ Linear(TermReLU(out_x), 1, 0)

LHS:
Linear(x, q, r) ~ wire
ReLU(out) ~ wire
q*x + r = wire
out = IF wire > 0 THEN wire ELSE 0
out = IF (q*x + r) > 0 THEN (q*x + r) ELSE 0

RHS:
Linear(x, q, r) ~ wire
Materialize(out_x) ~ wire
q*x + r = wire
out_x = wire
out = 1*TermReLU(out_x) + 0
Because TermReLU(x) is defined as z3.If(x > 0, x, 0):
out = 1*(IF (q*x + r) > 0 THEN (q*x + r) ELSE 0) + 0

EQUIVALENCE:
IF (q*x + r) > 0 THEN (q*x + r) ELSE 0 = 1*(IF (q*x + r) > 0 THEN (q*x + r) ELSE 0) + 0 =>
IF (q*x + r) > 0 THEN (q*x + r) ELSE 0 = IF (q*x + r) > 0 THEN (q*x + r) ELSE 0

Concrete >< ReLU

Concrete(k) >< ReLU(out) => (1), (2)

LHS:
Concrete(k) ~ wire
ReLU(out) ~ wire
k = wire
out = IF wire > 0 THEN wire ELSE 0
out = IF k > 0 THEN k ELSE 0

Case 1:

k > 0 => out ~ Concrete(k)

RHS:
out = k

EQUIVALENCE:
IF true THEN k ELSE 0 = k => k = k

Case 2:

k <= 0 => out ~ Concrete(0)

RHS:
out = 0

EQUIVALENCE:
IF false THEN k ELSE 0 = 0 => 0 = 0

Soundness of Reduction

Let IN_0 be the Interaction Net translated from a Neural Network NN. Let IN_n be the state of the net
after n reduction steps. Then forall n in N, [IN_n] = [NN].

Proof by Induction

  • Base Case (n = 0): By the Soundness of Translation, the initial net IN_0 is constructed such that
    its semantics [IN_0] exactly match the mathematical definition of the ONNX nodes in NN.
  • Induction Step (n -> n + 1): Assume [IN_n] = [NN]. If IN_n is in normal form, the proof is complete.
    Otherwise, there exists an active pair A that reduces IN_n to IN_{n+1}.
    By the Soundness of Interaction Rules, the mathematical definition is preserved after any reduction step,
    it follows that [IN_{n+1}] = [IN_n]. By the inductive hypothesis, [IN_{n+1}] = [NN].

By the principle of mathematical induction, the Interaction Net remains semantically equivalent to the original
Neural Network at every step of the reduction process.

Since Interaction Nets are confluent, the reduced mathematical expression is unique regardless
of order in which rules are applied.