summary refs log tree commit diff stats
path: root/compiler/lambdalifting.nim
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/lambdalifting.nim')
-rw-r--r--compiler/lambdalifting.nim134
1 files changed, 86 insertions, 48 deletions
diff --git a/compiler/lambdalifting.nim b/compiler/lambdalifting.nim
index ed92fefb4..2189a1d67 100644
--- a/compiler/lambdalifting.nim
+++ b/compiler/lambdalifting.nim
@@ -116,9 +116,9 @@ type
   TDep = tuple[e: PEnv, field: PSym]
   TEnv {.final.} = object of TObject
     attachedNode: PNode
-    closure: PSym   # if != nil it is a used environment
+    closure: PSym           # if != nil it is a used environment
     capturedVars: seq[PSym] # captured variables in this environment
-    deps: seq[TDep] # dependencies
+    deps: seq[TDep]         # dependencies
     up: PEnv
     tup: PType
   
@@ -149,7 +149,19 @@ proc newInnerContext(fn: PSym): PInnerContext =
   new(result)
   result.fn = fn
   initIdNodeTable(result.localsToAccess)
-  
+
+proc getStateType(iter: PSym): PType =
+  var n = newNodeI(nkRange, iter.info)
+  addSon(n, newIntNode(nkIntLit, -1))
+  addSon(n, newIntNode(nkIntLit, 0))
+  result = newType(tyRange, iter)
+  result.n = n
+  rawAddSon(result, getSysType(tyInt))
+
+proc createStateField(iter: PSym): PSym =
+  result = newSym(skField, getIdent(":state"), iter, iter.info)
+  result.typ = getStateType(iter)
+
 proc newEnv(outerProc: PSym, up: PEnv, n: PNode): PEnv =
   new(result)
   result.deps = @[]
@@ -170,6 +182,9 @@ proc addField(tup: PType, s: PSym) =
 proc addCapturedVar(e: PEnv, v: PSym) =
   for x in e.capturedVars:
     if x == v: return
+  # XXX meh, just add the state field for every closure for now, it's too
+  # hard to figure out if it comes from a closure iterator:
+  if e.tup.len == 0: addField(e.tup, createStateField(v.owner))
   e.capturedVars.add(v)
   addField(e.tup, v)
   
@@ -189,6 +204,7 @@ proc indirectAccess(a: PNode, b: PSym, info: TLineInfo): PNode =
   # returns a[].b as a node
   var deref = newNodeI(nkHiddenDeref, info)
   deref.typ = a.typ.sons[0]
+  assert deref.typ.kind == tyTuple
   let field = getSymFromList(deref.typ.n, b.name)
   assert field != nil, b.name.s
   addSon(deref, a)
@@ -220,18 +236,30 @@ proc getHiddenParam(routine: PSym): PSym =
   assert hidden.kind == nkSym
   result = hidden.sym
 
+proc getEnvParam(routine: PSym): PSym =
+  let params = routine.ast.sons[paramsPos]
+  let hidden = lastSon(params)
+  if hidden.kind == nkSym and hidden.sym.name.s == paramName:
+    result = hidden.sym
+
 proc isInnerProc(s, outerProc: PSym): bool {.inline.} =
-  result = s.kind in {skProc, skMethod, skConverter} and 
+  result = (s.kind in {skProc, skMethod, skConverter} or
+            s.kind == skIterator and s.typ.callConv == ccClosure) and
            s.skipGenericOwner == outerProc
   #s.typ.callConv == ccClosure
 
 proc addClosureParam(i: PInnerContext, e: PEnv) =
-  var cp = newSym(skParam, getIdent(paramName), i.fn, i.fn.info)
-  incl(cp.flags, sfFromGeneric)
-  cp.typ = newType(tyRef, i.fn)
-  rawAddSon(cp.typ, e.tup)
+  var cp = getEnvParam(i.fn)
+  if cp == nil:
+    cp = newSym(skParam, getIdent(paramName), i.fn, i.fn.info)
+    incl(cp.flags, sfFromGeneric)
+    cp.typ = newType(tyRef, i.fn)
+    rawAddSon(cp.typ, e.tup)
+    addHiddenParam(i.fn, cp)
+  else:
+    e.tup = cp.typ.sons[0]
+    assert e.tup.kind == tyTuple
   i.closureParam = cp
