summary refs log tree commit diff stats
path: root/compiler
diff options
context:
space:
mode:
authorZahary Karadjov <zahary@gmail.com>2016-08-12 03:25:59 +0300
committerZahary Karadjov <zahary@gmail.com>2017-03-24 16:58:15 +0200
commit0f2c4be1299fc99aeea2011c57240c8cfabd83c3 (patch)
tree64fa8e58c3c9745b54c127f8fdc3b897ff3896a0 /compiler
parent0b0a3e5f203f6b21f3790a6cd50ceeaa8786badc (diff)
downloadNim-0f2c4be1299fc99aeea2011c57240c8cfabd83c3.tar.gz
infer static parameters even when more complicated arithmetic is involved
Diffstat (limited to 'compiler')
-rw-r--r--compiler/ast.nim12
-rw-r--r--compiler/msgs.nim2
-rw-r--r--compiler/semdata.nim15
-rw-r--r--compiler/semexprs.nim1
-rw-r--r--compiler/semstmts.nim2
-rw-r--r--compiler/semtypes.nim1
-rw-r--r--compiler/semtypinst.nim5
-rw-r--r--compiler/sigmatch.nim143
-rw-r--r--compiler/types.nim2
9 files changed, 137 insertions, 46 deletions
diff --git a/compiler/ast.nim b/compiler/ast.nim
index f13691d54..5adac92df 100644
--- a/compiler/ast.nim
+++ b/compiler/ast.nim
@@ -1581,7 +1581,7 @@ proc hasPattern*(s: PSym): bool {.inline.} =
   result = isRoutine(s) and s.ast.sons[patternPos].kind != nkEmpty
 
 iterator items*(n: PNode): PNode =
-  for i in 0.. <n.len: yield n.sons[i]
+  for i in 0.. <n.safeLen: yield n.sons[i]
 
 iterator pairs*(n: PNode): tuple[i: int, n: PNode] =
   for i in 0.. <n.len: yield (i, n.sons[i])
@@ -1624,6 +1624,16 @@ proc toObject*(typ: PType): PType =
   if result.kind == tyRef:
     result = result.lastSon
 
+proc findUnresolvedStatic*(n: PNode): PNode =
+  if n.kind == nkSym and n.typ.kind == tyStatic and n.typ.n == nil:
+    return n
+
+  for son in n:
+    let n = son.findUnresolvedStatic
+    if n != nil: return n
+
+  return nil
+
 when false:
   proc containsNil*(n: PNode): bool =
     # only for debugging
diff --git a/compiler/msgs.nim b/compiler/msgs.nim
index 2db3646b5..eb9986114 100644
--- a/compiler/msgs.nim
+++ b/compiler/msgs.nim
@@ -109,6 +109,7 @@ type
     errXCannotBeClosure, errXMustBeCompileTime,
     errCannotInferTypeOfTheLiteral,
     errCannotInferReturnType,
+    errCannotInferStaticParam,
     errGenericLambdaNotAllowed,
     errProcHasNoConcreteType,
     errCompilerDoesntSupportTarget,
@@ -373,6 +374,7 @@ const
     errXMustBeCompileTime: "'$1' can only be used in compile-time context",
     errCannotInferTypeOfTheLiteral: "cannot infer the type of the $1",
     errCannotInferReturnType: "cannot infer the return type of the proc",
+    errCannotInferStaticParam: "cannot infer the value of the static param `$1`",
     errGenericLambdaNotAllowed: "A nested proc can have generic parameters only when " &
                                 "it is used as an operand to another routine and the types " &
                                 "of the generic paramers can be inferred from the expected signature.",
diff --git a/compiler/semdata.nim b/compiler/semdata.nim
index 5cd755607..5644a2346 100644
--- a/compiler/semdata.nim
+++ b/compiler/semdata.nim
@@ -318,23 +318,14 @@ proc makeRangeWithStaticExpr*(c: PContext, n: PNode): PType =
   let intType = getSysType(tyInt)
   result = newTypeS(tyRange, c)
   result.sons = @[intType]
