diff options
| author | ericmarin <maarin.eric@gmail.com> | 2026-03-10 17:57:36 +0100 |
|---|---|---|
| committer | ericmarin <maarin.eric@gmail.com> | 2026-03-10 18:05:46 +0100 |
| commit | 8619ee7a61bafb8c401087508b886e37779be07b (patch) | |
| tree | 36d318e99883aab10cc429a57c44c1c96ddbb1c1 /prover.py | |
| parent | 0882fc5328127f68a7d79c06d0c7decdee770bb9 (diff) | |
| download | vein-8619ee7a61bafb8c401087508b886e37779be07b.tar.gz vein-8619ee7a61bafb8c401087508b886e37779be07b.zip | |
added scale
Diffstat (limited to '')
| -rw-r--r-- | prover.py | 15 |
1 files changed, 9 insertions, 6 deletions
@@ -1,6 +1,9 @@ import z3 import sys +# Scale used during export to Inpla (must match the 'scale' parameter in the exporter) +SCALE = 1000.0 + syms = {} def Symbolic(id): id = f"x_{id}" @@ -8,7 +11,8 @@ def Symbolic(id): syms[id] = z3.Real(id) return syms[id] -def Concrete(val): return z3.RealVal(val) +def Concrete(val): return z3.RealVal(val) / SCALE + def TermAdd(a, b): return a + b def TermMul(a, b): return a * b def TermReLU(x): return z3.If(x > 0, x, 0) @@ -56,20 +60,19 @@ def epsilon_equivalence(net_a, net_b, epsilon): def argmax_equivalence(net_a, net_b): solver = z3.Solver() - solver.add(z3.IsInt(net_a > 0.5) != z3.IsInt(net_b > 0.5)) + solver.add((net_a > 0.5) != (net_b > 0.5)) result = solver.check() if result == z3.unsat: - print(f"VERIFIED: The networks are argmax equivalent.") + print("VERIFIED: The networks are argmax equivalent (binary).") elif result == z3.sat: - print("FAILED: The networks are different.") + print("FAILED: The networks are classification-different.") print("Counter-example input:") print(solver.model()) else: print("UNKNOWN: Solver could not decide.") - if __name__ == "__main__": lines = [line.strip() for line in sys.stdin if line.strip() and not line.startswith("(")] @@ -81,7 +84,7 @@ if __name__ == "__main__": net_a_str = lines[-2] net_b_str = lines[-1] - print(f"Comparing:\nA: {net_a_str}\nB: {net_b_str}") + print(f"Comparing:\nA: {net_a_str}\n\nB: {net_b_str}") net_a = eval(net_a_str, context) net_b = eval(net_b_str, context) |
