summary refs log tree commit diff stats
path: root/compiler
diff options
context:
space:
mode:
authorAraq <rumpf_a@web.de>2014-01-23 01:41:26 +0100
committerAraq <rumpf_a@web.de>2014-01-23 01:41:26 +0100
commit3f87326247b142df4eff99a92c6529b33bb79b81 (patch)
tree632dc70d2d73e51b97fd9830a9a7ff42014df412 /compiler
parent37229df7fc044fe108d2f4d88f127141cabeb6a6 (diff)
downloadNim-3f87326247b142df4eff99a92c6529b33bb79b81.tar.gz
closure iterators almost work
Diffstat (limited to 'compiler')
-rw-r--r--compiler/lambdalifting.nim368
-rw-r--r--compiler/transf.nim12
2 files changed, 189 insertions, 191 deletions
diff --git a/compiler/lambdalifting.nim b/compiler/lambdalifting.nim
index 2189a1d67..352b40693 100644
--- a/compiler/lambdalifting.nim
+++ b/compiler/lambdalifting.nim
@@ -130,12 +130,99 @@ type
   TOuterContext {.final.} = object
     fn: PSym # may also be a module!
     currentEnv: PEnv
+    isIter: bool   # first class iterator?
     capturedVars, processed: TIntSet
     localsToEnv: TIdTable # PSym->PEnv mapping
     localsToAccess: TIdNodeTable
     lambdasToEnv: TIdTable # PSym->PEnv mapping
     up: POuterContext
 
+    closureParam, state, resultSym: PSym # only if isIter
+    tup: PType # only if isIter
+
+
+proc getStateType(iter: PSym): PType =
+  var n = newNodeI(nkRange, iter.info)
+  addSon(n, newIntNode(nkIntLit, -1))
+  addSon(n, newIntNode(nkIntLit, 0))
+  result = newType(tyRange, iter)
+  result.n = n
+  rawAddSon(result, getSysType(tyInt))
+
+proc createStateField(iter: PSym): PSym =
+  result = newSym(skField, getIdent(":state"), iter, iter.info)
+  result.typ = getStateType(iter)
+
+proc newIterResult(iter: PSym): PSym =
+  if resultPos < iter.ast.len:
+    result = iter.ast.sons[resultPos].sym
+  else:
+    # XXX a bit hacky:
+    result = newSym(skResult, getIdent":result", iter, iter.info)
+    result.typ = iter.typ.sons[0]
+    incl(result.flags, sfUsed)
+    iter.ast.add newSymNode(result)
+
+proc addHiddenParam(routine: PSym, param: PSym) =
+  var params = routine.ast.sons[paramsPos]
+  # -1 is correct here as param.position is 0 based but we have at position 0
+  # some nkEffect node:
+  param.position = params.len-1
+  addSon(params, newSymNode(param))
+  incl(routine.typ.flags, tfCapturesEnv)
+  #echo "produced environment: ", param.id, " for ", routine.name.s
+
+proc getHiddenParam(routine: PSym): PSym =
+  let params = routine.ast.sons[paramsPos]
+  let hidden = lastSon(params)
+  assert hidden.kind == nkSym
+  result = hidden.sym
+
+proc getEnvParam(routine: PSym): PSym =
+  let params = routine.ast.sons[paramsPos]
+  let hidden = lastSon(params)
+  if hidden.kind == nkSym and hidden.sym.name.s == paramName:
+    result = hidden.sym
+    
+proc addField(tup: PType, s: PSym) =
+  var field = newSym(skField, s.name, s.owner, s.info)
+  let t = skipIntLit(s.typ)
+  field.typ = t
+  field.position = sonsLen(tup)
+  addSon(tup.n, newSymNode(field))
+  rawAddSon(tup, t)
+
+proc initIterContext(c: POuterContext, iter: PSym) =
+  c.fn = iter
+  c.capturedVars = initIntSet()
+
+  var cp = getEnvParam(iter)
+  if cp == nil:
+    c.tup = newType(tyTuple, iter)
+    c.tup.n = newNodeI(nkRecList, iter.info)
+
+    cp = newSym(skParam, getIdent(paramName), iter, iter.info)
+    incl(cp.flags, sfFromGeneric)
+    cp.typ = newType(tyRef, iter)
+    rawAddSon(cp.typ, c.tup)
+    addHiddenParam(iter, cp)
+
+    c.state = createStateField(iter)
+    addField(c.tup, c.state)
+  else:
+    c.tup = cp.typ.sons[0]
+    assert c.tup.kind == tyTuple
+    if c.tup.len > 0:
+      c.state = c.tup.n[0].sym
+    else:
+      c.state = createStateField(iter)
+      addField(c.tup, c.state)
+
+  c.closureParam = cp
+  if iter.typ.sons[0] != nil:
+    c.resultSym = newIterResult(iter)
+    #iter.ast.add(newSymNode(c.resultSym))
+
 proc newOuterContext(fn: PSym, up: POuterContext = nil): POuterContext =
   new(result)
   result.fn = fn
