diff options
-rw-r--r-- | compiler/sigmatch.nim | 48 | ||||
-rw-r--r-- | tests/generics/t5683.nim | 16 |
2 files changed, 39 insertions, 25 deletions
diff --git a/compiler/sigmatch.nim b/compiler/sigmatch.nim index 49478f5a7..50a55e860 100644 --- a/compiler/sigmatch.nim +++ b/compiler/sigmatch.nim @@ -739,7 +739,7 @@ proc tryResolvingStaticExpr(c: var TCandidate, n: PNode, allowMetaTypes = allowUnresolved) result = c.c.semExpr(c.c, instantiated) -proc inferStaticParam*(lhs: PNode, rhs: BiggestInt): PType = +proc inferStaticParam*(c: var TCandidate, lhs: PNode, rhs: BiggestInt): bool = # This is a simple integer arithimetic equation solver, # capable of deriving the value of a static parameter in # expressions such as (N + 5) / 2 = rhs @@ -754,64 +754,65 @@ proc inferStaticParam*(lhs: PNode, rhs: BiggestInt): PType = # # Result: # - # The proc will return the inferred static type with the `n` field - # populated with the inferred value. - # - # `nil` will be returned if the inference was not possible + # The proc will return true if the static types was successfully + # inferred. The result will be bound to the original static type + # in the TCandidate. # if lhs.kind in nkCallKinds and lhs[0].kind == nkSym: case lhs[0].sym.magic of mUnaryLt: - return inferStaticParam(lhs[1], rhs + 1) + return inferStaticParam(c, lhs[1], rhs + 1) of mAddI, mAddU, mInc, mSucc: if lhs[1].kind == nkIntLit: - return inferStaticParam(lhs[2], rhs - lhs[1].intVal) + return inferStaticParam(c, lhs[2], rhs - lhs[1].intVal) elif lhs[2].kind == nkIntLit: - return inferStaticParam(lhs[1], rhs - lhs[2].intVal) + return inferStaticParam(c, lhs[1], rhs - lhs[2].intVal) of mDec, mSubI, mSubU, mPred: if lhs[1].kind == nkIntLit: - return inferStaticParam(lhs[2], lhs[1].intVal - rhs) + return inferStaticParam(c, lhs[2], lhs[1].intVal - rhs) elif lhs[2].kind == nkIntLit: - return inferStaticParam(lhs[1], rhs + lhs[2].intVal) + return inferStaticParam(c, lhs[1], rhs + lhs[2].intVal) of mMulI, mMulU: if lhs[1].kind == nkIntLit: if rhs mod lhs[1].intVal == 0: - return inferStaticParam(lhs[2], rhs div lhs[1].intVal) + return inferStaticParam(c, lhs[2], rhs div lhs[1].intVal) elif lhs[2].kind == nkIntLit: if rhs mod lhs[2].intVal == 0: - return inferStaticParam(lhs[1], rhs div lhs[2].intVal) + return inferStaticParam(c, lhs[1], rhs div lhs[2].intVal) of mDivI, mDivU: if lhs[1].kind == nkIntLit: if lhs[1].intVal mod rhs == 0: - return inferStaticParam(lhs[2], lhs[1].intVal div rhs) + return inferStaticParam(c, lhs[2], lhs[1].intVal div rhs) elif lhs[2].kind == nkIntLit: - return inferStaticParam(lhs[1], lhs[2].intVal * rhs) + return inferStaticParam(c, lhs[1], lhs[2].intVal * rhs) of mShlI: if lhs[2].kind == nkIntLit: - return inferStaticParam(lhs[1], rhs shr lhs[2].intVal) + return inferStaticParam(c, lhs[1], rhs shr lhs[2].intVal) of mShrI: if lhs[2].kind == nkIntLit: - return inferStaticParam(lhs[1], rhs shl lhs[2].intVal) + return inferStaticParam(c, lhs[1], rhs shl lhs[2].intVal) of mUnaryMinusI: - return inferStaticParam(lhs[1], -rhs) + return inferStaticParam(c, lhs[1], -rhs) of mUnaryPlusI, mToInt, mToBiggestInt: - return inferStaticParam(lhs[1], rhs) + return inferStaticParam(c, lhs[1], rhs) else: discard elif lhs.kind == nkSym and lhs.typ.kind == tyStatic and lhs.typ.n == nil: - lhs.typ.n = newIntNode(nkIntLit, rhs) - return lhs.typ + var inferred = newTypeWithSons(c.c, tyStatic, lhs.typ.sons) + inferred.n = newIntNode(nkIntLit, rhs) + put(c, lhs.typ, inferred) + return true - return nil + return false proc failureToInferStaticParam(n: PNode) = let staticParam = n.findUnresolvedStatic @@ -825,13 +826,10 @@ proc inferStaticsInRange(c: var TCandidate, allowUnresolved = true) let upperBound = tryResolvingStaticExpr(c, inferred.n[1], allowUnresolved = true) - template doInferStatic(e: PNode, r: BiggestInt) = var exp = e var rhs = r - var inferred = inferStaticParam(exp, rhs) - if inferred != nil: - put(c, inferred, inferred) + if inferStaticParam(c, exp, rhs): return isGeneric else: failureToInferStaticParam exp diff --git a/tests/generics/t5683.nim b/tests/generics/t5683.nim index 08ec7f30d..38da52ec2 100644 --- a/tests/generics/t5683.nim +++ b/tests/generics/t5683.nim @@ -13,3 +13,19 @@ const ] echo "perm: ", a.perm, " det: ", a.det + +# This tests multiple instantiations of a generic +# proc involving static params: +type + Vector64*[N: static[int]] = ref array[N, float64] + Array64[N: static[int]] = array[N, float64] + +proc vector*[N: static[int]](xs: Array64[N]): Vector64[N] = + new result + for i in 0 .. < N: + result[i] = xs[i] + +let v1 = vector([1.0, 2.0, 3.0, 4.0, 5.0]) +let v2 = vector([1.0, 2.0, 3.0, 4.0, 5.0]) +let v3 = vector([1.0, 2.0, 3.0, 4.0]) + |