diff options
Diffstat (limited to 'rod/semfold.nim')
-rwxr-xr-x | rod/semfold.nim | 500 |
1 files changed, 500 insertions, 0 deletions
diff --git a/rod/semfold.nim b/rod/semfold.nim new file mode 100755 index 000000000..7d9cbfe52 --- /dev/null +++ b/rod/semfold.nim @@ -0,0 +1,500 @@ +# +# +# The Nimrod Compiler +# (c) Copyright 2009 Andreas Rumpf +# +# See the file "copying.txt", included in this +# distribution, for details about the copyright. +# + +# this module folds constants; used by semantic checking phase +# and evaluation phase + +import + strutils, lists, options, ast, astalgo, trees, treetab, nimsets, times, + nversion, platform, math, msgs, os, condsyms, idents, rnimsyn, types + +proc getConstExpr*(module: PSym, n: PNode): PNode + # evaluates the constant expression or returns nil if it is no constant + # expression +proc evalOp*(m: TMagic, n, a, b, c: PNode): PNode +proc leValueConv*(a, b: PNode): bool +proc newIntNodeT*(intVal: BiggestInt, n: PNode): PNode +proc newFloatNodeT*(floatVal: BiggestFloat, n: PNode): PNode +proc newStrNodeT*(strVal: string, n: PNode): PNode +proc getInt*(a: PNode): biggestInt +proc getFloat*(a: PNode): biggestFloat +proc getStr*(a: PNode): string +proc getStrOrChar*(a: PNode): string +# implementation + +proc newIntNodeT(intVal: BiggestInt, n: PNode): PNode = + if skipTypes(n.typ, abstractVarRange).kind == tyChar: + result = newIntNode(nkCharLit, intVal) + else: + result = newIntNode(nkIntLit, intVal) + result.typ = n.typ + result.info = n.info + +proc newFloatNodeT(floatVal: BiggestFloat, n: PNode): PNode = + result = newFloatNode(nkFloatLit, floatVal) + result.typ = n.typ + result.info = n.info + +proc newStrNodeT(strVal: string, n: PNode): PNode = + result = newStrNode(nkStrLit, strVal) + result.typ = n.typ + result.info = n.info + +proc getInt(a: PNode): biggestInt = + case a.kind + of nkIntLit..nkInt64Lit: result = a.intVal + else: + internalError(a.info, "getInt") + result = 0 + +proc getFloat(a: PNode): biggestFloat = + case a.kind + of nkFloatLit..nkFloat64Lit: result = a.floatVal + else: + internalError(a.info, "getFloat") + result = 0.0 + +proc getStr(a: PNode): string = + case a.kind + of nkStrLit..nkTripleStrLit: result = a.strVal + else: + internalError(a.info, "getStr") + result = "" + +proc getStrOrChar(a: PNode): string = + case a.kind + of nkStrLit..nkTripleStrLit: result = a.strVal + of nkCharLit: result = chr(int(a.intVal)) & "" + else: + internalError(a.info, "getStrOrChar") + result = "" + +proc enumValToString(a: PNode): string = + var + n: PNode + field: PSym + x: biggestInt + x = getInt(a) + n = skipTypes(a.typ, abstractInst).n + for i in countup(0, sonsLen(n) - 1): + if n.sons[i].kind != nkSym: InternalError(a.info, "enumValToString") + field = n.sons[i].sym + if field.position == x: + return field.name.s + InternalError(a.info, "no symbol for ordinal value: " & $(x)) + +proc evalOp(m: TMagic, n, a, b, c: PNode): PNode = + # b and c may be nil + result = nil + case m + of mOrd: + result = newIntNodeT(getOrdValue(a), n) + of mChr: + result = newIntNodeT(getInt(a), n) + of mUnaryMinusI, mUnaryMinusI64: + result = newIntNodeT(- getInt(a), n) + of mUnaryMinusF64: + result = newFloatNodeT(- getFloat(a), n) + of mNot: + result = newIntNodeT(1 - getInt(a), n) + of mCard: + result = newIntNodeT(nimsets.cardSet(a), n) + of mBitnotI, mBitnotI64: + result = newIntNodeT(not getInt(a), n) + of mLengthStr: + result = newIntNodeT(len(getStr(a)), n) + of mLengthArray: + result = newIntNodeT(lengthOrd(a.typ), n) + of mLengthSeq, mLengthOpenArray: + result = newIntNodeT(sonsLen(a), n) # BUGFIX + of mUnaryPlusI, mUnaryPlusI64, mUnaryPlusF64: + result = a # throw `+` away + of mToFloat, mToBiggestFloat: + result = newFloatNodeT(toFloat(int(getInt(a))), n) + of mToInt, mToBiggestInt: + result = newIntNodeT(system.toInt(getFloat(a)), n) + of mAbsF64: + result = newFloatNodeT(abs(getFloat(a)), n) + of mAbsI, mAbsI64: + if getInt(a) >= 0: result = a + else: result = newIntNodeT(- getInt(a), n) + of mZe8ToI, mZe8ToI64, mZe16ToI, mZe16ToI64, mZe32ToI64, mZeIToI64: + # byte(-128) = 1...1..1000_0000'64 --> 0...0..1000_0000'64 + result = newIntNodeT(getInt(a) and (`shl`(1, getSize(a.typ) * 8) - 1), n) + of mToU8: + result = newIntNodeT(getInt(a) and 0x000000FF, n) + of mToU16: + result = newIntNodeT(getInt(a) and 0x0000FFFF, n) + of mToU32: + result = newIntNodeT(getInt(a) and 0x00000000FFFFFFFF'i64, n) + of mSucc: + result = newIntNodeT(getOrdValue(a) + getInt(b), n) + of mPred: + result = newIntNodeT(getOrdValue(a) - getInt(b), n) + of mAddI, mAddI64: + result = newIntNodeT(getInt(a) + getInt(b), n) + of mSubI, mSubI64: + result = newIntNodeT(getInt(a) - getInt(b), n) + of mMulI, mMulI64: + result = newIntNodeT(getInt(a) * getInt(b), n) + of mMinI, mMinI64: + if getInt(a) > getInt(b): result = newIntNodeT(getInt(b), n) + else: result = newIntNodeT(getInt(a), n) + of mMaxI, mMaxI64: + if getInt(a) > getInt(b): result = newIntNodeT(getInt(a), n) + else: result = newIntNodeT(getInt(b), n) + of mShlI, mShlI64: + case skipTypes(n.typ, abstractRange).kind + of tyInt8: result = newIntNodeT(int8(getInt(a)) shl int8(getInt(b)), n) + of tyInt16: result = newIntNodeT(int16(getInt(a)) shl int16(getInt(b)), n) + of tyInt32: result = newIntNodeT(int32(getInt(a)) shl int32(getInt(b)), n) + of tyInt64, tyInt: result = newIntNodeT(`shl`(getInt(a), getInt(b)), n) + else: InternalError(n.info, "constant folding for shl") + of mShrI, mShrI64: + case skipTypes(n.typ, abstractRange).kind + of tyInt8: result = newIntNodeT(int8(getInt(a)) shr int8(getInt(b)), n) + of tyInt16: result = newIntNodeT(int16(getInt(a)) shr int16(getInt(b)), n) + of tyInt32: result = newIntNodeT(int32(getInt(a)) shr int32(getInt(b)), n) + of tyInt64, tyInt: result = newIntNodeT(`shr`(getInt(a), getInt(b)), n) + else: InternalError(n.info, "constant folding for shl") + of mDivI, mDivI64: + result = newIntNodeT(getInt(a) div getInt(b), n) + of mModI, mModI64: + result = newIntNodeT(getInt(a) mod getInt(b), n) + of mAddF64: + result = newFloatNodeT(getFloat(a) + getFloat(b), n) + of mSubF64: + result = newFloatNodeT(getFloat(a) - getFloat(b), n) + of mMulF64: + result = newFloatNodeT(getFloat(a) * getFloat(b), n) + of mDivF64: + if getFloat(b) == 0.0: + if getFloat(a) == 0.0: result = newFloatNodeT(NaN, n) + else: result = newFloatNodeT(Inf, n) + else: + result = newFloatNodeT(getFloat(a) / getFloat(b), n) + of mMaxF64: + if getFloat(a) > getFloat(b): result = newFloatNodeT(getFloat(a), n) + else: result = newFloatNodeT(getFloat(b), n) + of mMinF64: + if getFloat(a) > getFloat(b): result = newFloatNodeT(getFloat(b), n) + else: result = newFloatNodeT(getFloat(a), n) + of mIsNil: + result = newIntNodeT(ord(a.kind == nkNilLit), n) + of mLtI, mLtI64, mLtB, mLtEnum, mLtCh: + result = newIntNodeT(ord(getOrdValue(a) < getOrdValue(b)), n) + of mLeI, mLeI64, mLeB, mLeEnum, mLeCh: + result = newIntNodeT(ord(getOrdValue(a) <= getOrdValue(b)), n) + of mEqI, mEqI64, mEqB, mEqEnum, mEqCh: + result = newIntNodeT(ord(getOrdValue(a) == getOrdValue(b)), n) # operators for floats + of mLtF64: + result = newIntNodeT(ord(getFloat(a) < getFloat(b)), n) + of mLeF64: + result = newIntNodeT(ord(getFloat(a) <= getFloat(b)), n) + of mEqF64: + result = newIntNodeT(ord(getFloat(a) == getFloat(b)), n) # operators for strings + of mLtStr: + result = newIntNodeT(ord(getStr(a) < getStr(b)), n) + of mLeStr: + result = newIntNodeT(ord(getStr(a) <= getStr(b)), n) + of mEqStr: + result = newIntNodeT(ord(getStr(a) == getStr(b)), n) + of mLtU, mLtU64: + result = newIntNodeT(ord(`<%`(getOrdValue(a), getOrdValue(b))), n) + of mLeU, mLeU64: + result = newIntNodeT(ord(`<=%`(getOrdValue(a), getOrdValue(b))), n) + of mBitandI, mBitandI64, mAnd: + result = newIntNodeT(getInt(a) and getInt(b), n) + of mBitorI, mBitorI64, mOr: + result = newIntNodeT(getInt(a) or getInt(b), n) + of mBitxorI, mBitxorI64, mXor: + result = newIntNodeT(getInt(a) xor getInt(b), n) + of mAddU, mAddU64: + result = newIntNodeT(`+%`(getInt(a), getInt(b)), n) + of mSubU, mSubU64: + result = newIntNodeT(`-%`(getInt(a), getInt(b)), n) + of mMulU, mMulU64: + result = newIntNodeT(`*%`(getInt(a), getInt(b)), n) + of mModU, mModU64: + result = newIntNodeT(`%%`(getInt(a), getInt(b)), n) + of mDivU, mDivU64: + result = newIntNodeT(`/%`(getInt(a), getInt(b)), n) + of mLeSet: + result = newIntNodeT(Ord(containsSets(a, b)), n) + of mEqSet: + result = newIntNodeT(Ord(equalSets(a, b)), n) + of mLtSet: + result = newIntNodeT(Ord(containsSets(a, b) and not equalSets(a, b)), n) + of mMulSet: + result = nimsets.intersectSets(a, b) + result.info = n.info + of mPlusSet: + result = nimsets.unionSets(a, b) + result.info = n.info + of mMinusSet: + result = nimsets.diffSets(a, b) + result.info = n.info + of mSymDiffSet: + result = nimsets.symdiffSets(a, b) + result.info = n.info + of mConStrStr: + result = newStrNodeT(getStrOrChar(a) & getStrOrChar(b), n) + of mInSet: + result = newIntNodeT(Ord(inSet(a, b)), n) + of mRepr: + # BUGFIX: we cannot eval mRepr here. But this means that it is not + # available for interpretation. I don't know how to fix this. + #result := newStrNodeT(renderTree(a, {@set}[renderNoComments]), n); + of mIntToStr, mInt64ToStr: + result = newStrNodeT($(getOrdValue(a)), n) + of mBoolToStr: + if getOrdValue(a) == 0: result = newStrNodeT("false", n) + else: result = newStrNodeT("true", n) + of mCopyStr: + result = newStrNodeT(copy(getStr(a), int(getOrdValue(b)) + 0), n) + of mCopyStrLast: + result = newStrNodeT(copy(getStr(a), int(getOrdValue(b)) + 0, + int(getOrdValue(c)) + 0), n) + of mFloatToStr: + result = newStrNodeT($(getFloat(a)), n) + of mCStrToStr, mCharToStr: + result = newStrNodeT(getStrOrChar(a), n) + of mStrToStr: + result = a + of mEnumToStr: + result = newStrNodeT(enumValToString(a), n) + of mArrToSeq: + result = copyTree(a) + result.typ = n.typ + of mNewString, mExit, mInc, ast.mDec, mEcho, mAssert, mSwap, mAppendStrCh, + mAppendStrStr, mAppendSeqElem, mSetLengthStr, mSetLengthSeq, mNLen..mNError: + nil + else: InternalError(a.info, "evalOp(" & magicToStr[m] & ')') + +proc getConstIfExpr(c: PSym, n: PNode): PNode = + var it, e: PNode + result = nil + for i in countup(0, sonsLen(n) - 1): + it = n.sons[i] + case it.kind + of nkElifExpr: + e = getConstExpr(c, it.sons[0]) + if e == nil: + return nil + if getOrdValue(e) != 0: + if result == nil: + result = getConstExpr(c, it.sons[1]) + if result == nil: return + of nkElseExpr: + if result == nil: result = getConstExpr(c, it.sons[0]) + else: internalError(it.info, "getConstIfExpr()") + +proc partialAndExpr(c: PSym, n: PNode): PNode = + # partial evaluation + var a, b: PNode + result = n + a = getConstExpr(c, n.sons[1]) + b = getConstExpr(c, n.sons[2]) + if a != nil: + if getInt(a) == 0: result = a + elif b != nil: result = b + else: result = n.sons[2] + elif b != nil: + if getInt(b) == 0: result = b + else: result = n.sons[1] + +proc partialOrExpr(c: PSym, n: PNode): PNode = + # partial evaluation + var a, b: PNode + result = n + a = getConstExpr(c, n.sons[1]) + b = getConstExpr(c, n.sons[2]) + if a != nil: + if getInt(a) != 0: result = a + elif b != nil: result = b + else: result = n.sons[2] + elif b != nil: + if getInt(b) != 0: result = b + else: result = n.sons[1] + +proc leValueConv(a, b: PNode): bool = + result = false + case a.kind + of nkCharLit..nkInt64Lit: + case b.kind + of nkCharLit..nkInt64Lit: result = a.intVal <= b.intVal + of nkFloatLit..nkFloat64Lit: result = a.intVal <= round(b.floatVal) + else: InternalError(a.info, "leValueConv") + of nkFloatLit..nkFloat64Lit: + case b.kind + of nkFloatLit..nkFloat64Lit: result = a.floatVal <= b.floatVal + of nkCharLit..nkInt64Lit: result = a.floatVal <= toFloat(int(b.intVal)) + else: InternalError(a.info, "leValueConv") + else: InternalError(a.info, "leValueConv") + +proc getConstExpr(module: PSym, n: PNode): PNode = + var + s: PSym + a, b, c: PNode + result = nil + case n.kind + of nkSym: + s = n.sym + if s.kind == skEnumField: + result = newIntNodeT(s.position, n) + elif (s.kind == skConst): + case s.magic + of mIsMainModule: result = newIntNodeT(ord(sfMainModule in module.flags), + n) + of mCompileDate: result = newStrNodeT(times.getDateStr(), n) + of mCompileTime: result = newStrNodeT(times.getClockStr(), n) + of mNimrodVersion: result = newStrNodeT(VersionAsString, n) + of mNimrodMajor: result = newIntNodeT(VersionMajor, n) + of mNimrodMinor: result = newIntNodeT(VersionMinor, n) + of mNimrodPatch: result = newIntNodeT(VersionPatch, n) + of mCpuEndian: result = newIntNodeT(ord(CPU[targetCPU].endian), n) + of mHostOS: result = newStrNodeT(toLower(platform.OS[targetOS].name), n) + of mHostCPU: result = newStrNodeT(toLower(platform.CPU[targetCPU].name), n) + of mNaN: result = newFloatNodeT(NaN, n) + of mInf: result = newFloatNodeT(Inf, n) + of mNegInf: result = newFloatNodeT(NegInf, n) + else: + result = copyTree(s.ast) # BUGFIX + elif s.kind in {skProc, skMethod}: # BUGFIX + result = n + of nkCharLit..nkNilLit: + result = copyNode(n) + of nkIfExpr: + result = getConstIfExpr(module, n) + of nkCall, nkCommand, nkCallStrLit: + if (n.sons[0].kind != nkSym): return + s = n.sons[0].sym + if (s.kind != skProc): return + try: + case s.magic + of mNone: + return # XXX: if it has no sideEffect, it should be evaluated + of mSizeOf: + a = n.sons[1] + if computeSize(a.typ) < 0: + liMessage(a.info, errCannotEvalXBecauseIncompletelyDefined, "sizeof") + if a.typ.kind in {tyArray, tyObject, tyTuple}: + result = nil # XXX: size computation for complex types + # is still wrong + else: + result = newIntNodeT(getSize(a.typ), n) + of mLow: + result = newIntNodeT(firstOrd(n.sons[1].typ), n) + of mHigh: + if not (skipTypes(n.sons[1].typ, abstractVar).kind in + {tyOpenArray, tySequence, tyString}): + result = newIntNodeT(lastOrd(skipTypes(n.sons[1].typ, abstractVar)), n) + else: + a = getConstExpr(module, n.sons[1]) + if a == nil: return + if sonsLen(n) > 2: + b = getConstExpr(module, n.sons[2]) + if b == nil: return + if sonsLen(n) > 3: + c = getConstExpr(module, n.sons[3]) + if c == nil: return + else: + b = nil + result = evalOp(s.magic, n, a, b, c) + except EOverflow: + liMessage(n.info, errOverOrUnderflow) + except EDivByZero: + liMessage(n.info, errConstantDivisionByZero) + of nkAddr: + a = getConstExpr(module, n.sons[0]) + if a != nil: + result = n + n.sons[0] = a + of nkBracket: + result = copyTree(n) + for i in countup(0, sonsLen(n) - 1): + a = getConstExpr(module, n.sons[i]) + if a == nil: + return nil + result.sons[i] = a + incl(result.flags, nfAllConst) + of nkRange: + a = getConstExpr(module, n.sons[0]) + if a == nil: return + b = getConstExpr(module, n.sons[1]) + if b == nil: return + result = copyNode(n) + addSon(result, a) + addSon(result, b) + of nkCurly: + result = copyTree(n) + for i in countup(0, sonsLen(n) - 1): + a = getConstExpr(module, n.sons[i]) + if a == nil: + return nil + result.sons[i] = a + incl(result.flags, nfAllConst) + of nkPar: + # tuple constructor + result = copyTree(n) + if (sonsLen(n) > 0) and (n.sons[0].kind == nkExprColonExpr): + for i in countup(0, sonsLen(n) - 1): + a = getConstExpr(module, n.sons[i].sons[1]) + if a == nil: + return nil + result.sons[i].sons[1] = a + else: + for i in countup(0, sonsLen(n) - 1): + a = getConstExpr(module, n.sons[i]) + if a == nil: + return nil + result.sons[i] = a + incl(result.flags, nfAllConst) + of nkChckRangeF, nkChckRange64, nkChckRange: + a = getConstExpr(module, n.sons[0]) + if a == nil: return + if leValueConv(n.sons[1], a) and leValueConv(a, n.sons[2]): + result = a # a <= x and x <= b + result.typ = n.typ + else: + liMessage(n.info, errGenerated, `%`( + msgKindToString(errIllegalConvFromXtoY), + [typeToString(n.sons[0].typ), typeToString(n.typ)])) + of nkStringToCString, nkCStringToString: + a = getConstExpr(module, n.sons[0]) + if a == nil: return + result = a + result.typ = n.typ + of nkHiddenStdConv, nkHiddenSubConv, nkConv, nkCast: + a = getConstExpr(module, n.sons[1]) + if a == nil: return + case skipTypes(n.typ, abstractRange).kind + of tyInt..tyInt64: + case skipTypes(a.typ, abstractRange).kind + of tyFloat..tyFloat64: result = newIntNodeT(system.toInt(getFloat(a)), n) + of tyChar: result = newIntNodeT(getOrdValue(a), n) + else: + result = a + result.typ = n.typ + of tyFloat..tyFloat64: + case skipTypes(a.typ, abstractRange).kind + of tyInt..tyInt64, tyEnum, tyBool, tyChar: + result = newFloatNodeT(toFloat(int(getOrdValue(a))), n) + else: + result = a + result.typ = n.typ + of tyOpenArray, tyProc: + nil + else: + #n.sons[1] := a; + #result := n; + result = a + result.typ = n.typ + else: + nil |