@@ -144,24 +231,14 @@ proc newOuterContext(fn: PSym, up: POuterContext = nil): POuterContext =
   initIdNodeTable(result.localsToAccess)
   initIdTable(result.localsToEnv)
   initIdTable(result.lambdasToEnv)
+  result.isIter = fn.kind == skIterator and fn.typ.callConv == ccClosure
+  if result.isIter: initIterContext(result, fn)
   
 proc newInnerContext(fn: PSym): PInnerContext =
   new(result)
   result.fn = fn
   initIdNodeTable(result.localsToAccess)
 
-proc getStateType(iter: PSym): PType =
-  var n = newNodeI(nkRange, iter.info)
-  addSon(n, newIntNode(nkIntLit, -1))
-  addSon(n, newIntNode(nkIntLit, 0))
-  result = newType(tyRange, iter)
-  result.n = n
-  rawAddSon(result, getSysType(tyInt))
-
-proc createStateField(iter: PSym): PSym =
-  result = newSym(skField, getIdent(":state"), iter, iter.info)
-  result.typ = getStateType(iter)
-
 proc newEnv(outerProc: PSym, up: PEnv, n: PNode): PEnv =
   new(result)
   result.deps = @[]
@@ -171,14 +248,6 @@ proc newEnv(outerProc: PSym, up: PEnv, n: PNode): PEnv =
   result.up = up
   result.attachedNode = n
 
-proc addField(tup: PType, s: PSym) =
-  var field = newSym(skField, s.name, s.owner, s.info)
-  let t = skipIntLit(s.typ)
-  field.typ = t
-  field.position = sonsLen(tup)
-  addSon(tup.n, newSymNode(field))
-  rawAddSon(tup, t)
-  
 proc addCapturedVar(e: PEnv, v: PSym) =
   for x in e.capturedVars:
     if x == v: return
@@ -221,27 +290,6 @@ proc newCall(a, b: PSym): PNode =
   result.add newSymNode(a)
   result.add newSymNode(b)
 
-proc addHiddenParam(routine: PSym, param: PSym) =
-  var params = routine.ast.sons[paramsPos]
-  # -1 is correct here as param.position is 0 based but we have at position 0
-  # some nkEffect node:
-  param.position = params.len-1
-  addSon(params, newSymNode(param))
-  incl(routine.typ.flags, tfCapturesEnv)
-  #echo "produced environment: ", param.id, " for ", routine.name.s
-
-proc getHiddenParam(routine: PSym): PSym =
-  let params = routine.ast.sons[paramsPos]
-  let hidden = lastSon(params)
-  assert hidden.kind == nkSym
-  result = hidden.sym
-
-proc getEnvParam(routine: PSym): PSym =
-  let params = routine.ast.sons[paramsPos]
-  let hidden = lastSon(params)
-  if hidden.kind == nkSym and hidden.sym.name.s == paramName:
-    result = hidden.sym
-
 proc isInnerProc(s, outerProc: PSym): bool {.inline.} =
   result = (s.kind in {skProc, skMethod, skConverter} or
             s.kind == skIterator and s.typ.callConv == ccClosure) and
@@ -334,7 +382,9 @@ proc gatherVars(o: POuterContext, i: PInnerContext, n: PNode) =
     var s = n.sym
     if interestingVar(s) and i.fn.id != s.owner.id:
       captureVar(o, i, s, n.info)