-  if n.typ.n == nil: result.flags.incl tfUnresolved
+  if n.typ != nil and n.typ.n == nil:
+    result.flags.incl tfUnresolved
   result.n = newNode(nkRange, n.info, @[
     newIntTypeNode(nkIntLit, 0, intType),
     makeStaticExpr(c, n.nMinusOne)])
 
 template rangeHasUnresolvedStatic*(t: PType): bool =
-  # this accepts the ranges's node
-  t.n != nil and t.n.len > 1 and t.n[1].kind == nkStaticExpr
-
-proc findUnresolvedStaticInRange*(t: PType): (PType, int) =
-  assert t.kind == tyRange
-  # XXX: This really needs to become more sophisticated
-  let upperBound = t.n[1]
-  if upperBound[0].kind == nkCall:
-    return (upperBound[0][1].typ, 1)
-  else:
-    return (upperBound.typ, 0)
+  tfUnresolved in t.flags
 
 proc errorType*(c: PContext): PType =
   ## creates a type representing an error state
diff --git a/compiler/semexprs.nim b/compiler/semexprs.nim
index 3ec2cd391..930c84368 100644
--- a/compiler/semexprs.nim
+++ b/compiler/semexprs.nim
@@ -627,6 +627,7 @@ proc evalAtCompileTime(c: PContext, n: PNode): PNode =
 
 proc semStaticExpr(c: PContext, n: PNode): PNode =
   let a = semExpr(c, n.sons[0])
+  if a.findUnresolvedStatic != nil: return a
   result = evalStaticExpr(c.module, c.cache, a, c.p.owner)
   if result.isNil:
     localError(n.info, errCannotInterpretNodeX, renderTree(n))
diff --git a/compiler/semstmts.nim b/compiler/semstmts.nim
index 64449dda4..09631c793 100644
--- a/compiler/semstmts.nim
+++ b/compiler/semstmts.nim
@@ -1621,7 +1621,7 @@ proc semStmtList(c: PContext, n: PNode, flags: TExprFlags): PNode =
             elif expr[2].typ.isUnresolvedStatic:
               inferConceptStaticParam(c, expr[2].typ, expr[1])
               continue
-
+            
           let verdict = semConstExpr(c, n[i])
           if verdict.intVal == 0:
             localError(result.info, "type class predicate failed")
diff --git a/compiler/semtypes.nim b/compiler/semtypes.nim
index bf6c24310..eef83c2a7 100644
--- a/compiler/semtypes.nim
+++ b/compiler/semtypes.nim
@@ -195,6 +195,7 @@ proc semRangeAux(c: PContext, n: PNode, prev: PType): PType =
   for i in 0..1:
     if hasGenericArguments(range[i]):
       result.n.addSon makeStaticExpr(c, range[i])
+      result.flags.incl tfUnresolved
     else:
       result.n.addSon semConstExpr(c, range[i])
 
diff --git a/compiler/semtypinst.nim b/compiler/semtypinst.nim
index 7e114afb8..9e72e46f6 100644
--- a/compiler/semtypinst.nim
+++ b/compiler/semtypinst.nim
@@ -445,7 +445,7 @@ proc replaceTypeVarsTAux(cl: var TReplTypeVars, t: PType): PType =
     elif t.sons[0].kind != tyNone:
       result = makeTypeDesc(cl.c, replaceTypeVarsT(cl, t.sons[0]))
 
-  of tyUserTypeClass:
+  of tyUserTypeClass, tyStatic:
     result = t
 
   of tyGenericInst, tyUserTypeClassInst:
@@ -502,8 +502,9 @@ proc initTypeVars*(p: PContext, pt: TIdTable, info: TLineInfo;
   result.owner = owner
 
 proc replaceTypesInBody*(p: PContext, pt: TIdTable, n: PNode;
-                         owner: PSym): PNode =
+                         owner: PSym, allowMetaTypes = false): PNode =
   var cl = initTypeVars(p, pt, n.info, owner)
+  cl.allowMetaTypes = allowMetaTypes
   pushInfoContext(n.info)
   result = replaceTypeVarsN(cl, n)
   popInfoContext()
diff --git a/compiler/sigmatch.nim b/compiler/sigmatch.nim
index 162385e6d..ca9cdcaf8 100644
--- a/compiler/sigmatch.nim
+++ b/compiler/sigmatch.nim
@@ -681,16 +681,125 @@ proc maybeSkipDistinct(t: PType, callee: PSym): PType =
   else:
     result = t
 
