diff options
Diffstat (limited to 'compiler/semfold.nim')
-rw-r--r-- | compiler/semfold.nim | 90 |
1 files changed, 51 insertions, 39 deletions
diff --git a/compiler/semfold.nim b/compiler/semfold.nim index 52931bc2b..729222220 100644 --- a/compiler/semfold.nim +++ b/compiler/semfold.nim @@ -118,6 +118,9 @@ proc makeRange(typ: PType, first, last: BiggestInt): PType = let lowerNode = newIntNode(nkIntLit, minA) if typ.kind == tyInt and minA == maxA: result = getIntLitType(lowerNode) + elif typ.kind in {tyUint, tyUInt64}: + # these are not ordinal types, so you get no subrange type for these: + result = typ else: var n = newNode(nkRange) addSon(n, lowerNode) @@ -135,8 +138,9 @@ proc makeRangeF(typ: PType, first, last: BiggestFloat): PType = addSonSkipIntLit(result, skipTypes(typ, {tyRange})) proc getIntervalType*(m: TMagic, n: PNode): PType = - # Nimrod requires interval arithmetic for ``range`` types. Lots of tedious + # Nim requires interval arithmetic for ``range`` types. Lots of tedious # work but the feature is very nice for reducing explicit conversions. + const ordIntLit = {nkIntLit..nkUInt64Lit} result = n.typ template commutativeOp(opr: expr) {.immediate.} = @@ -175,7 +179,7 @@ proc getIntervalType*(m: TMagic, n: PNode): PType = else: result = makeRangeF(a, abs(getFloat(a.n.sons[1])), abs(getFloat(a.n.sons[0]))) - of mAbsI, mAbsI64: + of mAbsI: let a = n.sons[1].typ if isIntRange(a): if a.n[0].intVal <= 0: @@ -196,32 +200,32 @@ proc getIntervalType*(m: TMagic, n: PNode): PType = if isIntRange(a) and isIntLit(b): result = makeRange(a, pickMinInt(n.sons[1]) |-| pickMinInt(n.sons[2]), pickMaxInt(n.sons[1]) |-| pickMaxInt(n.sons[2])) - of mAddI, mAddI64, mAddU: + of mAddI, mAddU: commutativeOp(`|+|`) - of mMulI, mMulI64, mMulU: + of mMulI, mMulU: commutativeOp(`|*|`) - of mSubI, mSubI64, mSubU: + of mSubI, mSubU: binaryOp(`|-|`) - of mBitandI, mBitandI64: + of mBitandI: # since uint64 is still not even valid for 'range' (since it's no ordinal # yet), we exclude it from the list (see bug #1638) for now: var a = n.sons[1] var b = n.sons[2] # symmetrical: - if b.kind notin {nkIntLit..nkUInt32Lit}: swap(a, b) - if b.kind in {nkIntLit..nkUInt32Lit}: + if b.kind notin ordIntLit: swap(a, b) + if b.kind in ordIntLit: let x = b.intVal|+|1 if (x and -x) == x and x >= 0: result = makeRange(a.typ, 0, b.intVal) of mModU: let a = n.sons[1] let b = n.sons[2] - if b.kind in {nkIntLit..nkUInt32Lit}: + if a.kind in ordIntLit: if b.intVal >= 0: result = makeRange(a.typ, 0, b.intVal-1) else: result = makeRange(a.typ, b.intVal+1, 0) - of mModI, mModI64: + of mModI: # so ... if you ever wondered about modulo's signedness; this defines it: let a = n.sons[1] let b = n.sons[2] @@ -230,17 +234,17 @@ proc getIntervalType*(m: TMagic, n: PNode): PType = result = makeRange(a.typ, -(b.intVal-1), b.intVal-1) else: result = makeRange(a.typ, b.intVal+1, -(b.intVal+1)) - of mDivI, mDivI64, mDivU: + of mDivI, mDivU: binaryOp(`|div|`) - of mMinI, mMinI64: + of mMinI: commutativeOp(min) - of mMaxI, mMaxI64: + of mMaxI: commutativeOp(max) else: discard discard """ - mShlI, mShlI64, - mShrI, mShrI64, mAddF64, mSubF64, mMulF64, mDivF64, mMaxF64, mMinF64 + mShlI, + mShrI, mAddF64, mSubF64, mMulF64, mDivF64, mMaxF64, mMinF64 """ proc evalIs(n, a: PNode): PNode = @@ -281,16 +285,20 @@ proc evalOp(m: TMagic, n, a, b, c: PNode): PNode = of mUnaryMinusF64: result = newFloatNodeT(- getFloat(a), n) of mNot: result = newIntNodeT(1 - getInt(a), n) of mCard: result = newIntNodeT(nimsets.cardSet(a), n) - of mBitnotI, mBitnotI64: result = newIntNodeT(not getInt(a), n) - of mLengthStr: result = newIntNodeT(len(getStr(a)), n) + of mBitnotI: result = newIntNodeT(not getInt(a), n) + of mLengthStr, mXLenStr: + if a.kind == nkNilLit: result = newIntNodeT(0, n) + else: result = newIntNodeT(len(getStr(a)), n) of mLengthArray: result = newIntNodeT(lengthOrd(a.typ), n) - of mLengthSeq, mLengthOpenArray: result = newIntNodeT(sonsLen(a), n) # BUGFIX - of mUnaryPlusI, mUnaryPlusI64, mUnaryPlusF64: result = a # throw `+` away + of mLengthSeq, mLengthOpenArray, mXLenSeq: + if a.kind == nkNilLit: result = newIntNodeT(0, n) + else: result = newIntNodeT(sonsLen(a), n) # BUGFIX + of mUnaryPlusI, mUnaryPlusF64: result = a # throw `+` away of mToFloat, mToBiggestFloat: result = newFloatNodeT(toFloat(int(getInt(a))), n) of mToInt, mToBiggestInt: result = newIntNodeT(system.toInt(getFloat(a)), n) of mAbsF64: result = newFloatNodeT(abs(getFloat(a)), n) - of mAbsI, mAbsI64: + of mAbsI: if getInt(a) >= 0: result = a else: result = newIntNodeT(- getInt(a), n) of mZe8ToI, mZe8ToI64, mZe16ToI, mZe16ToI64, mZe32ToI64, mZeIToI64: @@ -302,16 +310,16 @@ proc evalOp(m: TMagic, n, a, b, c: PNode): PNode = of mUnaryLt: result = newIntNodeT(getOrdValue(a) - 1, n) of mSucc: result = newIntNodeT(getOrdValue(a) + getInt(b), n) of mPred: result = newIntNodeT(getOrdValue(a) - getInt(b), n) - of mAddI, mAddI64: result = newIntNodeT(getInt(a) + getInt(b), n) - of mSubI, mSubI64: result = newIntNodeT(getInt(a) - getInt(b), n) - of mMulI, mMulI64: result = newIntNodeT(getInt(a) * getInt(b), n) - of mMinI, mMinI64: + of mAddI: result = newIntNodeT(getInt(a) + getInt(b), n) + of mSubI: result = newIntNodeT(getInt(a) - getInt(b), n) + of mMulI: result = newIntNodeT(getInt(a) * getInt(b), n) + of mMinI: if getInt(a) > getInt(b): result = newIntNodeT(getInt(b), n) else: result = newIntNodeT(getInt(a), n) - of mMaxI, mMaxI64: + of mMaxI: if getInt(a) > getInt(b): result = newIntNodeT(getInt(a), n) else: result = newIntNodeT(getInt(b), n) - of mShlI, mShlI64: + of mShlI: case skipTypes(n.typ, abstractRange).kind of tyInt8: result = newIntNodeT(int8(getInt(a)) shl int8(getInt(b)), n) of tyInt16: result = newIntNodeT(int16(getInt(a)) shl int16(getInt(b)), n) @@ -319,7 +327,7 @@ proc evalOp(m: TMagic, n, a, b, c: PNode): PNode = of tyInt64, tyInt, tyUInt..tyUInt64: result = newIntNodeT(`shl`(getInt(a), getInt(b)), n) else: internalError(n.info, "constant folding for shl") - of mShrI, mShrI64: + of mShrI: case skipTypes(n.typ, abstractRange).kind of tyInt8: result = newIntNodeT(int8(getInt(a)) shr int8(getInt(b)), n) of tyInt16: result = newIntNodeT(int16(getInt(a)) shr int16(getInt(b)), n) @@ -327,11 +335,11 @@ proc evalOp(m: TMagic, n, a, b, c: PNode): PNode = of tyInt64, tyInt, tyUInt..tyUInt64: result = newIntNodeT(`shr`(getInt(a), getInt(b)), n) else: internalError(n.info, "constant folding for shr") - of mDivI, mDivI64: + of mDivI: let y = getInt(b) if y != 0: result = newIntNodeT(getInt(a) div y, n) - of mModI, mModI64: + of mModI: let y = getInt(b) if y != 0: result = newIntNodeT(getInt(a) mod y, n) @@ -351,11 +359,11 @@ proc evalOp(m: TMagic, n, a, b, c: PNode): PNode = if getFloat(a) > getFloat(b): result = newFloatNodeT(getFloat(b), n) else: result = newFloatNodeT(getFloat(a), n) of mIsNil: result = newIntNodeT(ord(a.kind == nkNilLit), n) - of mLtI, mLtI64, mLtB, mLtEnum, mLtCh: + of mLtI, mLtB, mLtEnum, mLtCh: result = newIntNodeT(ord(getOrdValue(a) < getOrdValue(b)), n) - of mLeI, mLeI64, mLeB, mLeEnum, mLeCh: + of mLeI, mLeB, mLeEnum, mLeCh: result = newIntNodeT(ord(getOrdValue(a) <= getOrdValue(b)), n) - of mEqI, mEqI64, mEqB, mEqEnum, mEqCh: + of mEqI, mEqB, mEqEnum, mEqCh: result = newIntNodeT(ord(getOrdValue(a) == getOrdValue(b)), n) of mLtF64: result = newIntNodeT(ord(getFloat(a) < getFloat(b)), n) of mLeF64: result = newIntNodeT(ord(getFloat(a) <= getFloat(b)), n) @@ -367,9 +375,9 @@ proc evalOp(m: TMagic, n, a, b, c: PNode): PNode = result = newIntNodeT(ord(`<%`(getOrdValue(a), getOrdValue(b))), n) of mLeU, mLeU64: result = newIntNodeT(ord(`<=%`(getOrdValue(a), getOrdValue(b))), n) - of mBitandI, mBitandI64, mAnd: result = newIntNodeT(a.getInt and b.getInt, n) - of mBitorI, mBitorI64, mOr: result = newIntNodeT(getInt(a) or getInt(b), n) - of mBitxorI, mBitxorI64, mXor: result = newIntNodeT(a.getInt xor b.getInt, n) + of mBitandI, mAnd: result = newIntNodeT(a.getInt and b.getInt, n) + of mBitorI, mOr: result = newIntNodeT(getInt(a) or getInt(b), n) + of mBitxorI, mXor: result = newIntNodeT(a.getInt xor b.getInt, n) of mAddU: result = newIntNodeT(`+%`(getInt(a), getInt(b)), n) of mSubU: result = newIntNodeT(`-%`(getInt(a), getInt(b)), n) of mMulU: result = newIntNodeT(`*%`(getInt(a), getInt(b)), n) @@ -426,8 +434,12 @@ proc evalOp(m: TMagic, n, a, b, c: PNode): PNode = mExit, mInc, ast.mDec, mEcho, mSwap, mAppendStrCh, mAppendStrStr, mAppendSeqElem, mSetLengthStr, mSetLengthSeq, mParseExprToAst, mParseStmtToAst, mExpandToAst, mTypeTrait, mDotDot, - mNLen..mNError, mEqRef, mSlurp, mStaticExec, mNGenSym, mSpawn, mParallel: + mNLen..mNError, mEqRef, mSlurp, mStaticExec, mNGenSym, mSpawn, + mParallel, mPlugin, mGetTypeInfo: discard + of mEqProc: + result = newIntNodeT(ord( + exprStructuralEquivalent(a, b, strictSymEquality=true)), n) else: internalError(a.info, "evalOp(" & $m & ')') proc getConstIfExpr(c: PSym, n: PNode): PNode = @@ -519,7 +531,7 @@ proc rangeCheck(n: PNode, value: BiggestInt) = proc foldConv*(n, a: PNode; check = false): PNode = # XXX range checks? case skipTypes(n.typ, abstractRange).kind - of tyInt..tyInt64: + of tyInt..tyInt64, tyUInt..tyUInt64: case skipTypes(a.typ, abstractRange).kind of tyFloat..tyFloat64: result = newIntNodeT(int(getFloat(a)), n) @@ -531,7 +543,7 @@ proc foldConv*(n, a: PNode; check = false): PNode = of tyFloat..tyFloat64: case skipTypes(a.typ, abstractRange).kind of tyInt..tyInt64, tyEnum, tyBool, tyChar: - result = newFloatNodeT(toFloat(int(getOrdValue(a))), n) + result = newFloatNodeT(toBiggestFloat(getOrdValue(a)), n) else: result = a result.typ = n.typ @@ -646,7 +658,7 @@ proc getConstExpr(m: PSym, n: PNode): PNode = result = copyNode(n) of nkIfExpr: result = getConstIfExpr(m, n) - of nkCall, nkCommand, nkCallStrLit, nkPrefix, nkInfix: + of nkCallKinds: if n.sons[0].kind != nkSym: return var s = n.sons[0].sym if s.kind != skProc: return |