-    elif isInnerProc(s, o.fn) and tfCapturesEnv in s.typ.flags and s != i.fn:
+    elif s.kind in {skProc, skMethod, skConverter} and
+            s.skipGenericOwner == o.fn and 
+            tfCapturesEnv in s.typ.flags and s != i.fn:
       # call to some other inner proc; we need to track the dependencies for
       # this:
       let env = PEnv(idTableGet(o.lambdasToEnv, i.fn))
@@ -342,7 +392,7 @@ proc gatherVars(o: POuterContext, i: PInnerContext, n: PNode) =
       if o.currentEnv != env:
         discard addDep(o.currentEnv, env, i.fn)
         internalError(n.info, "too complex environment handling required")
-  of nkEmpty..pred(nkSym), succ(nkSym)..nkNilLit: discard
+  of nkEmpty..pred(nkSym), succ(nkSym)..nkNilLit, nkClosure: discard
   else:
     for k in countup(0, sonsLen(n) - 1): 
       gatherVars(o, i, n.sons[k])
@@ -398,7 +448,8 @@ proc transformInnerProc(o: POuterContext, i: PInnerContext, n: PNode): PNode =
   of nkLambdaKinds, nkIteratorDef:
     if n.typ != nil:
       result = transformInnerProc(o, i, n.sons[namePos])
-  of nkProcDef, nkMethodDef, nkConverterDef, nkMacroDef, nkTemplateDef:
+  of nkProcDef, nkMethodDef, nkConverterDef, nkMacroDef, nkTemplateDef,
+      nkClosure:
     # don't recurse here:
     discard
   else:
@@ -467,7 +518,8 @@ proc searchForInnerProcs(o: POuterContext, n: PNode) =
         searchForInnerProcs(o, it.sons[L-1])
       else:
         internalError(it.info, "transformOuter")
-  of nkProcDef, nkMethodDef, nkConverterDef, nkMacroDef, nkTemplateDef:
+  of nkProcDef, nkMethodDef, nkConverterDef, nkMacroDef, nkTemplateDef, 
+     nkClosure:
     # don't recurse here:
     # XXX recurse here and setup 'up' pointers
     discard
@@ -526,12 +578,61 @@ proc generateClosureCreation(o: POuterContext, scope: PEnv): PNode =
     result.add(newAsgnStmt(indirectAccess(env, field, env.info),
                newSymNode(getClosureVar(o, e)), env.info))
 
+proc interestingIterVar(s: PSym): bool {.inline.} =
+  result = s.kind in {skVar, skLet, skTemp, skForVar} and sfGlobal notin s.flags
+
+proc transformOuterProc(o: POuterContext, n: PNode): PNode
+
+proc transformYield(c: POuterContext, n: PNode): PNode =
+  inc c.state.typ.n.sons[1].intVal
+  let stateNo = c.state.typ.n.sons[1].intVal
+
+  var stateAsgnStmt = newNodeI(nkAsgn, n.info)
+  stateAsgnStmt.add(indirectAccess(newSymNode(c.closureParam),c.state,n.info))
+  stateAsgnStmt.add(newIntTypeNode(nkIntLit, stateNo, getSysType(tyInt)))
+
+  var retStmt = newNodeI(nkReturnStmt, n.info)
+  if n.sons[0].kind != nkEmpty:
+    var a = newNodeI(nkAsgn, n.sons[0].info)
+    var retVal = transformOuterProc(c, n.sons[0])
+    addSon(a, newSymNode(c.resultSym))
+    addSon(a, if retVal.isNil: n.sons[0] else: retVal)
+    retStmt.add(a)
+  else:
+    retStmt.add(emptyNode)
+  
+  var stateLabelStmt = newNodeI(nkState, n.info)
+  stateLabelStmt.add(newIntTypeNode(nkIntLit, stateNo, getSysType(tyInt)))
+  
+  result = newNodeI(nkStmtList, n.info)
+  result.add(stateAsgnStmt)
+  result.add(retStmt)
+  result.add(stateLabelStmt)
+
+proc transformReturn(c: POuterContext, n: PNode): PNode =
+  result = newNodeI(nkStmtList, n.info)
+  var stateAsgnStmt = newNodeI(nkAsgn, n.info)
+  stateAsgnStmt.add(indirectAccess(newSymNode(c.closureParam),c.state,n.info))
+  stateAsgnStmt.add(newIntTypeNode(nkIntLit, -1, getSysType(tyInt)))
+  result.add(stateAsgnStmt)
+  result.add(n)
+
+proc outerProcSons(o: POuterContext, n: PNode) =
+  for i in countup(0, sonsLen(n) - 1):
+    let x = transformOuterProc(o, n.sons[i])
+    if x != nil: n.sons[i] = x
+
 proc transformOuterProc(o: POuterContext, n: PNode): PNode =
   if n == nil: return nil
   case n.kind
   of nkEmpty..pred(nkSym), succ(nkSym)..nkNilLit: discard
   of nkSym:
     var local = n.sym