-proc tryResolvingStaticExpr(c: var TCandidate, n: PNode): PNode =
+proc tryResolvingStaticExpr(c: var TCandidate, n: PNode,
+                            allowUnresolved = false): PNode =
   # Consider this example:
   #   type Value[N: static[int]] = object
   #   proc foo[N](a: Value[N], r: range[0..(N-1)])
   # Here, N-1 will be initially nkStaticExpr that can be evaluated only after
   # N is bound to a concrete value during the matching of the first param.
   # This proc is used to evaluate such static expressions.
-  let instantiated = replaceTypesInBody(c.c, c.bindings, n, nil)
+  let instantiated = replaceTypesInBody(c.c, c.bindings, n, nil,
+                                        allowMetaTypes = allowUnresolved)
   result = c.c.semExpr(c.c, instantiated)
 
+proc inferStaticParam*(lhs: PNode, rhs: BiggestInt): PType =
+  # This is a simple integer arithimetic equation solver,
+  # capable of deriving the value of a static parameter in
+  # expressions such as (N + 5) / 2 = rhs
+  #
+  # Preconditions:
+  #
+  #   * The input of this proc must be semantized 
+  #     - all templates should be expanded
+  #     - aby constant folding possible should already be performed
+  #
+  #   * There must be exactly one unresolved static parameter
+  #
+  # Result:
+  #
+  #   The proc will return the inferred static type with the `n` field
+  #   populated with the inferred value.
+  #
+  #   `nil` will be returned if the inference was not possible
+  #
+  if lhs.kind in nkCallKinds and lhs[0].kind == nkSym:
+    case lhs[0].sym.magic
+    of mUnaryLt:
+      return inferStaticParam(lhs[1], rhs + 1)
+
+    of mAddI, mAddU, mInc, mSucc:
+      if lhs[1].kind == nkIntLit:
+        return inferStaticParam(lhs[2], rhs - lhs[1].intVal)
+      elif lhs[2].kind == nkIntLit:
+        return inferStaticParam(lhs[1], rhs - lhs[2].intVal)
+    
+    of mDec, mSubI, mSubU, mPred:
+      if lhs[1].kind == nkIntLit:
+        return inferStaticParam(lhs[2], lhs[1].intVal - rhs)
+      elif lhs[2].kind == nkIntLit:
+        return inferStaticParam(lhs[1], rhs + lhs[2].intVal)
+    
+    of mMulI, mMulU:
+      if lhs[1].kind == nkIntLit:
+        if rhs mod lhs[1].intVal == 0:
+          return inferStaticParam(lhs[2], rhs div lhs[1].intVal)
+      elif lhs[2].kind == nkIntLit:
+        if rhs mod lhs[2].intVal == 0:
+          return inferStaticParam(lhs[1], rhs div lhs[2].intVal)
+    
+    of mDivI, mDivU:
+      if lhs[1].kind == nkIntLit:
+        if lhs[1].intVal mod rhs == 0:
+          return inferStaticParam(lhs[2], lhs[1].intVal div rhs)
+      elif lhs[2].kind == nkIntLit:
+        return inferStaticParam(lhs[1], lhs[2].intVal * rhs)
+    
+    of mShlI:
+      if lhs[2].kind == nkIntLit:
+        return inferStaticParam(lhs[1], rhs shr lhs[2].intVal)
+    
+    of mShrI:
+      if lhs[2].kind == nkIntLit:
+        return inferStaticParam(lhs[1], rhs shl lhs[2].intVal)
+    
+    of mUnaryMinusI:
+      return inferStaticParam(lhs[1], -rhs)
+    
+    of mUnaryPlusI, mToInt, mToBiggestInt:
+      return inferStaticParam(lhs[1], rhs)
+    
+    else: discard
+  
+  elif lhs.kind == nkSym and lhs.typ.kind == tyStatic and lhs.typ.n == nil:
+    lhs.typ.n = newIntNode(nkIntLit, rhs)
+    return lhs.typ
+  
+  return nil
+
+proc failureToInferStaticParam(n: PNode) =
+  let staticParam = n.findUnresolvedStatic
+  let name = if staticParam != nil: staticParam.sym.name.s
+             else: "unknown"
+  localError(n.info, errCannotInferStaticParam, name)
+
+proc inferStaticsInRange(c: var TCandidate,
+                         inferred, concrete: PType): TTypeRelation =
+  let lowerBound = tryResolvingStaticExpr(c, inferred.n[0],
+                                          allowUnresolved = true)
+  let upperBound = tryResolvingStaticExpr(c, inferred.n[1],
+                                          allowUnresolved = true)
+  
+  template doInferStatic(c: var TCandidate, e: PNode, r: BiggestInt) =
+    var exp = e
+    var rhs = r
+    var inferred = inferStaticParam(exp, rhs)
+    if inferred != nil:
+      put(c.bindings, inferred, inferred)
+      return isGeneric
+    else:
+      failureToInferStaticParam exp
+
+  if lowerBound.kind == nkIntLit:
+    if upperBound.kind == nkIntLit:
+      if lengthOrd(concrete) == upperBound.intVal - lowerBound.intVal + 1:
+        return isGeneric
+      else:
+        return isNone
+    doInferStatic(c, upperBound, lengthOrd(concrete) + lowerBound.intVal - 1)
+  elif upperBound.kind == nkIntLit:
+    doInferStatic(c, lowerBound, upperBound.intVal + 1 - lengthOrd(concrete))
+
 template subtypeCheck() =
   if result <= isSubrange and f.lastSon.skipTypes(abstractInst).kind in {tyRef, tyPtr, tyVar}:
     result = isNone
