aboutsummaryrefslogtreecommitdiff
path: root/prover.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--prover.py14
1 files changed, 6 insertions, 8 deletions
diff --git a/prover.py b/prover.py
index d6d1fcc..614cdb6 100644
--- a/prover.py
+++ b/prover.py
@@ -1,17 +1,14 @@
import z3
+import re
import sys
-# Scale used during export to Inpla (must match the 'scale' parameter in the exporter)
-SCALE = 1000.0
-
syms = {}
def Symbolic(id):
- id = f"x_{id}"
if id not in syms:
syms[id] = z3.Real(id)
return syms[id]
-def Concrete(val): return z3.RealVal(val) / SCALE
+def Concrete(val): return z3.RealVal(val)
def TermAdd(a, b): return a + b
def TermMul(a, b): return a * b
@@ -90,8 +87,9 @@ if __name__ == "__main__":
sys.exit(1)
try:
- net_a_str = lines[-2]
- net_b_str = lines[-1]
+ wrap = re.compile(r"Symbolic\((.*?)\)")
+ net_a_str = wrap.sub(r'Symbolic("\1")', lines[-2]);
+ net_b_str = wrap.sub(r'Symbolic("\1")', lines[-1]);
print(f"Comparing:\nA: {net_a_str}\n\nB: {net_b_str}")
@@ -101,7 +99,7 @@ if __name__ == "__main__":
print("\nStrict Equivalence")
equivalence(net_a, net_b)
print("\nEpsilon-Equivalence")
- epsilon_equivalence(net_a, net_b, 1e-2)
+ epsilon_equivalence(net_a, net_b, 1e-1)
print("\nARGMAX Equivalence")
argmax_equivalence(net_a, net_b)