+
+    if o.isIter and interestingIterVar(local) and o.fn.id == local.owner.id:
+      if not containsOrIncl(o.capturedVars, local.id): addField(o.tup, local)
+      return indirectAccess(newSymNode(o.closureParam), local, n.info)
+
     var closure = PEnv(idTableGet(o.lambdasToEnv, local))
     if closure != nil:
       # we need to replace the lambda with '(lambda, env)': 
@@ -567,17 +668,44 @@ proc transformOuterProc(o: POuterContext, n: PNode): PNode =
   of nkLambdaKinds, nkIteratorDef:
     if n.typ != nil:
       result = transformOuterProc(o, n.sons[namePos])
-  of nkProcDef, nkMethodDef, nkConverterDef, nkMacroDef, nkTemplateDef:
+  of nkProcDef, nkMethodDef, nkConverterDef, nkMacroDef, nkTemplateDef,
+      nkClosure:
     # don't recurse here:
     discard
   of nkHiddenStdConv, nkHiddenSubConv, nkConv:
     let x = transformOuterProc(o, n.sons[1])
     if x != nil: n.sons[1] = x
     result = transformOuterConv(n)
+  of nkYieldStmt:
+    if o.isIter: result = transformYield(o, n)
+    else: outerProcSons(o, n)
+  of nkReturnStmt:
+    if o.isIter: result = transformReturn(o, n)
+    else: outerProcSons(o, n)
   else:
-    for i in countup(0, sonsLen(n) - 1):
-      let x = transformOuterProc(o, n.sons[i])
-      if x != nil: n.sons[i] = x
+    outerProcSons(o, n)
+
+proc liftIterator(c: POuterContext, body: PNode): PNode =
+  let iter = c.fn
+  result = newNodeI(nkStmtList, iter.info)
+  var gs = newNodeI(nkGotoState, iter.info)
+  gs.add(indirectAccess(newSymNode(c.closureParam), c.state, iter.info))
+  result.add(gs)
+  var state0 = newNodeI(nkState, iter.info)
+  state0.add(newIntNode(nkIntLit, 0))
+  result.add(state0)
+  
+  let newBody = transformOuterProc(c, body)
+  if newBody != nil:
+    result.add(newBody)
+  else:
+    result.add(body)
+
+  var stateAsgnStmt = newNodeI(nkAsgn, iter.info)
+  stateAsgnStmt.add(indirectAccess(newSymNode(c.closureParam),
+                    c.state,iter.info))
+  stateAsgnStmt.add(newIntTypeNode(nkIntLit, -1, getSysType(tyInt)))
+  result.add(stateAsgnStmt)
 
 proc liftLambdas*(fn: PSym, body: PNode): PNode =
   # XXX gCmd == cmdCompileToJS does not suffice! The compiletime stuff needs
@@ -601,8 +729,11 @@ proc liftLambdas*(fn: PSym, body: PNode): PNode =
     if resultPos < sonsLen(ast) and ast.sons[resultPos].kind == nkSym:
       idTablePut(o.localsToEnv, ast.sons[resultPos].sym, o.currentEnv)
     searchForInnerProcs(o, body)
-    discard transformOuterProc(o, body)
-    result = ex
+    if o.isIter:
+      result = liftIterator(o, ex)
+    else:
+      discard transformOuterProc(o, body)
+      result = ex
 
 proc liftLambdasForTopLevel*(module: PSym, body: PNode): PNode =
   if body.kind == nkEmpty or gCmd == cmdCompileToJS:
