summary refs log tree commit diff stats
path: root/compiler
diff options
context:
space:
mode:
authorZahary Karadjov <zahary@gmail.com>2014-03-08 22:57:06 +0200
committerZahary Karadjov <zahary@gmail.com>2014-03-08 22:57:06 +0200
commit085b339b8b12267cb8ed555979db368e151c9ca4 (patch)
treeca2b34fe9cb94a72319892144a922221b66219d5 /compiler
parent2cbe46daff73987d819ea0ca4bc6ada919d531d4 (diff)
downloadNim-085b339b8b12267cb8ed555979db368e151c9ca4.tar.gz
implements higher-order inline iterators and return type inference for iterators
Diffstat (limited to 'compiler')
-rw-r--r--compiler/msgs.nim6
-rw-r--r--compiler/semexprs.nim16
-rw-r--r--compiler/seminst.nim2
-rw-r--r--compiler/semstmts.nim18
-rw-r--r--compiler/semtypes.nim28
-rw-r--r--compiler/semtypinst.nim7
-rw-r--r--compiler/sigmatch.nim19
-rw-r--r--compiler/transf.nim12
8 files changed, 81 insertions, 27 deletions
diff --git a/compiler/msgs.nim b/compiler/msgs.nim
index 66763e7f5..5bc490d14 100644
--- a/compiler/msgs.nim
+++ b/compiler/msgs.nim
@@ -96,8 +96,8 @@ type
     errOnlyACallOpCanBeDelegator, errUsingNoSymbol,
     errMacroBodyDependsOnGenericTypes,
     errDestructorNotGenericEnough,
-    
-    errXExpectsTwoArguments, 
+    errInlineIteratorsAsProcParams,
+    errXExpectsTwoArguments,
     errXExpectsObjectTypes, errXcanNeverBeOfThisSubtype, errTooManyIterations, 
     errCannotInterpretNodeX, errFieldXNotFound, errInvalidConversionFromTypeX, 
     errAssertionFailed, errCannotGenerateCodeForX, errXRequiresOneArgument, 
@@ -331,6 +331,8 @@ const
                                        "because the parameter '$1' has a generic type",
     errDestructorNotGenericEnough: "Destructor signarue is too specific. " &
                                    "A destructor must be associated will all instantiations of a generic type",
+    errInlineIteratorsAsProcParams: "inline iterators can be used as parameters only for " &
+                                    "templates, macros and other inline iterators",
     errXExpectsTwoArguments: "\'$1\' expects two arguments", 
     errXExpectsObjectTypes: "\'$1\' expects object types",
     errXcanNeverBeOfThisSubtype: "\'$1\' can never be of this subtype", 
diff --git a/compiler/semexprs.nim b/compiler/semexprs.nim
index 203a51816..c16ab9b87 100644
--- a/compiler/semexprs.nim
+++ b/compiler/semexprs.nim
@@ -1119,6 +1119,9 @@ proc asgnToResultVar(c: PContext, n, le, ri: PNode) {.inline.} =
       n.sons[0] = x # 'result[]' --> 'result'
       n.sons[1] = takeImplicitAddr(c, ri)
 
+template resultTypeIsInferrable(typ: PType): expr =
+  typ.isMetaType and typ.kind != tyTypeDesc
+
 proc semAsgn(c: PContext, n: PNode): PNode =
   checkSonsLen(n, 2)
   var a = n.sons[0]
@@ -1170,7 +1173,7 @@ proc semAsgn(c: PContext, n: PNode): PNode =
         if lhsIsResult: {efAllowDestructor} else: {})
     if lhsIsResult:
       n.typ = enforceVoidContext
-      if lhs.sym.typ.isMetaType and lhs.sym.typ.kind != tyTypeDesc:
+      if resultTypeIsInferrable(lhs.sym.typ):
         if cmpTypes(c, lhs.typ, rhs.typ) == isGeneric:
           internalAssert c.p.resultSym != nil
           lhs.typ = rhs.typ
@@ -1259,12 +1262,21 @@ proc semYield(c: PContext, n: PNode): PNode =
     localError(n.info, errYieldNotAllowedInTryStmt)
   elif n.sons[0].kind != nkEmpty:
     n.sons[0] = semExprWithType(c, n.sons[0]) # check for type compatibility:
