diff options
author | Zahary Karadjov <zahary@gmail.com> | 2016-08-12 03:25:59 +0300 |
---|---|---|
committer | Zahary Karadjov <zahary@gmail.com> | 2017-03-24 16:58:15 +0200 |
commit | 0f2c4be1299fc99aeea2011c57240c8cfabd83c3 (patch) | |
tree | 64fa8e58c3c9745b54c127f8fdc3b897ff3896a0 /compiler/sigmatch.nim | |
parent | 0b0a3e5f203f6b21f3790a6cd50ceeaa8786badc (diff) | |
download | Nim-0f2c4be1299fc99aeea2011c57240c8cfabd83c3.tar.gz |
infer static parameters even when more complicated arithmetic is involved
Diffstat (limited to 'compiler/sigmatch.nim')
-rw-r--r-- | compiler/sigmatch.nim | 143 |
1 files changed, 114 insertions, 29 deletions
diff --git a/compiler/sigmatch.nim b/compiler/sigmatch.nim index 162385e6d..ca9cdcaf8 100644 --- a/compiler/sigmatch.nim +++ b/compiler/sigmatch.nim @@ -681,16 +681,125 @@ proc maybeSkipDistinct(t: PType, callee: PSym): PType = else: result = t -proc tryResolvingStaticExpr(c: var TCandidate, n: PNode): PNode = +proc tryResolvingStaticExpr(c: var TCandidate, n: PNode, + allowUnresolved = false): PNode = # Consider this example: # type Value[N: static[int]] = object # proc foo[N](a: Value[N], r: range[0..(N-1)]) # Here, N-1 will be initially nkStaticExpr that can be evaluated only after # N is bound to a concrete value during the matching of the first param. # This proc is used to evaluate such static expressions. - let instantiated = replaceTypesInBody(c.c, c.bindings, n, nil) + let instantiated = replaceTypesInBody(c.c, c.bindings, n, nil, + allowMetaTypes = allowUnresolved) result = c.c.semExpr(c.c, instantiated) +proc inferStaticParam*(lhs: PNode, rhs: BiggestInt): PType = + # This is a simple integer arithimetic equation solver, + # capable of deriving the value of a static parameter in + # expressions such as (N + 5) / 2 = rhs + # + # Preconditions: + # + # * The input of this proc must be semantized + # - all templates should be expanded + # - aby constant folding possible should already be performed + # + # * There must be exactly one unresolved static parameter + # + # Result: + # + # The proc will return the inferred static type with the `n` field + # populated with the inferred value. + # + # `nil` will be returned if the inference was not possible + # + if lhs.kind in nkCallKinds and lhs[0].kind == nkSym: + case lhs[0].sym.magic + of mUnaryLt: + return inferStaticParam(lhs[1], rhs + 1) + + of mAddI, mAddU, mInc, mSucc: + if lhs[1].kind == nkIntLit: + return inferStaticParam(lhs[2], rhs - lhs[1].intVal) + elif lhs[2].kind == nkIntLit: + return inferStaticParam(lhs[1], rhs - lhs[2].intVal) + + of mDec, mSubI, mSubU, mPred: + if lhs[1].kind == nkIntLit: + return inferStaticParam(lhs[2], lhs[1].intVal - rhs) + elif lhs[2].kind == nkIntLit: + return inferStaticParam(lhs[1], rhs + lhs[2].intVal) + + of mMulI, mMulU: + if lhs[1].kind == nkIntLit: + if rhs mod lhs[1].intVal == 0: + return inferStaticParam(lhs[2], rhs div lhs[1].intVal) + elif lhs[2].kind == nkIntLit: + if rhs mod lhs[2].intVal == 0: + return inferStaticParam(lhs[1], rhs div lhs[2].intVal) + + of mDivI, mDivU: + if lhs[1].kind == nkIntLit: + if lhs[1].intVal mod rhs == 0: + return inferStaticParam(lhs[2], lhs[1].intVal div rhs) + elif lhs[2].kind == nkIntLit: + return inferStaticParam(lhs[1], lhs[2].intVal * rhs) + + of mShlI: + if lhs[2].kind == nkIntLit: + return inferStaticParam(lhs[1], rhs shr lhs[2].intVal) + + of mShrI: + if lhs[2].kind == nkIntLit: + return inferStaticParam(lhs[1], rhs shl lhs[2].intVal) + + of mUnaryMinusI: + return inferStaticParam(lhs[1], -rhs) + + of mUnaryPlusI, mToInt, mToBiggestInt: + return inferStaticParam(lhs[1], rhs) + + else: discard + + elif lhs.kind == nkSym and lhs.typ.kind == tyStatic and lhs.typ.n == nil: + lhs.typ.n = newIntNode(nkIntLit, rhs) + return lhs.typ + + return nil + +proc failureToInferStaticParam(n: PNode) = + let staticParam = n.findUnresolvedStatic + let name = if staticParam != nil: staticParam.sym.name.s + else: "unknown" + localError(n.info, errCannotInferStaticParam, name) + +proc inferStaticsInRange(c: var TCandidate, + inferred, concrete: PType): TTypeRelation = + let lowerBound = tryResolvingStaticExpr(c, inferred.n[0], + allowUnresolved = true) + let upperBound = tryResolvingStaticExpr(c, inferred.n[1], + allowUnresolved = true) + + template doInferStatic(c: var TCandidate, e: PNode, r: BiggestInt) = + var exp = e + var rhs = r + var inferred = inferStaticParam(exp, rhs) + if inferred != nil: + put(c.bindings, inferred, inferred) + return isGeneric + else: + failureToInferStaticParam exp + + if lowerBound.kind == nkIntLit: + if upperBound.kind == nkIntLit: + if lengthOrd(concrete) == upperBound.intVal - lowerBound.intVal + 1: + return isGeneric + else: + return isNone + doInferStatic(c, upperBound, lengthOrd(concrete) + lowerBound.intVal - 1) + elif upperBound.kind == nkIntLit: + doInferStatic(c, lowerBound, upperBound.intVal + 1 - lengthOrd(concrete)) + template subtypeCheck() = if result <= isSubrange and f.lastSon.skipTypes(abstractInst).kind in {tyRef, tyPtr, tyVar}: result = isNone @@ -894,34 +1003,10 @@ proc typeRel(c: var TCandidate, f, aOrig: PType, doBind = true): TTypeRelation = a.sons[1].skipTypes({tyTypeDesc})) if result < isGeneric: return isNone - proc inferStaticRange(c: var TCandidate, inferred, concrete: PType) = - var (staticT, offset) = inferred.findUnresolvedStaticInRange - var - replacementT = newTypeWithSons(c.c, tyStatic, @[tyInt.getSysType]) - concreteUpperBound = concrete.n[1].intVal - # we must correct for the off-by-one discrepancy between - # ranges and static params: - replacementT.n = newIntNode(nkIntLit, concreteUpperBound + offset) - if tfInferrableStatic in staticT.flags: - staticT.n = replacementT.n - put(c.bindings, staticT, replacementT) - - if rangeHasUnresolvedStatic(fRange): - if tfUnresolved in fRange.flags: - # This is a range from an array instantiated with a generic - # static param. We must extract the static param here and bind - # it to the size of the currently supplied array. - inferStaticRange(c, fRange, aRange) - return isGeneric - - let len = tryResolvingStaticExpr(c, fRange.n[1]) - if len.kind == nkIntLit and len.intVal+1 == lengthOrd(a): - return # if we get this far, the result is already good - else: - return isNone + if fRange.rangeHasUnresolvedStatic: + return inferStaticsInRange(c, fRange, a) elif c.c.inTypeClass > 0 and aRange.rangeHasUnresolvedStatic: - inferStaticRange(c, aRange, fRange) - return isGeneric + return inferStaticsInRange(c, aRange, f) elif lengthOrd(fRange) != lengthOrd(a): result = isNone else: discard |