summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--compiler/sigmatch.nim48
-rw-r--r--tests/generics/t5683.nim16
2 files changed, 39 insertions, 25 deletions
diff --git a/compiler/sigmatch.nim b/compiler/sigmatch.nim
index 49478f5a7..50a55e860 100644
--- a/compiler/sigmatch.nim
+++ b/compiler/sigmatch.nim
@@ -739,7 +739,7 @@ proc tryResolvingStaticExpr(c: var TCandidate, n: PNode,
                                         allowMetaTypes = allowUnresolved)
   result = c.c.semExpr(c.c, instantiated)
 
-proc inferStaticParam*(lhs: PNode, rhs: BiggestInt): PType =
+proc inferStaticParam*(c: var TCandidate, lhs: PNode, rhs: BiggestInt): bool =
   # This is a simple integer arithimetic equation solver,
   # capable of deriving the value of a static parameter in
   # expressions such as (N + 5) / 2 = rhs
@@ -754,64 +754,65 @@ proc inferStaticParam*(lhs: PNode, rhs: BiggestInt): PType =
   #
   # Result:
   #
-  #   The proc will return the inferred static type with the `n` field
-  #   populated with the inferred value.
-  #
-  #   `nil` will be returned if the inference was not possible
+  #   The proc will return true if the static types was successfully
+  #   inferred. The result will be bound to the original static type
+  #   in the TCandidate.
   #
   if lhs.kind in nkCallKinds and lhs[0].kind == nkSym:
     case lhs[0].sym.magic
     of mUnaryLt:
-      return inferStaticParam(lhs[1], rhs + 1)
+      return inferStaticParam(c, lhs[1], rhs + 1)
 
     of mAddI, mAddU, mInc, mSucc:
       if lhs[1].kind == nkIntLit:
-        return inferStaticParam(lhs[2], rhs - lhs[1].intVal)
+        return inferStaticParam(c, lhs[2], rhs - lhs[1].intVal)
       elif lhs[2].kind == nkIntLit:
-        return inferStaticParam(lhs[1], rhs - lhs[2].intVal)
+        return inferStaticParam(c, lhs[1], rhs - lhs[2].intVal)
 
     of mDec, mSubI, mSubU, mPred:
       if lhs[1].kind == nkIntLit:
-        return inferStaticParam(lhs[2], lhs[1].intVal - rhs)
+        return inferStaticParam(c, lhs[2], lhs[1].intVal - rhs)
       elif lhs[2].kind == nkIntLit:
-        return inferStaticParam(lhs[1], rhs + lhs[2].intVal)
+        return inferStaticParam(c, lhs[1], rhs + lhs[2].intVal)
 
     of mMulI, mMulU:
       if lhs[1].kind == nkIntLit:
         if rhs mod lhs[1].intVal == 0:
-          return inferStaticParam(lhs[2], rhs div lhs[1].intVal)
+          return inferStaticParam(c, lhs[2], rhs div lhs[1].intVal)
       elif lhs[2].kind == nkIntLit:
         if rhs mod lhs[2].intVal == 0:
-          return inferStaticParam(lhs[1], rhs div lhs[2].intVal)
+          return inferStaticParam(c, lhs[1], rhs div lhs[2].intVal)
 
     of mDivI, mDivU:
       if lhs[1].kind == nkIntLit:
         if lhs[1].intVal mod rhs == 0:
-          return inferStaticParam(lhs[2], lhs[1].intVal div rhs)
+          return inferStaticParam(c, lhs[2], lhs[1].intVal div rhs)
       elif lhs[2].kind == nkIntLit:
-        return inferStaticParam(lhs[1], lhs[2].intVal * rhs)
+        return inferStaticParam(c, lhs[1], lhs[2].intVal * rhs)
 
     of mShlI:
       if lhs[2].kind == nkIntLit:
-        return inferStaticParam(lhs[1], rhs shr lhs[2].intVal)
+        return inferStaticParam(c, lhs[1], rhs shr lhs[2].intVal)
 
     of mShrI:
       if lhs[2].kind == nkIntLit:
-        return inferStaticParam(lhs[1], rhs shl lhs[2].intVal)
+        return inferStaticParam(c, lhs[1], rhs shl lhs[2].intVal)
 
     of mUnaryMinusI:
-      return inferStaticParam(lhs[1], -rhs)
+      return inferStaticParam(c, lhs[1], -rhs)
 
     of mUnaryPlusI, mToInt, mToBiggestInt:
-      return inferStaticParam(lhs[1], rhs)
+      return inferStaticParam(c, lhs[1], rhs)
 
     else: discard
 
   elif lhs.kind == nkSym and lhs.typ.kind == tyStatic and lhs.typ.n == nil:
-    lhs.typ.n = newIntNode(nkIntLit, rhs)
-    return lhs.typ
+    var inferred = newTypeWithSons(c.c, tyStatic, lhs.typ.sons)
+    inferred.n = newIntNode(nkIntLit, rhs)
+    put(c, lhs.typ, inferred)
+    return true
 
-  return nil
+  return false
 
 proc failureToInferStaticParam(n: PNode) =
   let staticParam = n.findUnresolvedStatic
@@ -825,13 +826,10 @@ proc inferStaticsInRange(c: var TCandidate,
                                           allowUnresolved = true)
   let upperBound = tryResolvingStaticExpr(c, inferred.n[1],
                                           allowUnresolved = true)
-
   template doInferStatic(e: PNode, r: BiggestInt) =
     var exp = e
     var rhs = r
-    var inferred = inferStaticParam(exp, rhs)
-    if inferred != nil:
-      put(c, inferred, inferred)
+    if inferStaticParam(c, exp, rhs):
       return isGeneric
     else:
       failureToInferStaticParam exp
diff --git a/tests/generics/t5683.nim b/tests/generics/t5683.nim
index 08ec7f30d..38da52ec2 100644
--- a/tests/generics/t5683.nim
+++ b/tests/generics/t5683.nim
@@ -13,3 +13,19 @@ const
       ]
 
 echo "perm: ", a.perm, " det: ", a.det
+
+# This tests multiple instantiations of a generic
+# proc involving static params:
+type
+  Vector64*[N: static[int]] = ref array[N, float64]
+  Array64[N: static[int]] = array[N, float64]
+
+proc vector*[N: static[int]](xs: Array64[N]): Vector64[N] =
+  new result
+  for i in 0 .. < N:
+    result[i] = xs[i]
+
+let v1 = vector([1.0, 2.0, 3.0, 4.0, 5.0])
+let v2 = vector([1.0, 2.0, 3.0, 4.0, 5.0])
+let v3 = vector([1.0, 2.0, 3.0, 4.0])
+