summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--compiler/semfold.nim96
-rw-r--r--compiler/types.nim12
-rw-r--r--tests/misc/tsemfold.nim23
-rw-r--r--tests/typerel/ttypelessemptyset.nim2
4 files changed, 100 insertions, 33 deletions
diff --git a/compiler/semfold.nim b/compiler/semfold.nim
index 6fcc9a0a4..c4d79a4a3 100644
--- a/compiler/semfold.nim
+++ b/compiler/semfold.nim
@@ -13,7 +13,7 @@
 import
   strutils, options, ast, astalgo, trees, treetab, nimsets, times,
   nversion, platform, math, msgs, os, condsyms, idents, renderer, types,
-  commands, magicsys, saturate
+  commands, magicsys
 
 proc getConstExpr*(m: PSym, n: PNode): PNode
   # evaluates the constant expression or returns nil if it is no constant
@@ -24,6 +24,63 @@ proc newIntNodeT*(intVal: BiggestInt, n: PNode): PNode
 proc newFloatNodeT(floatVal: BiggestFloat, n: PNode): PNode
 proc newStrNodeT*(strVal: string, n: PNode): PNode
 
+proc checkInRange(n: PNode, res: BiggestInt): bool =
+  if res in firstOrd(n.typ)..lastOrd(n.typ):
+    result = true
+
+proc foldAdd(a, b: BiggestInt, n: PNode): PNode =
+  let res = a +% b
+  if ((res xor a) >= 0'i64 or (res xor b) >= 0'i64) and
+      checkInRange(n, res):
+    result = newIntNodeT(res, n)     
+
+proc foldSub*(a, b: BiggestInt, n: PNode): PNode =
+  let res = a -% b
+  if ((res xor a) >= 0'i64 or (res xor not b) >= 0'i64) and
+      checkInRange(n, res):
+    result = newIntNodeT(res, n)
+
+proc foldAbs*(a: BiggestInt, n: PNode): PNode =
+  if a != firstOrd(n.typ):
+    result = newIntNodeT(a, n)
+  
+proc foldMod*(a, b: BiggestInt, n: PNode): PNode =
+  if b != 0'i64:
+    result = newIntNodeT(a mod b, n)
+
+proc foldModU*(a, b: BiggestInt, n: PNode): PNode =
+  if b != 0'i64:
+    result = newIntNodeT(a %% b, n)
+
+proc foldDiv*(a, b: BiggestInt, n: PNode): PNode =
+  if b != 0'i64 and (a != firstOrd(n.typ) or b != -1'i64):
+    result = newIntNodeT(a div b, n)
+
+proc foldDivU*(a, b: BiggestInt, n: PNode): PNode =
+  if b != 0'i64:
+    result = newIntNodeT(a /% b, n)
+
+proc foldMul*(a, b: BiggestInt, n: PNode): PNode =
+  let res = a *% b
+  let floatProd = toBiggestFloat(a) * toBiggestFloat(b)
+  let resAsFloat = toBiggestFloat(res)
+
+  # Fast path for normal case: small multiplicands, and no info
+  # is lost in either method.
+  if resAsFloat == floatProd and checkInRange(n, res):
+    return newIntNodeT(res, n)
+
+  # Somebody somewhere lost info. Close enough, or way off? Note
+  # that a != 0 and b != 0 (else resAsFloat == floatProd == 0).
+  # The difference either is or isn't significant compared to the
+  # true value (of which floatProd is a good approximation).
+
+  # abs(diff)/abs(prod) <= 1/32 iff
+  #   32 * abs(diff) <= abs(prod) -- 5 good bits is "close enough"
+  if 32.0 * abs(resAsFloat - floatProd) <= abs(floatProd) and
+      checkInRange(n, res):
+    return newIntNodeT(res, n)
+
 # implementation
 
 proc newIntNodeT(intVal: BiggestInt, n: PNode): PNode =
@@ -172,23 +229,22 @@ proc evalOp(m: TMagic, n, a, b, c: PNode): PNode =
   of mUnaryPlusI, mUnaryPlusF64: result = a # throw `+` away
   of mToFloat, mToBiggestFloat:
     result = newFloatNodeT(toFloat(int(getInt(a))), n)
+  # XXX: Hides overflow/underflow
   of mToInt, mToBiggestInt: result = newIntNodeT(system.toInt(getFloat(a)), n)
   of mAbsF64: result = newFloatNodeT(abs(getFloat(a)), n)
-  of mAbsI:
-    if getInt(a) >= 0: result = a
-    else: result = newIntNodeT(- getInt(a), n)
+  of mAbsI: result = foldAbs(getInt(a), n)
   of mZe8ToI, mZe8ToI64, mZe16ToI, mZe16ToI64, mZe32ToI64, mZeIToI64:
     # byte(-128) = 1...1..1000_0000'64 --> 0...0..1000_0000'64
     result = newIntNodeT(getInt(a) and (`shl`(1, getSize(a.typ) * 8) - 1), n)
   of mToU8: result = newIntNodeT(getInt(a) and 0x000000FF, n)
   of mToU16: result = newIntNodeT(getInt(a) and 0x0000FFFF, n)
   of mToU32: result = newIntNodeT(getInt(a) and 0x00000000FFFFFFFF'i64, n)
-  of mUnaryLt: result = newIntNodeT(getOrdValue(a) |-| 1, n)
-  of mSucc: result = newIntNodeT(getOrdValue(a) |+| getInt(b), n)
-  of mPred: result = newIntNodeT(getOrdValue(a) |-| getInt(b), n)
-  of mAddI: result = newIntNodeT(getInt(a) |+| getInt(b), n)
-  of mSubI: result = newIntNodeT(getInt(a) |-| getInt(b), n)
-  of mMulI: result = newIntNodeT(getInt(a) |*| getInt(b), n)
+  of mUnaryLt: result = foldSub(getOrdValue(a), 1, n)
+  of mSucc: result = foldAdd(getOrdValue(a), getInt(b), n)
+  of mPred: result = foldSub(getOrdValue(a), getInt(b), n)
+  of mAddI: result = foldAdd(getInt(a), getInt(b), n)
+  of mSubI: result = foldSub(getInt(a), getInt(b), n)
+  of mMulI: result = foldMul(getInt(a), getInt(b), n)
   of mMinI:
     if getInt(a) > getInt(b): result = newIntNodeT(getInt(b), n)
     else: result = newIntNodeT(getInt(a), n)
@@ -211,14 +267,8 @@ proc evalOp(m: TMagic, n, a, b, c: PNode): PNode =
     of tyInt64, tyInt, tyUInt..tyUInt64:
       result = newIntNodeT(`shr`(getInt(a), getInt(b)), n)
     else: internalError(n.info, "constant folding for shr")
-  of mDivI:
-    let y = getInt(b)
-    if y != 0:
-      result = newIntNodeT(`|div|`(getInt(a), y), n)
-  of mModI:
-    let y = getInt(b)
-    if y != 0:
-      result = newIntNodeT(`|mod|`(getInt(a), y), n)
+  of mDivI: result = foldDiv(getInt(a), getInt(b), n)
+  of mModI: result = foldMod(getInt(a), getInt(b), n)
   of mAddF64: result = newFloatNodeT(getFloat(a) + getFloat(b), n)
   of mSubF64: result = newFloatNodeT(getFloat(a) - getFloat(b), n)
   of mMulF64: result = newFloatNodeT(getFloat(a) * getFloat(b), n)
@@ -258,14 +308,8 @@ proc evalOp(m: TMagic, n, a, b, c: PNode): PNode =
   of mAddU: result = newIntNodeT(`+%`(getInt(a), getInt(b)), n)
   of mSubU: result = newIntNodeT(`-%`(getInt(a), getInt(b)), n)
   of mMulU: result = newIntNodeT(`*%`(getInt(a), getInt(b)), n)
-  of mModU:
-    let y = getInt(b)
-    if y != 0:
-      result = newIntNodeT(`%%`(getInt(a), y), n)
-  of mDivU:
-    let y = getInt(b)
-    if y != 0:
-      result = newIntNodeT(`/%`(getInt(a), y), n)
+  of mModU: result = foldModU(getInt(a), getInt(b), n)
+  of mDivU: result = foldDivU(getInt(a), getInt(b), n)
   of mLeSet: result = newIntNodeT(ord(containsSets(a, b)), n)
   of mEqSet: result = newIntNodeT(ord(equalSets(a, b)), n)
   of mLtSet:
diff --git a/compiler/types.nim b/compiler/types.nim
index 5d60fa9b4..edf4e5b46 100644
--- a/compiler/types.nim
+++ b/compiler/types.nim
@@ -612,13 +612,13 @@ proc firstOrd*(t: PType): BiggestInt =
     else:
       assert(t.n.sons[0].kind == nkSym)
       result = t.n.sons[0].sym.position
-  of tyGenericInst, tyDistinct, tyTypeDesc, tyAlias:
+  of tyGenericInst, tyDistinct, tyTypeDesc, tyAlias, tyStatic:
     result = firstOrd(lastSon(t))
   of tyOrdinal:
     if t.len > 0: result = firstOrd(lastSon(t))
-    else: internalError("invalid kind for first(" & $t.kind & ')')
+    else: internalError("invalid kind for firstOrd(" & $t.kind & ')')
   else:
-    internalError("invalid kind for first(" & $t.kind & ')')
+    internalError("invalid kind for firstOrd(" & $t.kind & ')')
     result = 0
 
 proc lastOrd*(t: PType; fixedUnsigned = false): BiggestInt =
@@ -651,14 +651,14 @@ proc lastOrd*(t: PType; fixedUnsigned = false): BiggestInt =
   of tyEnum:
     assert(t.n.sons[sonsLen(t.n) - 1].kind == nkSym)
     result = t.n.sons[sonsLen(t.n) - 1].sym.position
-  of tyGenericInst, tyDistinct, tyTypeDesc, tyAlias:
+  of tyGenericInst, tyDistinct, tyTypeDesc, tyAlias, tyStatic:
     result = lastOrd(lastSon(t))
   of tyProxy: result = 0
   of tyOrdinal:
     if t.len > 0: result = lastOrd(lastSon(t))
-    else: internalError("invalid kind for last(" & $t.kind & ')')
+    else: internalError("invalid kind for lastOrd(" & $t.kind & ')')
   else:
-    internalError("invalid kind for last(" & $t.kind & ')')
+    internalError("invalid kind for lastOrd(" & $t.kind & ')')
     result = 0
 
 proc lengthOrd*(t: PType): BiggestInt =
diff --git a/tests/misc/tsemfold.nim b/tests/misc/tsemfold.nim
new file mode 100644
index 000000000..18c282d9e
--- /dev/null
+++ b/tests/misc/tsemfold.nim
@@ -0,0 +1,23 @@
+discard """
+  action: run
+"""
+
+doAssertRaises(OverflowError): discard low(int8) - 1'i8
+doAssertRaises(OverflowError): discard high(int8) + 1'i8
+doAssertRaises(OverflowError): discard abs(low(int8))
+doAssertRaises(DivByZeroError): discard 1 mod 0
+doAssertRaises(DivByZeroError): discard 1 div 0
+doAssertRaises(OverflowError): discard low(int8) div -1'i8
+
+doAssertRaises(OverflowError): discard low(int64) - 1'i64
+doAssertRaises(OverflowError): discard high(int64) + 1'i64
+
+type E = enum eA, eB
+doAssertRaises(OverflowError): discard eA.pred
+doAssertRaises(OverflowError): discard eB.succ
+
+doAssertRaises(OverflowError): discard low(int8) * -1
+doAssertRaises(OverflowError): discard low(int64) * -1
+doAssertRaises(OverflowError): discard high(int8) * 2
+doAssertRaises(OverflowError): discard high(int64) * 2
+
diff --git a/tests/typerel/ttypelessemptyset.nim b/tests/typerel/ttypelessemptyset.nim
index 3e171387b..5f49c33fd 100644
--- a/tests/typerel/ttypelessemptyset.nim
+++ b/tests/typerel/ttypelessemptyset.nim
@@ -1,5 +1,5 @@
 discard """
-  errormsg: "internal error: invalid kind for last(tyEmpty)"
+  errormsg: "internal error: invalid kind for lastOrd(tyEmpty)"
 """
 var q = false
 discard (if q: {} else: {})