diff options
Diffstat (limited to 'rod/semfold.nim')
-rwxr-xr-x | rod/semfold.nim | 277 |
1 files changed, 91 insertions, 186 deletions
diff --git a/rod/semfold.nim b/rod/semfold.nim index 78968d705..455ddf2b8 100755 --- a/rod/semfold.nim +++ b/rod/semfold.nim @@ -14,7 +14,7 @@ import strutils, lists, options, ast, astalgo, trees, treetab, nimsets, times, nversion, platform, math, msgs, os, condsyms, idents, rnimsyn, types -proc getConstExpr*(module: PSym, n: PNode): PNode +proc getConstExpr*(m: PSym, n: PNode): PNode # evaluates the constant expression or returns nil if it is no constant # expression proc evalOp*(m: TMagic, n, a, b, c: PNode): PNode @@ -22,10 +22,7 @@ proc leValueConv*(a, b: PNode): bool proc newIntNodeT*(intVal: BiggestInt, n: PNode): PNode proc newFloatNodeT*(floatVal: BiggestFloat, n: PNode): PNode proc newStrNodeT*(strVal: string, n: PNode): PNode -proc getInt*(a: PNode): biggestInt -proc getFloat*(a: PNode): biggestFloat -proc getStr*(a: PNode): string -proc getStrOrChar*(a: PNode): string + # implementation proc newIntNodeT(intVal: BiggestInt, n: PNode): PNode = @@ -46,35 +43,6 @@ proc newStrNodeT(strVal: string, n: PNode): PNode = result.typ = n.typ result.info = n.info -proc getInt(a: PNode): biggestInt = - case a.kind - of nkIntLit..nkInt64Lit: result = a.intVal - else: - internalError(a.info, "getInt") - result = 0 - -proc getFloat(a: PNode): biggestFloat = - case a.kind - of nkFloatLit..nkFloat64Lit: result = a.floatVal - else: - internalError(a.info, "getFloat") - result = 0.0 - -proc getStr(a: PNode): string = - case a.kind - of nkStrLit..nkTripleStrLit: result = a.strVal - else: - internalError(a.info, "getStr") - result = "" - -proc getStrOrChar(a: PNode): string = - case a.kind - of nkStrLit..nkTripleStrLit: result = a.strVal - of nkCharLit: result = chr(int(a.intVal)) & "" - else: - internalError(a.info, "getStrOrChar") - result = "" - proc enumValToString(a: PNode): string = var n: PNode @@ -93,56 +61,35 @@ proc evalOp(m: TMagic, n, a, b, c: PNode): PNode = # b and c may be nil result = nil case m - of mOrd: - result = newIntNodeT(getOrdValue(a), n) - of mChr: - result = newIntNodeT(getInt(a), n) - of mUnaryMinusI, mUnaryMinusI64: - result = newIntNodeT(- getInt(a), n) - of mUnaryMinusF64: - result = newFloatNodeT(- getFloat(a), n) - of mNot: - result = newIntNodeT(1 - getInt(a), n) - of mCard: - result = newIntNodeT(nimsets.cardSet(a), n) - of mBitnotI, mBitnotI64: - result = newIntNodeT(not getInt(a), n) - of mLengthStr: - result = newIntNodeT(len(getStr(a)), n) - of mLengthArray: - result = newIntNodeT(lengthOrd(a.typ), n) - of mLengthSeq, mLengthOpenArray: - result = newIntNodeT(sonsLen(a), n) # BUGFIX - of mUnaryPlusI, mUnaryPlusI64, mUnaryPlusF64: - result = a # throw `+` away + of mOrd: result = newIntNodeT(getOrdValue(a), n) + of mChr: result = newIntNodeT(getInt(a), n) + of mUnaryMinusI, mUnaryMinusI64: result = newIntNodeT(- getInt(a), n) + of mUnaryMinusF64: result = newFloatNodeT(- getFloat(a), n) + of mNot: result = newIntNodeT(1 - getInt(a), n) + of mCard: result = newIntNodeT(nimsets.cardSet(a), n) + of mBitnotI, mBitnotI64: result = newIntNodeT(not getInt(a), n) + of mLengthStr: result = newIntNodeT(len(getStr(a)), n) + of mLengthArray: result = newIntNodeT(lengthOrd(a.typ), n) + of mLengthSeq, mLengthOpenArray: result = newIntNodeT(sonsLen(a), n) # BUGFIX + of mUnaryPlusI, mUnaryPlusI64, mUnaryPlusF64: result = a # throw `+` away of mToFloat, mToBiggestFloat: result = newFloatNodeT(toFloat(int(getInt(a))), n) - of mToInt, mToBiggestInt: - result = newIntNodeT(system.toInt(getFloat(a)), n) - of mAbsF64: - result = newFloatNodeT(abs(getFloat(a)), n) + of mToInt, mToBiggestInt: result = newIntNodeT(system.toInt(getFloat(a)), n) + of mAbsF64: result = newFloatNodeT(abs(getFloat(a)), n) of mAbsI, mAbsI64: if getInt(a) >= 0: result = a else: result = newIntNodeT(- getInt(a), n) of mZe8ToI, mZe8ToI64, mZe16ToI, mZe16ToI64, mZe32ToI64, mZeIToI64: # byte(-128) = 1...1..1000_0000'64 --> 0...0..1000_0000'64 result = newIntNodeT(getInt(a) and (`shl`(1, getSize(a.typ) * 8) - 1), n) - of mToU8: - result = newIntNodeT(getInt(a) and 0x000000FF, n) - of mToU16: - result = newIntNodeT(getInt(a) and 0x0000FFFF, n) - of mToU32: - result = newIntNodeT(getInt(a) and 0x00000000FFFFFFFF'i64, n) - of mSucc: - result = newIntNodeT(getOrdValue(a) + getInt(b), n) - of mPred: - result = newIntNodeT(getOrdValue(a) - getInt(b), n) - of mAddI, mAddI64: - result = newIntNodeT(getInt(a) + getInt(b), n) - of mSubI, mSubI64: - result = newIntNodeT(getInt(a) - getInt(b), n) - of mMulI, mMulI64: - result = newIntNodeT(getInt(a) * getInt(b), n) + of mToU8: result = newIntNodeT(getInt(a) and 0x000000FF, n) + of mToU16: result = newIntNodeT(getInt(a) and 0x0000FFFF, n) + of mToU32: result = newIntNodeT(getInt(a) and 0x00000000FFFFFFFF'i64, n) + of mSucc: result = newIntNodeT(getOrdValue(a) + getInt(b), n) + of mPred: result = newIntNodeT(getOrdValue(a) - getInt(b), n) + of mAddI, mAddI64: result = newIntNodeT(getInt(a) + getInt(b), n) + of mSubI, mSubI64: result = newIntNodeT(getInt(a) - getInt(b), n) + of mMulI, mMulI64: result = newIntNodeT(getInt(a) * getInt(b), n) of mMinI, mMinI64: if getInt(a) > getInt(b): result = newIntNodeT(getInt(b), n) else: result = newIntNodeT(getInt(a), n) @@ -163,16 +110,11 @@ proc evalOp(m: TMagic, n, a, b, c: PNode): PNode = of tyInt32: result = newIntNodeT(int32(getInt(a)) shr int32(getInt(b)), n) of tyInt64, tyInt: result = newIntNodeT(`shr`(getInt(a), getInt(b)), n) else: InternalError(n.info, "constant folding for shl") - of mDivI, mDivI64: - result = newIntNodeT(getInt(a) div getInt(b), n) - of mModI, mModI64: - result = newIntNodeT(getInt(a) mod getInt(b), n) - of mAddF64: - result = newFloatNodeT(getFloat(a) + getFloat(b), n) - of mSubF64: - result = newFloatNodeT(getFloat(a) - getFloat(b), n) - of mMulF64: - result = newFloatNodeT(getFloat(a) * getFloat(b), n) + of mDivI, mDivI64: result = newIntNodeT(getInt(a) div getInt(b), n) + of mModI, mModI64: result = newIntNodeT(getInt(a) mod getInt(b), n) + of mAddF64: result = newFloatNodeT(getFloat(a) + getFloat(b), n) + of mSubF64: result = newFloatNodeT(getFloat(a) - getFloat(b), n) + of mMulF64: result = newFloatNodeT(getFloat(a) * getFloat(b), n) of mDivF64: if getFloat(b) == 0.0: if getFloat(a) == 0.0: result = newFloatNodeT(NaN, n) @@ -185,50 +127,33 @@ proc evalOp(m: TMagic, n, a, b, c: PNode): PNode = of mMinF64: if getFloat(a) > getFloat(b): result = newFloatNodeT(getFloat(b), n) else: result = newFloatNodeT(getFloat(a), n) - of mIsNil: - result = newIntNodeT(ord(a.kind == nkNilLit), n) + of mIsNil: result = newIntNodeT(ord(a.kind == nkNilLit), n) of mLtI, mLtI64, mLtB, mLtEnum, mLtCh: result = newIntNodeT(ord(getOrdValue(a) < getOrdValue(b)), n) of mLeI, mLeI64, mLeB, mLeEnum, mLeCh: result = newIntNodeT(ord(getOrdValue(a) <= getOrdValue(b)), n) of mEqI, mEqI64, mEqB, mEqEnum, mEqCh: - result = newIntNodeT(ord(getOrdValue(a) == getOrdValue(b)), n) # operators for floats - of mLtF64: - result = newIntNodeT(ord(getFloat(a) < getFloat(b)), n) - of mLeF64: - result = newIntNodeT(ord(getFloat(a) <= getFloat(b)), n) - of mEqF64: - result = newIntNodeT(ord(getFloat(a) == getFloat(b)), n) # operators for strings - of mLtStr: - result = newIntNodeT(ord(getStr(a) < getStr(b)), n) - of mLeStr: - result = newIntNodeT(ord(getStr(a) <= getStr(b)), n) - of mEqStr: - result = newIntNodeT(ord(getStr(a) == getStr(b)), n) + result = newIntNodeT(ord(getOrdValue(a) == getOrdValue(b)), n) + of mLtF64: result = newIntNodeT(ord(getFloat(a) < getFloat(b)), n) + of mLeF64: result = newIntNodeT(ord(getFloat(a) <= getFloat(b)), n) + of mEqF64: result = newIntNodeT(ord(getFloat(a) == getFloat(b)), n) + of mLtStr: result = newIntNodeT(ord(getStr(a) < getStr(b)), n) + of mLeStr: result = newIntNodeT(ord(getStr(a) <= getStr(b)), n) + of mEqStr: result = newIntNodeT(ord(getStr(a) == getStr(b)), n) of mLtU, mLtU64: result = newIntNodeT(ord(`<%`(getOrdValue(a), getOrdValue(b))), n) of mLeU, mLeU64: result = newIntNodeT(ord(`<=%`(getOrdValue(a), getOrdValue(b))), n) - of mBitandI, mBitandI64, mAnd: - result = newIntNodeT(getInt(a) and getInt(b), n) - of mBitorI, mBitorI64, mOr: - result = newIntNodeT(getInt(a) or getInt(b), n) - of mBitxorI, mBitxorI64, mXor: - result = newIntNodeT(getInt(a) xor getInt(b), n) - of mAddU, mAddU64: - result = newIntNodeT(`+%`(getInt(a), getInt(b)), n) - of mSubU, mSubU64: - result = newIntNodeT(`-%`(getInt(a), getInt(b)), n) - of mMulU, mMulU64: - result = newIntNodeT(`*%`(getInt(a), getInt(b)), n) - of mModU, mModU64: - result = newIntNodeT(`%%`(getInt(a), getInt(b)), n) - of mDivU, mDivU64: - result = newIntNodeT(`/%`(getInt(a), getInt(b)), n) - of mLeSet: - result = newIntNodeT(Ord(containsSets(a, b)), n) - of mEqSet: - result = newIntNodeT(Ord(equalSets(a, b)), n) + of mBitandI, mBitandI64, mAnd: result = newIntNodeT(getInt(a) and getInt(b), n) + of mBitorI, mBitorI64, mOr: result = newIntNodeT(getInt(a) or getInt(b), n) + of mBitxorI, mBitxorI64, mXor: result = newIntNodeT(getInt(a) xor getInt(b), n) + of mAddU, mAddU64: result = newIntNodeT(`+%`(getInt(a), getInt(b)), n) + of mSubU, mSubU64: result = newIntNodeT(`-%`(getInt(a), getInt(b)), n) + of mMulU, mMulU64: result = newIntNodeT(`*%`(getInt(a), getInt(b)), n) + of mModU, mModU64: result = newIntNodeT(`%%`(getInt(a), getInt(b)), n) + of mDivU, mDivU64: result = newIntNodeT(`/%`(getInt(a), getInt(b)), n) + of mLeSet: result = newIntNodeT(Ord(containsSets(a, b)), n) + of mEqSet: result = newIntNodeT(Ord(equalSets(a, b)), n) of mLtSet: result = newIntNodeT(Ord(containsSets(a, b) and not equalSets(a, b)), n) of mMulSet: @@ -243,32 +168,24 @@ proc evalOp(m: TMagic, n, a, b, c: PNode): PNode = of mSymDiffSet: result = nimsets.symdiffSets(a, b) result.info = n.info - of mConStrStr: - result = newStrNodeT(getStrOrChar(a) & getStrOrChar(b), n) - of mInSet: - result = newIntNodeT(Ord(inSet(a, b)), n) + of mConStrStr: result = newStrNodeT(getStrOrChar(a) & getStrOrChar(b), n) + of mInSet: result = newIntNodeT(Ord(inSet(a, b)), n) of mRepr: # BUGFIX: we cannot eval mRepr here. But this means that it is not # available for interpretation. I don't know how to fix this. #result := newStrNodeT(renderTree(a, {@set}[renderNoComments]), n); - of mIntToStr, mInt64ToStr: - result = newStrNodeT($(getOrdValue(a)), n) + of mIntToStr, mInt64ToStr: result = newStrNodeT($(getOrdValue(a)), n) of mBoolToStr: if getOrdValue(a) == 0: result = newStrNodeT("false", n) else: result = newStrNodeT("true", n) - of mCopyStr: - result = newStrNodeT(copy(getStr(a), int(getOrdValue(b)) + 0), n) + of mCopyStr: result = newStrNodeT(copy(getStr(a), int(getOrdValue(b))), n) of mCopyStrLast: - result = newStrNodeT(copy(getStr(a), int(getOrdValue(b)) + 0, - int(getOrdValue(c)) + 0), n) - of mFloatToStr: - result = newStrNodeT($(getFloat(a)), n) - of mCStrToStr, mCharToStr: - result = newStrNodeT(getStrOrChar(a), n) - of mStrToStr: - result = a - of mEnumToStr: - result = newStrNodeT(enumValToString(a), n) + result = newStrNodeT(copy(getStr(a), int(getOrdValue(b)), + int(getOrdValue(c))), n) + of mFloatToStr: result = newStrNodeT($(getFloat(a)), n) + of mCStrToStr, mCharToStr: result = newStrNodeT(getStrOrChar(a), n) + of mStrToStr: result = a + of mEnumToStr: result = newStrNodeT(enumValToString(a), n) of mArrToSeq: result = copyTree(a) result.typ = n.typ @@ -278,15 +195,13 @@ proc evalOp(m: TMagic, n, a, b, c: PNode): PNode = else: InternalError(a.info, "evalOp(" & $m & ')') proc getConstIfExpr(c: PSym, n: PNode): PNode = - var it, e: PNode result = nil for i in countup(0, sonsLen(n) - 1): - it = n.sons[i] + var it = n.sons[i] case it.kind of nkElifExpr: - e = getConstExpr(c, it.sons[0]) - if e == nil: - return nil + var e = getConstExpr(c, it.sons[0]) + if e == nil: return nil if getOrdValue(e) != 0: if result == nil: result = getConstExpr(c, it.sons[1]) @@ -297,10 +212,9 @@ proc getConstIfExpr(c: PSym, n: PNode): PNode = proc partialAndExpr(c: PSym, n: PNode): PNode = # partial evaluation - var a, b: PNode result = n - a = getConstExpr(c, n.sons[1]) - b = getConstExpr(c, n.sons[2]) + var a = getConstExpr(c, n.sons[1]) + var b = getConstExpr(c, n.sons[2]) if a != nil: if getInt(a) == 0: result = a elif b != nil: result = b @@ -311,10 +225,9 @@ proc partialAndExpr(c: PSym, n: PNode): PNode = proc partialOrExpr(c: PSym, n: PNode): PNode = # partial evaluation - var a, b: PNode result = n - a = getConstExpr(c, n.sons[1]) - b = getConstExpr(c, n.sons[2]) + var a = getConstExpr(c, n.sons[1]) + var b = getConstExpr(c, n.sons[2]) if a != nil: if getInt(a) != 0: result = a elif b != nil: result = b @@ -338,20 +251,16 @@ proc leValueConv(a, b: PNode): bool = else: InternalError(a.info, "leValueConv") else: InternalError(a.info, "leValueConv") -proc getConstExpr(module: PSym, n: PNode): PNode = - var - s: PSym - a, b, c: PNode +proc getConstExpr(m: PSym, n: PNode): PNode = result = nil case n.kind of nkSym: - s = n.sym + var s = n.sym if s.kind == skEnumField: result = newIntNodeT(s.position, n) elif (s.kind == skConst): case s.magic - of mIsMainModule: result = newIntNodeT(ord(sfMainModule in module.flags), - n) + of mIsMainModule: result = newIntNodeT(ord(sfMainModule in m.flags), n) of mCompileDate: result = newStrNodeT(times.getDateStr(), n) of mCompileTime: result = newStrNodeT(times.getClockStr(), n) of mNimrodVersion: result = newStrNodeT(VersionAsString, n) @@ -364,29 +273,28 @@ proc getConstExpr(module: PSym, n: PNode): PNode = of mNaN: result = newFloatNodeT(NaN, n) of mInf: result = newFloatNodeT(Inf, n) of mNegInf: result = newFloatNodeT(NegInf, n) - else: - result = copyTree(s.ast) # BUGFIX + else: result = copyTree(s.ast) elif s.kind in {skProc, skMethod}: # BUGFIX result = n of nkCharLit..nkNilLit: result = copyNode(n) of nkIfExpr: - result = getConstIfExpr(module, n) + result = getConstIfExpr(m, n) of nkCall, nkCommand, nkCallStrLit: if (n.sons[0].kind != nkSym): return - s = n.sons[0].sym + var s = n.sons[0].sym if (s.kind != skProc): return try: case s.magic of mNone: - return # XXX: if it has no sideEffect, it should be evaluated + return # XXX: if it has no sideEffect, it should be evaluated of mSizeOf: - a = n.sons[1] + var a = n.sons[1] if computeSize(a.typ) < 0: liMessage(a.info, errCannotEvalXBecauseIncompletelyDefined, "sizeof") if a.typ.kind in {tyArray, tyObject, tyTuple}: - result = nil # XXX: size computation for complex types - # is still wrong + result = nil + # XXX: size computation for complex types is still wrong else: result = newIntNodeT(getSize(a.typ), n) of mLow: @@ -396,13 +304,14 @@ proc getConstExpr(module: PSym, n: PNode): PNode = {tyOpenArray, tySequence, tyString}): result = newIntNodeT(lastOrd(skipTypes(n.sons[1].typ, abstractVar)), n) else: - a = getConstExpr(module, n.sons[1]) + var a = getConstExpr(m, n.sons[1]) + var b, c: PNode if a == nil: return if sonsLen(n) > 2: - b = getConstExpr(module, n.sons[2]) + b = getConstExpr(m, n.sons[2]) if b == nil: return if sonsLen(n) > 3: - c = getConstExpr(module, n.sons[3]) + c = getConstExpr(m, n.sons[3]) if c == nil: return else: b = nil @@ -412,22 +321,21 @@ proc getConstExpr(module: PSym, n: PNode): PNode = except EDivByZero: liMessage(n.info, errConstantDivisionByZero) of nkAddr: - a = getConstExpr(module, n.sons[0]) + var a = getConstExpr(m, n.sons[0]) if a != nil: result = n n.sons[0] = a of nkBracket: result = copyTree(n) for i in countup(0, sonsLen(n) - 1): - a = getConstExpr(module, n.sons[i]) - if a == nil: - return nil + var a = getConstExpr(m, n.sons[i]) + if a == nil: return nil result.sons[i] = a incl(result.flags, nfAllConst) of nkRange: - a = getConstExpr(module, n.sons[0]) + var a = getConstExpr(m, n.sons[0]) if a == nil: return - b = getConstExpr(module, n.sons[1]) + var b = getConstExpr(m, n.sons[1]) if b == nil: return result = copyNode(n) addSon(result, a) @@ -435,9 +343,8 @@ proc getConstExpr(module: PSym, n: PNode): PNode = of nkCurly: result = copyTree(n) for i in countup(0, sonsLen(n) - 1): - a = getConstExpr(module, n.sons[i]) - if a == nil: - return nil + var a = getConstExpr(m, n.sons[i]) + if a == nil: return nil result.sons[i] = a incl(result.flags, nfAllConst) of nkPar: @@ -445,19 +352,17 @@ proc getConstExpr(module: PSym, n: PNode): PNode = result = copyTree(n) if (sonsLen(n) > 0) and (n.sons[0].kind == nkExprColonExpr): for i in countup(0, sonsLen(n) - 1): - a = getConstExpr(module, n.sons[i].sons[1]) - if a == nil: - return nil + var a = getConstExpr(m, n.sons[i].sons[1]) + if a == nil: return nil result.sons[i].sons[1] = a else: for i in countup(0, sonsLen(n) - 1): - a = getConstExpr(module, n.sons[i]) - if a == nil: - return nil + var a = getConstExpr(m, n.sons[i]) + if a == nil: return nil result.sons[i] = a incl(result.flags, nfAllConst) of nkChckRangeF, nkChckRange64, nkChckRange: - a = getConstExpr(module, n.sons[0]) + var a = getConstExpr(m, n.sons[0]) if a == nil: return if leValueConv(n.sons[1], a) and leValueConv(a, n.sons[2]): result = a # a <= x and x <= b @@ -467,12 +372,12 @@ proc getConstExpr(module: PSym, n: PNode): PNode = msgKindToString(errIllegalConvFromXtoY), [typeToString(n.sons[0].typ), typeToString(n.typ)])) of nkStringToCString, nkCStringToString: - a = getConstExpr(module, n.sons[0]) + var a = getConstExpr(m, n.sons[0]) if a == nil: return result = a result.typ = n.typ of nkHiddenStdConv, nkHiddenSubConv, nkConv, nkCast: - a = getConstExpr(module, n.sons[1]) + var a = getConstExpr(m, n.sons[1]) if a == nil: return case skipTypes(n.typ, abstractRange).kind of tyInt..tyInt64: |