summary refs log tree commit diff stats
path: root/compiler/semfold.nim
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/semfold.nim')
-rw-r--r--compiler/semfold.nim96
1 files changed, 70 insertions, 26 deletions
diff --git a/compiler/semfold.nim b/compiler/semfold.nim
index 6fcc9a0a4..c4d79a4a3 100644
--- a/compiler/semfold.nim
+++ b/compiler/semfold.nim
@@ -13,7 +13,7 @@
 import
   strutils, options, ast, astalgo, trees, treetab, nimsets, times,
   nversion, platform, math, msgs, os, condsyms, idents, renderer, types,
-  commands, magicsys, saturate
+  commands, magicsys
 
 proc getConstExpr*(m: PSym, n: PNode): PNode
   # evaluates the constant expression or returns nil if it is no constant
@@ -24,6 +24,63 @@ proc newIntNodeT*(intVal: BiggestInt, n: PNode): PNode
 proc newFloatNodeT(floatVal: BiggestFloat, n: PNode): PNode
 proc newStrNodeT*(strVal: string, n: PNode): PNode
 
+proc checkInRange(n: PNode, res: BiggestInt): bool =
+  if res in firstOrd(n.typ)..lastOrd(n.typ):
+    result = true
+
+proc foldAdd(a, b: BiggestInt, n: PNode): PNode =
+  let res = a +% b
+  if ((res xor a) >= 0'i64 or (res xor b) >= 0'i64) and
+      checkInRange(n, res):
+    result = newIntNodeT(res, n)     
+
+proc foldSub*(a, b: BiggestInt, n: PNode): PNode =
+  let res = a -% b
+  if ((res xor a) >= 0'i64 or (res xor not b) >= 0'i64) and
+      checkInRange(n, res):
+    result = newIntNodeT(res, n)
+
+proc foldAbs*(a: BiggestInt, n: PNode): PNode =
+  if a != firstOrd(n.typ):
+    result = newIntNodeT(a, n)
+  
+proc foldMod*(a, b: BiggestInt, n: PNode): PNode =
+  if b != 0'i64:
+    result = newIntNodeT(a mod b, n)
+
+proc foldModU*(a, b: BiggestInt, n: PNode): PNode =
+  if b != 0'i64:
+    result = newIntNodeT(a %% b, n)
+
+proc foldDiv*(a, b: BiggestInt, n: PNode): PNode =
+  if b != 0'i64 and (a != firstOrd(n.typ) or b != -1'i64):
+    result = newIntNodeT(a div b, n)
+
+proc foldDivU*(a, b: BiggestInt, n: PNode): PNode =
+  if b != 0'i64:
+    result = newIntNodeT(a /% b, n)
+
+proc foldMul*(a, b: BiggestInt, n: PNode): PNode =
+  let res = a *% b
+  let floatProd = toBiggestFloat(a) * toBiggestFloat(b)
+  let resAsFloat = toBiggestFloat(res)
+
+  # Fast path for normal case: small multiplicands, and no info
+  # is lost in either method.
+  if resAsFloat == floatProd and checkInRange(n, res):
+    return newIntNodeT(res, n)
+
+  # Somebody somewhere lost info. Close enough, or way off? Note
+  # that a != 0 and b != 0 (else resAsFloat == floatProd == 0).
+  # The difference either is or isn't significant compared to the
+  # true value (of which floatProd is a good approximation).
+
+  # abs(diff)/abs(prod) <= 1/32 iff
+  #   32 * abs(diff) <= abs(prod) -- 5 good bits is "close enough"
+  if 32.0 * abs(resAsFloat - floatProd) <= abs(floatProd) and
+      checkInRange(n, res):
+    return newIntNodeT(res, n)
+
 # implementation
 
 proc newIntNodeT(intVal: BiggestInt, n: PNode): PNode =
@@ -172,23 +229,22 @@ proc evalOp(m: TMagic, n, a, b, c: PNode): PNode =
   of mUnaryPlusI, mUnaryPlusF64: result = a # throw `+` away
   of mToFloat, mToBiggestFloat:
     result = newFloatNodeT(toFloat(int(getInt(a))), n)