@@ -617,144 +748,11 @@ proc liftLambdasForTopLevel*(module: PSym, body: PNode): PNode =
 
 # ------------------- iterator transformation --------------------------------
 
-discard """
-  iterator chain[S, T](a, b: *S->T, args: *S): T =
-    for x in a(args): yield x
-    for x in b(args): yield x
-
-  let c = chain(f, g)
-  for x in c: echo x
-  
-  # translated to:
-  let c = chain( (f, newClosure(f)), (g, newClosure(g)), newClosure(chain))
-"""
-
-type
-  TIterContext {.final, pure.} = object
-    iter, closureParam, state, resultSym: PSym
-    capturedVars: TIntSet
-    tup: PType
-
-proc newIterResult(iter: PSym): PSym =
-  if resultPos < iter.ast.len:
-    result = iter.ast.sons[resultPos].sym
-  else:
-    # XXX a bit hacky:
-    result = newSym(skResult, getIdent":result", iter, iter.info)
-    result.typ = iter.typ.sons[0]
-    incl(result.flags, sfUsed)
-    iter.ast.add newSymNode(result)
-
-proc interestingIterVar(s: PSym): bool {.inline.} =
-  result = s.kind in {skVar, skLet, skTemp, skForVar} and sfGlobal notin s.flags
-
-proc transfIterBody(c: var TIterContext, n: PNode): PNode =
-  # gather used vars for closure generation
-  if n == nil: return nil
-  case n.kind
-  of nkSym:
-    var s = n.sym
-    if interestingIterVar(s) and c.iter.id == s.owner.id:
-      if not containsOrIncl(c.capturedVars, s.id): addField(c.tup, s)
-      result = indirectAccess(newSymNode(c.closureParam), s, n.info)
-  of nkEmpty..pred(nkSym), succ(nkSym)..nkNilLit: discard
-  of nkYieldStmt:
-    inc c.state.typ.n.sons[1].intVal
-    let stateNo = c.state.typ.n.sons[1].intVal
-
-    var stateAsgnStmt = newNodeI(nkAsgn, n.info)
-    stateAsgnStmt.add(indirectAccess(newSymNode(c.closureParam),c.state,n.info))
-    stateAsgnStmt.add(newIntTypeNode(nkIntLit, stateNo, getSysType(tyInt)))
-
-    var retStmt = newNodeI(nkReturnStmt, n.info)
-    if n.sons[0].kind != nkEmpty:
-      var a = newNodeI(nkAsgn, n.sons[0].info)
-      var retVal = transfIterBody(c, n.sons[0])
-      addSon(a, newSymNode(c.resultSym))
-      addSon(a, if retVal.isNil: n.sons[0] else: retVal)
-      retStmt.add(a)
-    else:
-      retStmt.add(emptyNode)
-    
-    var stateLabelStmt = newNodeI(nkState, n.info)
-    stateLabelStmt.add(newIntTypeNode(nkIntLit, stateNo, getSysType(tyInt)))
-    
-    result = newNodeI(nkStmtList, n.info)
-    result.add(stateAsgnStmt)
-    result.add(retStmt)
-    result.add(stateLabelStmt)
-  of nkReturnStmt:
-    result = newNodeI(nkStmtList, n.info)
-    var stateAsgnStmt = newNodeI(nkAsgn, n.info)
-    stateAsgnStmt.add(indirectAccess(newSymNode(c.closureParam),c.state,n.info))
-    stateAsgnStmt.add(newIntTypeNode(nkIntLit, -1, getSysType(tyInt)))
-    result.add(stateAsgnStmt)
-    result.add(n)
-  else:
-    for i in countup(0, sonsLen(n)-1):
-      let x = transfIterBody(c, n.sons[i])
-      if x != nil: n.sons[i] = x
-
-proc initIterContext(c: var TIterContext, iter: PSym) =
-  c.iter = iter
-  c.capturedVars = initIntSet()
-
-  var cp = getEnvParam(iter)
-  if cp == nil:
-    c.tup = newType(tyTuple, iter)
-    c.tup.n = newNodeI(nkRecList, iter.info)
-
-    cp = newSym(skParam, getIdent(paramName), iter, iter.info)
-    incl(cp.flags, sfFromGeneric)
-    cp.typ = newType(tyRef, iter)
-    rawAddSon(cp.typ, c.tup)
-    addHiddenParam(iter, cp)
-
-    c.state = createStateField(iter)
-    addField(c.tup, c.state)
-  else:
-    c.tup = cp.typ.sons[0]
-    assert c.tup.kind == tyTuple
-    if c.tup.len > 0:
-      c.state = c.tup.n[0].sym
-    else:
-      c.state = createStateField(iter)
-      addField(c.tup, c.state)
-
-  c.closureParam = cp
-  if iter.typ.sons[0] != nil:
-    c.resultSym = newIterResult(iter)
-    #iter.ast.add(newSymNode(c.resultSym))
-
-proc liftIterator*(iter: PSym, body: PNode): PNode =
-  var c: TIterContext
-  initIterContext c, iter
-
-  result = newNodeI(nkStmtList, iter.info)
-  var gs = newNodeI(nkGotoState, iter.info)
-  gs.add(indirectAccess(newSymNode(c.closureParam), c.state, iter.info))
-  result.add(gs)
-  var state0 = newNodeI(nkState, iter.info)
-  state0.add(newIntNode(nkIntLit, 0))
-  result.add(state0)
-  
-  let newBody = transfIterBody(c, body)
-  if newBody != nil:
-    result.add(newBody)
-  else:
-    result.add(body)
-
-  var stateAsgnStmt = newNodeI(nkAsgn, iter.info)
-  stateAsgnStmt.add(indirectAccess(newSymNode(c.closureParam),
-                    c.state,iter.info))
-  stateAsgnStmt.add(newIntTypeNode(nkIntLit, -1, getSysType(tyInt)))
-  result.add(stateAsgnStmt)
-
 proc liftIterSym*(n: PNode): PNode =
   # transforms  (iter)  to  (let env = newClosure[iter](); (iter, env)) 
   let iter = n.sym
   assert iter.kind == skIterator