-  addHiddenParam(i.fn, i.closureParam)
   #echo "closure param added for ", i.fn.name.s, " ", i.fn.id
 
 proc dummyClosureParam(o: POuterContext, i: PInnerContext) =
@@ -344,6 +372,7 @@ proc transformOuterConv(n: PNode): PNode =
 proc makeClosure(prc, env: PSym, info: TLineInfo): PNode =
   result = newNodeIT(nkClosure, info, prc.typ)
   result.add(newSymNode(prc))
+  if prc.kind == skIterator: incl(prc.flags, sfClosureCreated)
   if env == nil:
     result.add(newNodeIT(nkNilLit, info, getSysType(tyNil)))
   else:
@@ -366,10 +395,10 @@ proc transformInnerProc(o: POuterContext, i: PInnerContext, n: PNode): PNode =
     else:
       # captured symbol?
       result = idNodeTableGet(i.localsToAccess, n.sym)
-  of nkLambdaKinds:
-    result = transformInnerProc(o, i, n.sons[namePos])
-  of nkProcDef, nkMethodDef, nkConverterDef, nkMacroDef, nkTemplateDef,
-     nkIteratorDef:
+  of nkLambdaKinds, nkIteratorDef:
+    if n.typ != nil:
+      result = transformInnerProc(o, i, n.sons[namePos])
+  of nkProcDef, nkMethodDef, nkConverterDef, nkMacroDef, nkTemplateDef:
     # don't recurse here:
     discard
   else:
@@ -400,8 +429,9 @@ proc searchForInnerProcs(o: POuterContext, n: PNode) =
       if inner.closureParam != nil:
         let ti = transformInnerProc(o, inner, body)
         if ti != nil: n.sym.ast.sons[bodyPos] = ti
-  of nkLambdaKinds:
-    searchForInnerProcs(o, n.sons[namePos])
+  of nkLambdaKinds, nkIteratorDef:
+    if n.typ != nil:
+      searchForInnerProcs(o, n.sons[namePos])
   of nkWhileStmt, nkForStmt, nkParForStmt, nkBlockStmt:
     # some nodes open a new scope, so they are candidates for the insertion
     # of closure creation; however for simplicity we merge closures between
@@ -437,8 +467,7 @@ proc searchForInnerProcs(o: POuterContext, n: PNode) =
         searchForInnerProcs(o, it.sons[L-1])
       else:
         internalError(it.info, "transformOuter")
-  of nkProcDef, nkMethodDef, nkConverterDef, nkMacroDef, nkTemplateDef, 
-     nkIteratorDef:
+  of nkProcDef, nkMethodDef, nkConverterDef, nkMacroDef, nkTemplateDef:
     # don't recurse here:
     # XXX recurse here and setup 'up' pointers
     discard
@@ -535,10 +564,10 @@ proc transformOuterProc(o: POuterContext, n: PNode): PNode =
     assert result != nil, "cannot find: " & local.name.s
     # else it is captured by copy and this means that 'outer' should continue
     # to access the local as a local.
-  of nkLambdaKinds:
-    result = transformOuterProc(o, n.sons[namePos])
-  of nkProcDef, nkMethodDef, nkConverterDef, nkMacroDef, nkTemplateDef, 
-     nkIteratorDef: 
+  of nkLambdaKinds, nkIteratorDef:
+    if n.typ != nil:
+      result = transformOuterProc(o, n.sons[namePos])
+  of nkProcDef, nkMethodDef, nkConverterDef, nkMacroDef, nkTemplateDef:
     # don't recurse here:
     discard
   of nkHiddenStdConv, nkHiddenSubConv, nkConv:
@@ -607,11 +636,14 @@ type
     tup: PType
 
 proc newIterResult(iter: PSym): PSym =
-  result = iter.ast.sons[resultPos].sym
-  when false:
+  if resultPos < iter.ast.len:
+    result = iter.ast.sons[resultPos].sym
+  else:
+    # XXX a bit hacky:
     result = newSym(skResult, getIdent":result", iter, iter.info)
     result.typ = iter.typ.sons[0]
     incl(result.flags, sfUsed)