+  # XXX: Hides overflow/underflow
   of mToInt, mToBiggestInt: result = newIntNodeT(system.toInt(getFloat(a)), n)
   of mAbsF64: result = newFloatNodeT(abs(getFloat(a)), n)
-  of mAbsI:
-    if getInt(a) >= 0: result = a
-    else: result = newIntNodeT(- getInt(a), n)
+  of mAbsI: result = foldAbs(getInt(a), n)
   of mZe8ToI, mZe8ToI64, mZe16ToI, mZe16ToI64, mZe32ToI64, mZeIToI64:
     # byte(-128) = 1...1..1000_0000'64 --> 0...0..1000_0000'64
     result = newIntNodeT(getInt(a) and (`shl`(1, getSize(a.typ) * 8) - 1), n)
   of mToU8: result = newIntNodeT(getInt(a) and 0x000000FF, n)
   of mToU16: result = newIntNodeT(getInt(a) and 0x0000FFFF, n)
   of mToU32: result = newIntNodeT(getInt(a) and 0x00000000FFFFFFFF'i64, n)
-  of mUnaryLt: result = newIntNodeT(getOrdValue(a) |-| 1, n)
-  of mSucc: result = newIntNodeT(getOrdValue(a) |+| getInt(b), n)
-  of mPred: result = newIntNodeT(getOrdValue(a) |-| getInt(b), n)
-  of mAddI: result = newIntNodeT(getInt(a) |+| getInt(b), n)
-  of mSubI: result = newIntNodeT(getInt(a) |-| getInt(b), n)
-  of mMulI: result = newIntNodeT(getInt(a) |*| getInt(b), n)
+  of mUnaryLt: result = foldSub(getOrdValue(a), 1, n)
+  of mSucc: result = foldAdd(getOrdValue(a), getInt(b), n)
+  of mPred: result = foldSub(getOrdValue(a), getInt(b), n)
+  of mAddI: result = foldAdd(getInt(a), getInt(b), n)
+  of mSubI: result = foldSub(getInt(a), getInt(b), n)
+  of mMulI: result = foldMul(getInt(a), getInt(b), n)
   of mMinI:
     if getInt(a) > getInt(b): result = newIntNodeT(getInt(b), n)
     else: result = newIntNodeT(getInt(a), n)
@@ -211,14 +267,8 @@ proc evalOp(m: TMagic, n, a, b, c: PNode): PNode =
     of tyInt64, tyInt, tyUInt..tyUInt64:
       result = newIntNodeT(`shr`(getInt(a), getInt(b)), n)
     else: internalError(n.info, "constant folding for shr")
-  of mDivI:
-    let y = getInt(b)
-    if y != 0:
-      result = newIntNodeT(`|div|`(getInt(a), y), n)
-  of mModI:
-    let y = getInt(b)
-    if y != 0:
-      result = newIntNodeT(`|mod|`(getInt(a), y), n)
+  of mDivI: result = foldDiv(getInt(a), getInt(b), n)
+  of mModI: result = foldMod(getInt(a), getInt(b), n)
   of mAddF64: result = newFloatNodeT(getFloat(a) + getFloat(b), n)
   of mSubF64: result = newFloatNodeT(getFloat(a) - getFloat(b), n)
   of mMulF64: result = newFloatNodeT(getFloat(a) * getFloat(b), n)
@@ -258,14 +308,8 @@ proc evalOp(m: TMagic, n, a, b, c: PNode): PNode =
   of mAddU: result = newIntNodeT(`+%`(getInt(a), getInt(b)), n)
   of mSubU: result = newIntNodeT(`-%`(getInt(a), getInt(b)), n)
   of mMulU: result = newIntNodeT(`*%`(getInt(a), getInt(b)), n)
-  of mModU:
-    let y = getInt(b)
-    if y != 0:
-      result = newIntNodeT(`%%`(getInt(a), y), n)
-  of mDivU:
-    let y = getInt(b)
-    if y != 0:
-      result = newIntNodeT(`/%`(getInt(a), y), n)
+  of mModU: result = foldModU(getInt(a), getInt(b), n)
+  of mDivU: result = foldDivU(getInt(a), getInt(b), n)
   of mLeSet: result = newIntNodeT(ord(containsSets(a, b)), n)
   of mEqSet: result = newIntNodeT(ord(equalSets(a, b)), n)
   of mLtSet: