summary refs log tree commit diff stats
path: root/compiler/semfold.nim
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/semfold.nim')
-rw-r--r--compiler/semfold.nim119
1 files changed, 65 insertions, 54 deletions
diff --git a/compiler/semfold.nim b/compiler/semfold.nim
index 2e7179673..5fe4e3299 100644
--- a/compiler/semfold.nim
+++ b/compiler/semfold.nim
@@ -118,6 +118,9 @@ proc makeRange(typ: PType, first, last: BiggestInt): PType =
   let lowerNode = newIntNode(nkIntLit, minA)
   if typ.kind == tyInt and minA == maxA:
     result = getIntLitType(lowerNode)
+  elif typ.kind in {tyUint, tyUInt64}:
+    # these are not ordinal types, so you get no subrange type for these:
+    result = typ
   else:
     var n = newNode(nkRange)
     addSon(n, lowerNode)
@@ -135,8 +138,9 @@ proc makeRangeF(typ: PType, first, last: BiggestFloat): PType =
   addSonSkipIntLit(result, skipTypes(typ, {tyRange}))
 
 proc getIntervalType*(m: TMagic, n: PNode): PType =
-  # Nimrod requires interval arithmetic for ``range`` types. Lots of tedious
+  # Nim requires interval arithmetic for ``range`` types. Lots of tedious
   # work but the feature is very nice for reducing explicit conversions.
+  const ordIntLit = {nkIntLit..nkUInt64Lit}
   result = n.typ
 
   template commutativeOp(opr: expr) {.immediate.} =
@@ -170,13 +174,19 @@ proc getIntervalType*(m: TMagic, n: PNode): PType =
     let a = n.sons[1].typ
     if isFloatRange(a):
       # abs(-5.. 1) == (1..5)
-      result = makeRangeF(a, abs(getFloat(a.n.sons[1])),
-                             abs(getFloat(a.n.sons[0])))
-  of mAbsI, mAbsI64:
+      if a.n[0].floatVal <= 0.0:
+        result = makeRangeF(a, 0.0, abs(getFloat(a.n.sons[0])))
+      else:
+        result = makeRangeF(a, abs(getFloat(a.n.sons[1])),
+                               abs(getFloat(a.n.sons[0])))
+  of mAbsI:
     let a = n.sons[1].typ
     if isIntRange(a):
-      result = makeRange(a, `|abs|`(getInt(a.n.sons[1])),
-                            `|abs|`(getInt(a.n.sons[0])))
+      if a.n[0].intVal <= 0:
+        result = makeRange(a, 0, `|abs|`(getInt(a.n.sons[0])))
+      else:
+        result = makeRange(a, `|abs|`(getInt(a.n.sons[1])),
+                              `|abs|`(getInt(a.n.sons[0])))
   of mSucc:
     let a = n.sons[1].typ
     let b = n.sons[2].typ
@@ -190,32 +200,32 @@ proc getIntervalType*(m: TMagic, n: PNode): PType =
     if isIntRange(a) and isIntLit(b):
       result = makeRange(a, pickMinInt(n.sons[1]) |-| pickMinInt(n.sons[2]),
                             pickMaxInt(n.sons[1]) |-| pickMaxInt(n.sons[2]))
-  of mAddI, mAddI64, mAddU:
+  of mAddI, mAddU:
     commutativeOp(`|+|`)
-  of mMulI, mMulI64, mMulU:
+  of mMulI, mMulU:
     commutativeOp(`|*|`)
-  of mSubI, mSubI64, mSubU:
+  of mSubI, mSubU:
     binaryOp(`|-|`)
-  of mBitandI, mBitandI64:
+  of mBitandI:
     # since uint64 is still not even valid for 'range' (since it's no ordinal
     # yet), we exclude it from the list (see bug #1638) for now:
     var a = n.sons[1]
     var b = n.sons[2]
     # symmetrical:
-    if b.kind notin {nkIntLit..nkUInt32Lit}: swap(a, b)
-    if b.kind in {nkIntLit..nkUInt32Lit}:
+    if b.kind notin ordIntLit: swap(a, b)
+    if b.kind in ordIntLit:
       let x = b.intVal|+|1
       if (x and -x) == x and x >= 0:
         result = makeRange(a.typ, 0, b.intVal)
   of mModU:
     let a = n.sons[1]
     let b = n.sons[2]
-    if b.kind in {nkIntLit..nkUInt32Lit}:
+    if a.kind in ordIntLit:
       if b.intVal >= 0:
         result = makeRange(a.typ, 0, b.intVal-1)
       else:
         result = makeRange(a.typ, b.intVal+1, 0)
-  of mModI, mModI64:
+  of mModI:
     # so ... if you ever wondered about modulo's signedness; this defines it:
     let a = n.sons[1]
     let b = n.sons[2]
@@ -224,17 +234,17 @@ proc getIntervalType*(m: TMagic, n: PNode): PType =
         result = makeRange(a.typ, -(b.intVal-1), b.intVal-1)
       else:
         result = makeRange(a.typ, b.intVal+1, -(b.intVal+1))
-  of mDivI, mDivI64, mDivU:
+  of mDivI, mDivU:
     binaryOp(`|div|`)
-  of mMinI, mMinI64:
+  of mMinI:
     commutativeOp(min)
-  of mMaxI, mMaxI64:
+  of mMaxI:
     commutativeOp(max)
   else: discard
 
 discard """
-  mShlI, mShlI64,
-  mShrI, mShrI64, mAddF64, mSubF64, mMulF64, mDivF64, mMaxF64, mMinF64
+  mShlI,
+  mShrI, mAddF64, mSubF64, mMulF64, mDivF64, mMaxF64, mMinF64
 """
 
 proc evalIs(n, a: PNode): PNode =
@@ -275,16 +285,20 @@ proc evalOp(m: TMagic, n, a, b, c: PNode): PNode =
   of mUnaryMinusF64: result = newFloatNodeT(- getFloat(a), n)
   of mNot: result = newIntNodeT(1 - getInt(a), n)
   of mCard: result = newIntNodeT(nimsets.cardSet(a), n)
-  of mBitnotI, mBitnotI64: result = newIntNodeT(not getInt(a), n)
-  of mLengthStr: result = newIntNodeT(len(getStr(a)), n)
+  of mBitnotI: result = newIntNodeT(not getInt(a), n)
+  of mLengthStr, mXLenStr:
+    if a.kind == nkNilLit: result = newIntNodeT(0, n)
+    else: result = newIntNodeT(len(getStr(a)), n)
   of mLengthArray: result = newIntNodeT(lengthOrd(a.typ), n)
-  of mLengthSeq, mLengthOpenArray: result = newIntNodeT(sonsLen(a), n) # BUGFIX
-  of mUnaryPlusI, mUnaryPlusI64, mUnaryPlusF64: result = a # throw `+` away
+  of mLengthSeq, mLengthOpenArray, mXLenSeq:
+    if a.kind == nkNilLit: result = newIntNodeT(0, n)
+    else: result = newIntNodeT(sonsLen(a), n) # BUGFIX
+  of mUnaryPlusI, mUnaryPlusF64: result = a # throw `+` away
   of mToFloat, mToBiggestFloat:
     result = newFloatNodeT(toFloat(int(getInt(a))), n)
   of mToInt, mToBiggestInt: result = newIntNodeT(system.toInt(getFloat(a)), n)
   of mAbsF64: result = newFloatNodeT(abs(getFloat(a)), n)
-  of mAbsI, mAbsI64:
+  of mAbsI:
     if getInt(a) >= 0: result = a
     else: result = newIntNodeT(- getInt(a), n)
   of mZe8ToI, mZe8ToI64, mZe16ToI, mZe16ToI64, mZe32ToI64, mZeIToI64:
@@ -293,19 +307,19 @@ proc evalOp(m: TMagic, n, a, b, c: PNode): PNode =
   of mToU8: result = newIntNodeT(getInt(a) and 0x000000FF, n)
   of mToU16: result = newIntNodeT(getInt(a) and 0x0000FFFF, n)
   of mToU32: result = newIntNodeT(getInt(a) and 0x00000000FFFFFFFF'i64, n)
-  of mUnaryLt: result = newIntNodeT(getOrdValue(a) - 1, n)
-  of mSucc: result = newIntNodeT(getOrdValue(a) + getInt(b), n)
-  of mPred: result = newIntNodeT(getOrdValue(a) - getInt(b), n)
-  of mAddI, mAddI64: result = newIntNodeT(getInt(a) + getInt(b), n)
-  of mSubI, mSubI64: result = newIntNodeT(getInt(a) - getInt(b), n)
-  of mMulI, mMulI64: result = newIntNodeT(getInt(a) * getInt(b), n)
-  of mMinI, mMinI64:
+  of mUnaryLt: result = newIntNodeT(getOrdValue(a) |-| 1, n)
+  of mSucc: result = newIntNodeT(getOrdValue(a) |+| getInt(b), n)
+  of mPred: result = newIntNodeT(getOrdValue(a) |-| getInt(b), n)
+  of mAddI: result = newIntNodeT(getInt(a) |+| getInt(b), n)
+  of mSubI: result = newIntNodeT(getInt(a) |-| getInt(b), n)
+  of mMulI: result = newIntNodeT(getInt(a) |*| getInt(b), n)
+  of mMinI:
     if getInt(a) > getInt(b): result = newIntNodeT(getInt(b), n)
     else: result = newIntNodeT(getInt(a), n)
-  of mMaxI, mMaxI64:
+  of mMaxI:
     if getInt(a) > getInt(b): result = newIntNodeT(getInt(a), n)
     else: result = newIntNodeT(getInt(b), n)
-  of mShlI, mShlI64:
+  of mShlI:
     case skipTypes(n.typ, abstractRange).kind
     of tyInt8: result = newIntNodeT(int8(getInt(a)) shl int8(getInt(b)), n)
     of tyInt16: result = newIntNodeT(int16(getInt(a)) shl int16(getInt(b)), n)
@@ -313,7 +327,7 @@ proc evalOp(m: TMagic, n, a, b, c: PNode): PNode =
     of tyInt64, tyInt, tyUInt..tyUInt64:
       result = newIntNodeT(`shl`(getInt(a), getInt(b)), n)
     else: internalError(n.info, "constant folding for shl")
-  of mShrI, mShrI64:
+  of mShrI:
     case skipTypes(n.typ, abstractRange).kind
     of tyInt8: result = newIntNodeT(int8(getInt(a)) shr int8(getInt(b)), n)
     of tyInt16: result = newIntNodeT(int16(getInt(a)) shr int16(getInt(b)), n)
@@ -321,14 +335,14 @@ proc evalOp(m: TMagic, n, a, b, c: PNode): PNode =
     of tyInt64, tyInt, tyUInt..tyUInt64:
       result = newIntNodeT(`shr`(getInt(a), getInt(b)), n)
     else: internalError(n.info, "constant folding for shr")
-  of mDivI, mDivI64:
+  of mDivI:
     let y = getInt(b)
     if y != 0:
-      result = newIntNodeT(getInt(a) div y, n)
-  of mModI, mModI64:
+      result = newIntNodeT(`|div|`(getInt(a), y), n)
+  of mModI:
     let y = getInt(b)
     if y != 0:
-      result = newIntNodeT(getInt(a) mod y, n)
+      result = newIntNodeT(`|mod|`(getInt(a), y), n)
   of mAddF64: result = newFloatNodeT(getFloat(a) + getFloat(b), n)
   of mSubF64: result = newFloatNodeT(getFloat(a) - getFloat(b), n)
   of mMulF64: result = newFloatNodeT(getFloat(a) * getFloat(b), n)
@@ -345,11 +359,11 @@ proc evalOp(m: TMagic, n, a, b, c: PNode): PNode =
     if getFloat(a) > getFloat(b): result = newFloatNodeT(getFloat(b), n)
     else: result = newFloatNodeT(getFloat(a), n)
   of mIsNil: result = newIntNodeT(ord(a.kind == nkNilLit), n)
-  of mLtI, mLtI64, mLtB, mLtEnum, mLtCh:
+  of mLtI, mLtB, mLtEnum, mLtCh:
     result = newIntNodeT(ord(getOrdValue(a) < getOrdValue(b)), n)
-  of mLeI, mLeI64, mLeB, mLeEnum, mLeCh:
+  of mLeI, mLeB, mLeEnum, mLeCh:
     result = newIntNodeT(ord(getOrdValue(a) <= getOrdValue(b)), n)
-  of mEqI, mEqI64, mEqB, mEqEnum, mEqCh:
+  of mEqI, mEqB, mEqEnum, mEqCh:
     result = newIntNodeT(ord(getOrdValue(a) == getOrdValue(b)), n)
   of mLtF64: result = newIntNodeT(ord(getFloat(a) < getFloat(b)), n)
   of mLeF64: result = newIntNodeT(ord(getFloat(a) <= getFloat(b)), n)
@@ -361,9 +375,9 @@ proc evalOp(m: TMagic, n, a, b, c: PNode): PNode =
     result = newIntNodeT(ord(`<%`(getOrdValue(a), getOrdValue(b))), n)
   of mLeU, mLeU64:
     result = newIntNodeT(ord(`<=%`(getOrdValue(a), getOrdValue(b))), n)
-  of mBitandI, mBitandI64, mAnd: result = newIntNodeT(a.getInt and b.getInt, n)
-  of mBitorI, mBitorI64, mOr: result = newIntNodeT(getInt(a) or getInt(b), n)
-  of mBitxorI, mBitxorI64, mXor: result = newIntNodeT(a.getInt xor b.getInt, n)
+  of mBitandI, mAnd: result = newIntNodeT(a.getInt and b.getInt, n)
+  of mBitorI, mOr: result = newIntNodeT(getInt(a) or getInt(b), n)
+  of mBitxorI, mXor: result = newIntNodeT(a.getInt xor b.getInt, n)
   of mAddU: result = newIntNodeT(`+%`(getInt(a), getInt(b)), n)
   of mSubU: result = newIntNodeT(`-%`(getInt(a), getInt(b)), n)
   of mMulU: result = newIntNodeT(`*%`(getInt(a), getInt(b)), n)
@@ -416,13 +430,10 @@ proc evalOp(m: TMagic, n, a, b, c: PNode): PNode =
   of mCompileOptionArg:
     result = newIntNodeT(ord(
       testCompileOptionArg(getStr(a), getStr(b), n.info)), n)
-  of mNewString, mNewStringOfCap,
-     mExit, mInc, ast.mDec, mEcho, mSwap, mAppendStrCh,
-     mAppendStrStr, mAppendSeqElem, mSetLengthStr, mSetLengthSeq,
-     mParseExprToAst, mParseStmtToAst, mExpandToAst, mTypeTrait, mDotDot,
-     mNLen..mNError, mEqRef, mSlurp, mStaticExec, mNGenSym, mSpawn, mParallel:
-    discard
-  else: internalError(a.info, "evalOp(" & $m & ')')
+  of mEqProc:
+    result = newIntNodeT(ord(
+        exprStructuralEquivalent(a, b, strictSymEquality=true)), n)
+  else: discard
 
 proc getConstIfExpr(c: PSym, n: PNode): PNode =
   result = nil
@@ -513,7 +524,7 @@ proc rangeCheck(n: PNode, value: BiggestInt) =
 proc foldConv*(n, a: PNode; check = false): PNode =
   # XXX range checks?
   case skipTypes(n.typ, abstractRange).kind
-  of tyInt..tyInt64:
+  of tyInt..tyInt64, tyUInt..tyUInt64:
     case skipTypes(a.typ, abstractRange).kind
     of tyFloat..tyFloat64:
       result = newIntNodeT(int(getFloat(a)), n)
@@ -525,7 +536,7 @@ proc foldConv*(n, a: PNode; check = false): PNode =
   of tyFloat..tyFloat64:
     case skipTypes(a.typ, abstractRange).kind
     of tyInt..tyInt64, tyEnum, tyBool, tyChar:
-      result = newFloatNodeT(toFloat(int(getOrdValue(a))), n)
+      result = newFloatNodeT(toBiggestFloat(getOrdValue(a)), n)
     else:
       result = a
       result.typ = n.typ
@@ -640,7 +651,7 @@ proc getConstExpr(m: PSym, n: PNode): PNode =
     result = copyNode(n)
   of nkIfExpr:
     result = getConstIfExpr(m, n)
-  of nkCall, nkCommand, nkCallStrLit, nkPrefix, nkInfix:
+  of nkCallKinds:
     if n.sons[0].kind != nkSym: return
     var s = n.sons[0].sym
     if s.kind != skProc: return