-  if sfClosureCreated in iter.flags: return n
+  #if sfClosureCreated in iter.flags: return n
   
   result = newNodeIT(nkStmtListExpr, n.info, n.typ)
   
diff --git a/compiler/transf.nim b/compiler/transf.nim
index 973e8848a..cda611005 100644
--- a/compiler/transf.nim
+++ b/compiler/transf.nim
@@ -113,8 +113,8 @@ proc newAsgnStmt(c: PTransf, le: PNode, ri: PTransNode): PTransNode =
   result[1] = ri
 
 proc transformSymAux(c: PTransf, n: PNode): PNode =
-  if n.sym.kind == skIterator and n.sym.typ.callConv == ccClosure:
-    return liftIterSym(n)
+  #if n.sym.kind == skIterator and n.sym.typ.callConv == ccClosure:
+  #  return liftIterSym(n)
   var b: PNode
   var tc = c.transCon
   if sfBorrow in n.sym.flags: 
@@ -636,8 +636,8 @@ proc transform(c: PTransf, n: PNode): PTransNode =
           s.ast.sons[bodyPos] = n.sons[bodyPos]
         #n.sons[bodyPos] = liftLambdas(s, n)
         #if n.kind == nkMethodDef: methodDef(s, false)
-    if n.kind == nkIteratorDef and n.typ != nil:
-      return liftIterSym(n.sons[namePos]).PTransNode
+    #if n.kind == nkIteratorDef and n.typ != nil:
+    #  return liftIterSym(n.sons[namePos]).PTransNode
     result = PTransNode(n)
   of nkMacroDef:
     # XXX no proper closure support yet:
@@ -741,8 +741,8 @@ proc transformBody*(module: PSym, n: PNode, prc: PSym): PNode =
     var c = openTransf(module, "")
     result = processTransf(c, n, prc)
     result = liftLambdas(prc, result)
-    if prc.kind == skIterator and prc.typ.callConv == ccClosure:
-      result = lambdalifting.liftIterator(prc, result)
+    #if prc.kind == skIterator and prc.typ.callConv == ccClosure:
+    #  result = lambdalifting.liftIterator(prc, result)
     incl(result.flags, nfTransf)
     when useEffectSystem: trackProc(prc, result)