-    var restype = c.p.owner.typ.sons[0]
+    var iterType = c.p.owner.typ
+    var restype = iterType.sons[0]
     if restype != nil:
       let adjustedRes = if c.p.owner.kind == skIterator: restype.base
                         else: restype
       n.sons[0] = fitNode(c, adjustedRes, n.sons[0])
       if n.sons[0].typ == nil: internalError(n.info, "semYield")
+      
+      if resultTypeIsInferrable(adjustedRes):
+        let inferred = n.sons[0].typ
+        if c.p.owner.kind == skIterator:
+          iterType.sons[0].sons[0] = inferred
+        else:
+          iterType.sons[0] = inferred
+      
       semYieldVarResult(c, n, adjustedRes)
     else:
       localError(n.info, errCannotReturnExpr)
diff --git a/compiler/seminst.nim b/compiler/seminst.nim
index 8faf1d21a..4bcfa7f15 100644
--- a/compiler/seminst.nim
+++ b/compiler/seminst.nim
@@ -20,7 +20,7 @@ proc instantiateGenericParamList(c: PContext, n: PNode, pt: TIdTable,
     if a.kind != nkSym: 
       internalError(a.info, "instantiateGenericParamList; no symbol")
     var q = a.sym
-    if q.typ.kind notin {tyTypeDesc, tyGenericParam, tyStatic}+tyTypeClasses:
+    if q.typ.kind notin {tyTypeDesc, tyGenericParam, tyStatic, tyIter}+tyTypeClasses:
       continue
     var s = newSym(skType, q.name, getCurrOwner(), q.info)
     s.flags = s.flags + {sfUsed, sfFromGeneric}
diff --git a/compiler/semstmts.nim b/compiler/semstmts.nim
index a11386966..edce7c9bd 100644
--- a/compiler/semstmts.nim
+++ b/compiler/semstmts.nim
@@ -661,10 +661,18 @@ proc semFor(c: PContext, n: PNode): PNode =
   openScope(c)
   n.sons[length-2] = semExprNoDeref(c, n.sons[length-2], {efWantIterator})
   var call = n.sons[length-2]
-  if call.kind in nkCallKinds and call.sons[0].typ.callConv == ccClosure:
+  let isCallExpr = call.kind in nkCallKinds
+  if isCallExpr and call.sons[0].sym.magic != mNone:
+    if call.sons[0].sym.magic == mOmpParFor:
+      result = semForVars(c, n)
+      result.kind = nkParForStmt
+    else:
+      result = semForFields(c, n, call.sons[0].sym.magic)
+  elif (isCallExpr and call.sons[0].typ.callConv == ccClosure) or
+      call.typ.kind == tyIter:
     # first class iterator:
     result = semForVars(c, n)
-  elif call.kind notin nkCallKinds or call.sons[0].kind != nkSym or
+  elif not isCallExpr or call.sons[0].kind != nkSym or
       call.sons[0].sym.kind notin skIterators:
     if length == 3:
       n.sons[length-2] = implicitIterator(c, "items", n.sons[length-2])
@@ -673,12 +681,6 @@ proc semFor(c: PContext, n: PNode): PNode =
     else:
       localError(n.sons[length-2].info, errIteratorExpected)
     result = semForVars(c, n)
-  elif call.sons[0].sym.magic != mNone:
-    if call.sons[0].sym.magic == mOmpParFor:
-      result = semForVars(c, n)
-      result.kind = nkParForStmt
-    else:
-      result = semForFields(c, n, call.sons[0].sym.magic)
   else:
     result = semForVars(c, n)
   # propagate any enforced VoidContext:
diff --git a/compiler/semtypes.nim b/compiler/semtypes.nim
index a619de7ff..875631505 100644
--- a/compiler/semtypes.nim
+++ b/compiler/semtypes.nim
@@ -721,7 +721,16 @@ proc liftParamType(c: PContext, procKind: TSymKind, genericParams: PNode,
                                   allowMetaTypes = true)
     result = newTypeWithSons(c, tyCompositeTypeClass, @[paramType, result])
     result = addImplicitGeneric(result)
-  
+
+  of tyIter:
+    if paramType.callConv == ccInline:
+      if procKind notin {skTemplate, skMacro, skIterator}:
+        localError(info, errInlineIteratorsAsProcParams)
+      if paramType.len == 1:
+        let lifted = liftingWalk(paramType.base)
+        if lifted != nil: paramType.sons[0] = lifted
+      result = addImplicitGeneric(paramType)
+
   of tyGenericInst:
     if paramType.lastSon.kind == tyUserTypeClass:
       var cp = copyType(paramType, getCurrOwner(), false)
@@ -852,7 +861,11 @@ proc semProcTypeNode(c: PContext, n, genericParams: PNode,
         if lifted != nil: r = lifted
         r.flags.incl tfRetType
       r = skipIntLit(r)
-      if kind == skIterator: r = newTypeWithSons(c, tyIter, @[r])
+      if kind == skIterator:
+        # see tchainediterators
+        # in cases like iterator foo(it: iterator): type(it)
+        # we don't need to change the return type to iter[T]
+        if not r.isInlineIterator: r = newTypeWithSons(c, tyIter, @[r])
       result.sons[0] = r
       res.typ = r
 
@@ -984,7 +997,8 @@ proc semTypeNode(c: PContext, n: PNode, prev: PType): PType =
   of nkTypeOfExpr:
     # for ``type(countup(1,3))``, see ``tests/ttoseq``.
     checkSonsLen(n, 1)
-    result = semExprWithType(c, n.sons[0], {efInTypeof}).typ.skipTypes({tyIter})
+    let typExpr = semExprWithType(c, n.sons[0], {efInTypeof})
+    result = typExpr.typ.skipTypes({tyIter})
   of nkPar: 
     if sonsLen(n) == 1: result = semTypeNode(c, n.sons[0], prev)
     else:
@@ -1103,8 +1117,12 @@ proc semTypeNode(c: PContext, n: PNode, prev: PType): PType =
       result = newConstraint(c, tyIter)
     else:
       result = semProcTypeWithScope(c, n, prev, skClosureIterator)
-      result.flags.incl(tfIterator)
-      result.callConv = ccClosure
+      if n.lastSon.kind == nkPragma and hasPragma(n.lastSon, wInline):
+        result.kind = tyIter
+        result.callConv = ccInline
+      else:
+        result.flags.incl(tfIterator)
+        result.callConv = ccClosure
   of nkProcTy:
     if n.sonsLen == 0:
       result = newConstraint(c, tyProc)
diff --git a/compiler/semtypinst.nim b/compiler/semtypinst.nim
index 22edc6e32..46ec31d90 100644
--- a/compiler/semtypinst.nim
+++ b/compiler/semtypinst.nim
@@ -305,6 +305,11 @@ proc skipIntLiteralParams(t: PType) =
     if skipped != p:
       t.sons[i] = skipped
       if i > 0: t.n.sons[i].sym.typ = skipped
+  
+  # when the typeof operator is used on a static input
+  # param, the results gets infected with static as well:
+  if t.sons[0] != nil and t.sons[0].kind == tyStatic:
+    t.sons[0] = t.sons[0].base
 
 proc propagateFieldFlags(t: PType, n: PNode) =
   # This is meant for objects and tuples
@@ -323,7 +328,7 @@ proc replaceTypeVarsTAux(cl: var TReplTypeVars, t: PType): PType =
   result = t
   if t == nil: return
 
-  if t.kind in {tyStatic, tyGenericParam} + tyTypeClasses:
+  if t.kind in {tyStatic, tyGenericParam, tyIter} + tyTypeClasses:
     let lookup = PType(idTableGet(cl.typeMap, t))
     if lookup != nil: return lookup
   
diff --git a/compiler/sigmatch.nim b/compiler/sigmatch.nim
index c0898ef26..19f10def8 100644
--- a/compiler/sigmatch.nim
+++ b/compiler/sigmatch.nim
@@ -1014,6 +1014,10 @@ proc localConvMatch(c: PContext, m: var TCandidate, f, a: PType,
       result.typ = getInstantiatedType(c, arg, m, base(f))
     m.baseTypeMatch = true
 
+proc isInlineIterator*(t: PType): bool =
+  result = t.kind == tyIter or
+          (t.kind == tyBuiltInTypeClass and t.base.kind == tyIter)
+
 proc paramTypesMatchAux(m: var TCandidate, f, argType: PType,
                         argSemantized, argOrig: PNode): PNode =
   var
@@ -1021,7 +1025,7 @@ proc paramTypesMatchAux(m: var TCandidate, f, argType: PType,
     arg = argSemantized
     argType = argType
     c = m.c
-    
+   
   if tfHasStatic in fMaybeStatic.flags:
     # XXX: When implicit statics are the default
     # this will be done earlier - we just have to
@@ -1060,7 +1064,14 @@ proc paramTypesMatchAux(m: var TCandidate, f, argType: PType,
       return arg.typ.n
     else:
       return argOrig
-  
+
+  if r != isNone and f.isInlineIterator:
+    var inlined = newTypeS(tyStatic, c)
+    inlined.sons = @[argType]
+    inlined.n = argSemantized
+    put(m.bindings, f, inlined)
+    return argSemantized
+
   case r
   of isConvertible:
     inc(m.convMatches)
@@ -1188,7 +1199,9 @@ proc prepareOperand(c: PContext; formal: PType; a: PNode): PNode =
     # a.typ == nil is valid
     result = a
   elif a.typ.isNil:
-    result = c.semOperand(c, a, {efDetermineType})
+    let flags = if formal.kind == tyIter: {efDetermineType, efWantIterator}
+                else: {efDetermineType}
+    result = c.semOperand(c, a, flags)
   else:
     result = a
 
diff --git a/compiler/transf.nim b/compiler/transf.nim
index f4b716c5b..9586398c9 100644
--- a/compiler/transf.nim
+++ b/compiler/transf.nim
@@ -425,7 +425,7 @@ proc findWrongOwners(c: PTransf, n: PNode) =
         x.sym.owner.name.s & " " & getCurrOwner(c).name.s)
   else:
     for i in 0 .. <safeLen(n): findWrongOwners(c, n.sons[i])
-  
+
 proc transformFor(c: PTransf, n: PNode): PTransNode = 
   # generate access statements for the parameters (unless they are constant)
   # put mapping from formal parameters to actual parameters
@@ -433,12 +433,13 @@ proc transformFor(c: PTransf, n: PNode): PTransNode =
 
   var length = sonsLen(n)
   var call = n.sons[length - 2]
-  if call.kind notin nkCallKinds or call.sons[0].kind != nkSym or 
-      call.sons[0].sym.kind != skIterator:
+  if call.typ.kind != tyIter and
+    (call.kind notin nkCallKinds or call.sons[0].kind != nkSym or 
+      call.sons[0].sym.kind != skIterator):
     n.sons[length-1] = transformLoopBody(c, n.sons[length-1]).PNode
     return lambdalifting.liftForLoop(n).PTransNode
     #InternalError(call.info, "transformFor")
-
+  
   #echo "transforming: ", renderTree(n)
   result = newTransNode(nkStmtList, n.info, 0)
   var loopBody = transformLoopBody(c, n.sons[length-1])
@@ -459,6 +460,7 @@ proc transformFor(c: PTransf, n: PNode): PTransNode =
   for i in countup(1, sonsLen(call) - 1): 
     var arg = transform(c, call.sons[i]).PNode
     var formal = skipTypes(iter.typ, abstractInst).n.sons[i].sym 
+    if arg.typ.kind == tyIter: continue
     case putArgInto(arg, formal.typ)
     of paDirectMapping: 
       idNodeTablePut(newC.mapping, formal, arg)
@@ -480,7 +482,7 @@ proc transformFor(c: PTransf, n: PNode): PTransNode =
   dec(c.inlining)
   popInfoContext()
   popTransCon(c)
-  #echo "transformed: ", renderTree(n)
+  # echo "transformed: ", result.PNode.renderTree
   
 proc getMagicOp(call: PNode): TMagic = 
   if call.sons[0].kind == nkSym and