@@ -894,34 +1003,10 @@ proc typeRel(c: var TCandidate, f, aOrig: PType, doBind = true): TTypeRelation =
                           a.sons[1].skipTypes({tyTypeDesc}))
       if result < isGeneric: return isNone
       
-      proc inferStaticRange(c: var TCandidate, inferred, concrete: PType) =
-        var (staticT, offset) = inferred.findUnresolvedStaticInRange
-        var
-          replacementT = newTypeWithSons(c.c, tyStatic, @[tyInt.getSysType])
-          concreteUpperBound = concrete.n[1].intVal
-        # we must correct for the off-by-one discrepancy between
-        # ranges and static params:
-        replacementT.n = newIntNode(nkIntLit, concreteUpperBound + offset)
-        if tfInferrableStatic in staticT.flags:
-          staticT.n = replacementT.n
-        put(c.bindings, staticT, replacementT)
-
-      if rangeHasUnresolvedStatic(fRange):
-        if tfUnresolved in fRange.flags:
-          # This is a range from an array instantiated with a generic
-          # static param. We must extract the static param here and bind
-          # it to the size of the currently supplied array.
-          inferStaticRange(c, fRange, aRange)
-          return isGeneric
-
-        let len = tryResolvingStaticExpr(c, fRange.n[1])
-        if len.kind == nkIntLit and len.intVal+1 == lengthOrd(a):
-          return # if we get this far, the result is already good
-        else:
-          return isNone
+      if fRange.rangeHasUnresolvedStatic:
+        return inferStaticsInRange(c, fRange, a)
       elif c.c.inTypeClass > 0 and aRange.rangeHasUnresolvedStatic:
-        inferStaticRange(c, aRange, fRange)
-        return isGeneric
+        return inferStaticsInRange(c, aRange, f)
       elif lengthOrd(fRange) != lengthOrd(a):
         result = isNone
     else: discard
diff --git a/compiler/types.nim b/compiler/types.nim
index be7028f9c..65eb6de61 100644
--- a/compiler/types.nim
+++ b/compiler/types.nim
@@ -63,7 +63,7 @@ const
   abstractVarRange* = {tyGenericInst, tyRange, tyVar, tyDistinct, tyOrdinal,
                        tyTypeDesc, tyAlias, tyInferred}
   abstractInst* = {tyGenericInst, tyDistinct, tyOrdinal, tyTypeDesc, tyAlias,
-                   tyInferred} + tyTypeClasses
+                   tyInferred}
   skipPtrs* = {tyVar, tyPtr, tyRef, tyGenericInst, tyTypeDesc, tyAlias,
                tyInferred}
   # typedescX is used if we're sure tyTypeDesc should be included (or skipped)
re>
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297