+    iter.ast.add newSymNode(result)
 
 proc interestingIterVar(s: PSym): bool {.inline.} =
   result = s.kind in {skVar, skLet, skTemp, skForVar} and sfGlobal notin s.flags
@@ -663,36 +695,40 @@ proc transfIterBody(c: var TIterContext, n: PNode): PNode =
       let x = transfIterBody(c, n.sons[i])
       if x != nil: n.sons[i] = x
 
-proc getStateType(iter: PSym): PType =
-  var n = newNodeI(nkRange, iter.info)
-  addSon(n, newIntNode(nkIntLit, -1))
-  addSon(n, newIntNode(nkIntLit, 0))
-  result = newType(tyRange, iter)
-  result.n = n
-  rawAddSon(result, getSysType(tyInt))
-
-proc liftIterator*(iter: PSym, body: PNode): PNode =
-  var c: TIterContext
+proc initIterContext(c: var TIterContext, iter: PSym) =
   c.iter = iter
   c.capturedVars = initIntSet()
 
-  c.tup = newType(tyTuple, iter)
-  c.tup.n = newNodeI(nkRecList, iter.info)
+  var cp = getEnvParam(iter)
+  if cp == nil:
+    c.tup = newType(tyTuple, iter)
+    c.tup.n = newNodeI(nkRecList, iter.info)
 
-  var cp = newSym(skParam, getIdent(paramName), iter, iter.info)
-  incl(cp.flags, sfFromGeneric)
-  cp.typ = newType(tyRef, iter)
-  rawAddSon(cp.typ, c.tup)
-  c.closureParam = cp
-  addHiddenParam(iter, cp)
+    cp = newSym(skParam, getIdent(paramName), iter, iter.info)
+    incl(cp.flags, sfFromGeneric)
+    cp.typ = newType(tyRef, iter)
+    rawAddSon(cp.typ, c.tup)
+    addHiddenParam(iter, cp)
 
-  c.state = newSym(skField, getIdent(":state"), iter, iter.info)
-  c.state.typ = getStateType(iter)
-  addField(c.tup, c.state)
+    c.state = createStateField(iter)
+    addField(c.tup, c.state)
+  else:
+    c.tup = cp.typ.sons[0]
+    assert c.tup.kind == tyTuple
+    if c.tup.len > 0:
+      c.state = c.tup.n[0].sym
+    else:
+      c.state = createStateField(iter)
+      addField(c.tup, c.state)
 
+  c.closureParam = cp
   if iter.typ.sons[0] != nil:
     c.resultSym = newIterResult(iter)
-    iter.ast.add(newSymNode(c.resultSym))
+    #iter.ast.add(newSymNode(c.resultSym))
+
+proc liftIterator*(iter: PSym, body: PNode): PNode =
+  var c: TIterContext
+  initIterContext c, iter
 
   result = newNodeI(nkStmtList, iter.info)
   var gs = newNodeI(nkGotoState, iter.info)
@@ -716,12 +752,14 @@ proc liftIterator*(iter: PSym, body: PNode): PNode =
 
 proc liftIterSym*(n: PNode): PNode =
   # transforms  (iter)  to  (let env = newClosure[iter](); (iter, env)) 
-  result = newNodeIT(nkStmtListExpr, n.info, n.typ)
   let iter = n.sym
   assert iter.kind == skIterator
+  if sfClosureCreated in iter.flags: return n
+  
+  result = newNodeIT(nkStmtListExpr, n.info, n.typ)
+  
   var env = copySym(getHiddenParam(iter))
   env.kind = skLet
-
   var v = newNodeI(nkVarSection, n.info)
   addVar(v, newSymNode(env))
   result.add(v)
@@ -766,7 +804,7 @@ proc liftForLoop*(body: PNode): PNode =
   # static binding?
   var env: PSym
   if call[0].kind == nkSym and call[0].sym.kind == skIterator:
-    # createClose()
+    # createClosure()
     let iter = call[0].sym
     assert iter.kind == skIterator
     env = copySym(getHiddenParam(iter))