summary refs log tree commit diff stats
diff options
context:
space:
mode:
authorAraq <rumpf_a@web.de>2012-06-18 02:03:08 +0200
committerAraq <rumpf_a@web.de>2012-06-18 02:03:08 +0200
commitd5b01dfb7ce96116eb6184a90be4e902f0a2a649 (patch)
tree0e216a17d6865ca1203b88a39972dc5ecf3762fc
parent7076f07228f65b05312b609f89dbac767b69394f (diff)
downloadNim-d5b01dfb7ce96116eb6184a90be4e902f0a2a649.tar.gz
next steps for full closure support
-rw-r--r--compiler/lambdalifting.nim264
-rw-r--r--tests/run/tclosure2.nim4
-rw-r--r--tests/run/tinterf.nim16
-rwxr-xr-xtodo.txt3
4 files changed, 153 insertions, 134 deletions
diff --git a/compiler/lambdalifting.nim b/compiler/lambdalifting.nim
index 8dd54131f..ea5aea449 100644
--- a/compiler/lambdalifting.nim
+++ b/compiler/lambdalifting.nim
@@ -49,9 +49,10 @@ discard """
         var bcl2: *
         new bcl2
         bcl2.up = bcl
+        bcl2.up2 = cl
         bcl2.x = x
       
-        proc c(cl) = capture cl.up.up.v, cl.up.w, cl.x
+        proc c(cl) = capture cl.up2.v, cl.up.w, cl.x
         c(bcl2)
       
       c(bcl)
@@ -107,68 +108,92 @@ const
   declarativeDefs* = {nkProcDef, nkMethodDef, nkIteratorDef, nkConverterDef}
   procDefs* = nkLambdaKinds + declarativeDefs
   upName* = ":up" # field name for the 'up' reference
+  paramName* = ":env"
   envName* = ":env"
 
 type
-  TLLShared {.final.} = object
-    upField: PSym
-  
   PInnerContext = ref TInnerContext
   POuterContext = ref TOuterContext
-  PLLShared = ref TLLShared
-  PBlock = ref TBlock
-  
-  TBlock {.final.} = object
-    body: PNode
-    closure: PSym
-    used: bool
+
+  PEnv = ref TEnv
+  TDep = tuple[e: PEnv, field: PSym]
+  TEnv {.final.} = object of TObject
+    attachedNode: PNode
+    closure: PSym   # if != nil it is a used environment
+    capturedVars: seq[PSym] # captured variables in this environment
+    deps: seq[TDep] # dependencies
+    up: PEnv
+    tup: PType
   
   TInnerContext {.final.} = object
     fn: PSym
     closureParam: PSym
     localsToAccess: TIdNodeTable
-    up: POuterContext         # used for chaining
-    levelsUp: int             # counts how many "up levels" are accessed
-    tup: PType
     
   TOuterContext {.final.} = object
     fn: PSym
-    currentBlock: PNode
-    capturedVars: TIntSet
-    localsToEnclosingScope: TIdNodeTable
+    currentEnv: PEnv
+    capturedVars, processed: TIntSet
+    localsToEnv: TIdTable # PSym->PEnv mapping
     localsToAccess: TIdNodeTable
-    lambdasToEnclosingScope: TIdNodeTable
-  
-    shared: PLLShared
+    lambdasToEnv: TIdTable # PSym->PEnv mapping
     up: POuterContext
 
-proc newOuterContext(fn: PSym, shared: PLLShared, 
-                     up: POuterContext = nil): POuterContext =
+proc newOuterContext(fn: PSym, up: POuterContext = nil): POuterContext =
   new(result)
   result.fn = fn
-  result.shared = shared
   result.capturedVars = initIntSet()
+  result.processed = initIntSet()
   initIdNodeTable(result.localsToAccess)
-  initIdNodeTable(result.localsToEnclosingScope)
-  initIdNodeTable(result.lambdasToEnclosingScope)
+  initIdTable(result.localsToEnv)
+  initIdTable(result.lambdasToEnv)
   
-proc newInnerContext(fn: PSym, outer: POuterContext): PInnerContext =
+proc newInnerContext(fn: PSym): PInnerContext =
   new(result)
-  result.up = outer
   result.fn = fn
   initIdNodeTable(result.localsToAccess)
   
+proc newEnv(outerProc: PSym, up: PEnv, n: PNode): PEnv =
+  new(result)
+  result.deps = @[]
+  result.capturedVars = @[]
+  result.tup = newType(tyTuple, outerProc)
+  result.tup.n = newNodeI(nkRecList, outerProc.info)
+  result.up = up
+  result.attachedNode = n
+
+proc addField(tup: PType, s: PSym) =
+  var field = newSym(skField, s.name, s.owner)
+  field.typ = s.typ
+  field.position = sonsLen(tup)
+  addSon(tup.n, newSymNode(field))
+  addSon(tup, s.typ)
+  
+proc addCapturedVar(e: PEnv, v: PSym) =
+  for x in e.capturedVars:
+    if x == v: return
+  e.capturedVars.add(v)
+  addField(e.tup, v)
+  
+proc addDep(e, d: PEnv, owner: PSym): PSym =
+  for x, field in items(e.deps):
+    if x == d: return field
+  var pos = sonsLen(e.tup)
+  result = newSym(skField, getIdent(upName & $pos), owner)
+  result.typ = newType(tyRef, owner)
+  result.position = pos
+  assert d.tup != nil
+  addSon(result.typ, d.tup)
+  addField(e.tup, result)
+  e.deps.add((d, result))
+  
 proc indirectAccess(a: PNode, b: PSym, info: TLineInfo): PNode = 
   # returns a[].b as a node
-  let x = a
   var deref = newNodeI(nkHiddenDeref, info)
-  deref.typ = x.typ.sons[0]
-  
+  deref.typ = a.typ.sons[0]
   let field = getSymFromList(deref.typ.n, b.name)
-  if field == nil:
-    echo b.name.s
-    assert false
-  addSon(deref, x)
+  assert field != nil, b.name.s
+  addSon(deref, a)
   result = newNodeI(nkDotExpr, info)
   addSon(result, deref)
   addSon(result, newSymNode(field))
@@ -182,74 +207,49 @@ proc newCall(a, b: PSym): PNode =
   result.add newSymNode(a)
   result.add newSymNode(b)
 
-proc addField(tup: PType, s: PSym) =
-  var field = newSym(skField, s.name, s.owner)
-  field.typ = s.typ
-  field.position = sonsLen(tup)
-  addSon(tup.n, newSymNode(field))
-  addSon(tup, s.typ)
-
 proc addHiddenParam(routine: PSym, param: PSym) =
   var params = routine.ast.sons[paramsPos]
   param.position = params.len
   addSon(params, newSymNode(param))
-  echo "produced environment: ", param.id, " for ", routine.name.s
+  #echo "produced environment: ", param.id, " for ", routine.name.s
 
 proc isInnerProc(s, outerProc: PSym): bool {.inline.} =
-  if s.name.s[0] == ':':
-    debug s
-    debug s.owner
-    debug outerProc
   result = s.kind in {skProc, skIterator, skMethod, skConverter} and
     s.owner == outerProc and not isGenericRoutine(s)
   #s.typ.callConv == ccClosure
 
-proc captureVar(o: POuterContext, i: PInnerContext, local: PSym,
+proc captureVar(o: POuterContext, i: PInnerContext, local: PSym, 
                 info: TLineInfo) =
-  discard """
-    Consider:
-      var x = 0
-      var y = 2
-      capture x, y
-      
-      block:
-        var z = 3
-        capture z
-      
-    We need to merge x, y into a closure, but not z! 
-  """
-  # we need to remember which outer closure belongs to this lambda; we also
-  # use this check to prevent multiple runs over the same inner proc:
-  echo "enter"
-  if IdNodeTableGet(o.lambdasToEnclosingScope, i.fn) != nil: return
-  IdNodeTablePut(o.lambdasToEnclosingScope, i.fn, o.currentBlock)
+  # we need to remember which inner most closure belongs to this lambda:
+  var e = o.currentEnv
+  if IdTableGet(o.lambdasToEnv, i.fn) == nil:
+    IdTablePut(o.lambdasToEnv, i.fn, e)
 
+  # variable already captured:
   if IdNodeTableGet(i.localsToAccess, local) != nil: return
   if i.closureParam == nil:
-    var cp = newSym(skParam, getIdent(upname), i.fn)
+    var cp = newSym(skParam, getIdent(paramname), i.fn)
     cp.info = i.fn.info
     incl(cp.flags, sfFromGeneric)
-    i.tup = newType(tyTuple, i.fn)
-    i.tup.n = newNodeI(nkRecList, i.fn.info)
     cp.typ = newType(tyRef, i.fn)
-    addSon(cp.typ, i.tup)
+    addSon(cp.typ, e.tup)
     i.closureParam = cp
     addHiddenParam(i.fn, i.closureParam)
-  addField(i.tup, local)
-  var it = i.up
+  
+  # check which environment `local` belongs to:
   var access = newSymNode(i.closureParam)
-  var levelsUp = 0
-  while it.fn.id != local.owner.id:
-    assert false
-    access = indirectAccess(access, o.shared.upField, info)
-    it = it.up
-    assert it != nil
-    inc levelsUp
-  i.levelsUp = max(i.levelsUp, levelsUp)
+  var it = PEnv(IdTableGet(o.localsToEnv, local))
+  assert it != nil
+  addCapturedVar(it, local)
+  if it == e:
+    # common case: local directly in current environment:
+    nil
+  else:
+    # it's in some upper environment:
+    access = indirectAccess(access, addDep(e, it, i.fn), info)
   access = indirectAccess(access, local, info)
-  IdNodeTablePut(i.localsToAccess, local, access)
   incl(o.capturedVars, local.id)
-  echo "exit"
+  IdNodeTablePut(i.localsToAccess, local, access)
 
 proc interestingVar(s: PSym): bool {.inline.} =
   result = s.kind in {skVar, skLet, skTemp, skForVar, skParam, skResult} and
@@ -262,6 +262,14 @@ proc gatherVars(o: POuterContext, i: PInnerContext, n: PNode) =
     var s = n.sym
     if interestingVar(s) and i.fn.id != s.owner.id:
       captureVar(o, i, s, n.info)
+    elif isInnerProc(s, o.fn) and s.typ.callConv == ccClosure and s != i.fn:
+      # call to some other inner proc; we need to track the dependencies for
+      # this:
+      let env = PEnv(IdTableGet(o.lambdasToEnv, i.fn))
+      if env == nil: InternalError(n.info, "no environment computed")
+      if o.currentEnv != env:
+        discard addDep(o.currentEnv, env, i.fn)
+        InternalError(n.info, "too complex enviroment handling required")
   of nkEmpty..pred(nkSym), succ(nkSym)..nkNilLit: nil
   else:
     for k in countup(0, sonsLen(n) - 1): 
@@ -279,10 +287,16 @@ proc transformInnerProc(o: POuterContext, i: PInnerContext, n: PNode): PNode =
   case n.kind
   of nkEmpty..pred(nkSym), succ(nkSym)..nkNilLit: nil
   of nkSym:
-    if n.sym == i.fn: 
+    let s = n.sym
+    if s == i.fn: 
       # recursive calls go through (lambda, hiddenParam):
       assert i.closureParam != nil
-      result = makeClosure(n.sym, i.closureParam, n.info)
+      result = makeClosure(s, i.closureParam, n.info)
+    elif isInnerProc(s, o.fn) and s.typ.callConv == ccClosure:
+      # ugh: call to some other inner proc; 
+      assert i.closureParam != nil
+      # XXX this is not correct in general! may also be some 'closure.upval'
+      result = makeClosure(s, i.closureParam, n.info)
     else:
       # captured symbol?
       result = IdNodeTableGet(i.localsToAccess, n.sym)
@@ -307,8 +321,8 @@ proc searchForInnerProcs(o: POuterContext, n: PNode) =
   of nkEmpty..pred(nkSym), succ(nkSym)..nkNilLit: 
     nil
   of nkSym:
-    if isInnerProc(n.sym, o.fn):
-      var inner = newInnerContext(n.sym, o)
+    if isInnerProc(n.sym, o.fn) and not containsOrIncl(o.processed, n.sym.id):
+      var inner = newInnerContext(n.sym)
       let body = n.sym.getBody
       gatherVars(o, inner, body)
       let ti = transformInnerProc(o, inner, body)
@@ -318,18 +332,18 @@ proc searchForInnerProcs(o: POuterContext, n: PNode) =
   of nkWhileStmt, nkForStmt, nkParForStmt, nkBlockStmt:
     # some nodes open a new scope, so they are candidates for the insertion
     # of closure creation; however for simplicity we merge closures between
-    # branches, in fact, only loops bodies are of interest here as only they 
+    # branches, in fact, only loop bodies are of interest here as only they 
     # yield observable changes in semantics. For Zahary we also
     # include ``nkBlock``.
     var body = n.len-1
     for i in countup(0, body - 1): searchForInnerProcs(o, n.sons[i])
     # special handling for the loop body:
-    let oldBlock = o.currentBlock
+    let oldEnv = o.currentEnv
     let ex = closureCreationPoint(n.sons[body])
-    o.currentBlock = ex
+    o.currentEnv = newEnv(o.fn, oldEnv, ex)
     searchForInnerProcs(o, n.sons[body])
     n.sons[body] = ex
-    o.currentBlock = oldBlock
+    o.currentEnv = oldEnv
   of nkVarSection, nkLetSection:
     # we need to compute a mapping var->declaredBlock. Note: The definition
     # counts, not the block where it is captured!
@@ -339,13 +353,12 @@ proc searchForInnerProcs(o: POuterContext, n: PNode) =
       elif it.kind == nkIdentDefs:
         if it.sons[0].kind != nkSym: InternalError(it.info, "transformOuter")
         #echo "set: ", it.sons[0].sym.name.s, " ", o.currentBlock == nil
-        IdNodeTablePut(o.localsToEnclosingScope, it.sons[0].sym, o.currentBlock)
+        IdTablePut(o.localsToEnv, it.sons[0].sym, o.currentEnv)
       elif it.kind == nkVarTuple:
         var L = sonsLen(it)
         for j in countup(0, L-3):
           #echo "set: ", it.sons[j].sym.name.s, " ", o.currentBlock == nil
-          IdNodeTablePut(o.localsToEnclosingScope, it.sons[j].sym, 
-                         o.currentBlock)
+          IdTablePut(o.localsToEnv, it.sons[j].sym, o.currentEnv)
       else:
         InternalError(it.info, "transformOuter")
   of nkProcDef, nkMethodDef, nkConverterDef, nkMacroDef, nkTemplateDef, 
@@ -369,15 +382,19 @@ proc addVar*(father, v: PNode) =
   addSon(vpart, ast.emptyNode)
   addSon(father, vpart)
 
-proc generateClosureCreation(o: POuterContext, scope: PNode): PNode =
-  # add assignment if it's a parameter that has been captured:
-  var env = newSym(skVar, getIdent(envName), o.fn)
-  incl(env.flags, sfShadowed)
-  env.info = scope.info
-  env.typ = newType(tyRef, o.fn)
-  var tup = newType(tyTuple, o.fn)
-  tup.n = newNodeI(nkRecList, scope.info)
-  env.typ.addSon(tup)
+proc getClosureVar(o: POuterContext, e: PEnv): PSym =
+  if e.closure == nil:
+    result = newSym(skVar, getIdent(envName), o.fn)
+    incl(result.flags, sfShadowed)
+    result.info = e.attachedNode.info
+    result.typ = newType(tyRef, o.fn)
+    result.typ.addSon(e.tup)
+    e.closure = result
+  else:
+    result = e.closure
+
+proc generateClosureCreation(o: POuterContext, scope: PEnv): PNode =
+  var env = getClosureVar(o, scope)
 
   result = newNodeI(nkStmtList, env.info)
   var v = newNodeI(nkVarSection, env.info)
@@ -387,42 +404,40 @@ proc generateClosureCreation(o: POuterContext, scope: PNode): PNode =
   result.add(newCall(getSysSym"internalNew", env))
   
   # add assignment statements:
-  for v, scope2 in pairs(o.localsToEnclosingScope):
+  for v, scope2 in idTablePairs(o.localsToEnv):
     if scope2 == scope:
       let local = PSym(v)
-      addField(tup, local)
       let fieldAccess = indirectAccess(env, local, env.info)
       if sfByCopy in local.flags or local.kind == skParam:
         # add ``env.param = param``
         result.add(newAsgnStmt(fieldAccess, newSymNode(local)))
       IdNodeTablePut(o.localsToAccess, local, fieldAccess)
-  # XXX add support for 'up' references!
+  # add support for 'up' references:
+  for e, field in items(scope.deps):
+    # add ``env.up = env2``
+    result.add(newAsgnStmt(indirectAccess(env, field, env.info),
+               newSymNode(getClosureVar(o, e))))
 
 proc transformOuterProc(o: POuterContext, n: PNode): PNode =
   case n.kind
   of nkEmpty..pred(nkSym), succ(nkSym)..nkNilLit: nil
   of nkSym:
     var local = n.sym
-    var envBlock = IdNodeTableGet(o.lambdasToEnclosingScope, local)
-    if envBlock != nil:
+    var closure = PEnv(IdTableGet(o.lambdasToEnv, local))
+    if closure != nil:
       # we need to replace the lambda with '(lambda, env)': 
-      let a = envBlock.sons[0]
-      assert a.kind == nkStmtList
-      assert a.sons[0].kind == nkVarSection
-      assert a.sons[0].sons[0].kind == nkIdentDefs
-      var env = a.sons[0].sons[0].sons[0].sym
-      return makeClosure(local, env, n.info)
+      let a = closure.closure
+      assert a != nil
+      return makeClosure(local, a, n.info)
   
-    if not o.capturedVars.contains(local.id): return
-    var scope = IdNodeTableGet(o.localsToEnclosingScope, local)
-    if scope == nil: return
-    
+    if not contains(o.capturedVars, local.id): return
+    var env = PEnv(IdTableGet(o.localsToEnv, local))
+    if env == nil: return
+    var scope = env.attachedNode
     assert scope.kind == nkStmtList
     if scope.sons[0].kind == nkEmpty:
-      # change the empty node to contain the closure construction; we need to
-      # gather all variables here that belong to the closure which is a bit
-      # expensive:
-      scope.sons[0] = generateClosureCreation(o, scope)
+      # change the empty node to contain the closure construction:
+      scope.sons[0] = generateClosureCreation(o, env)
     
     # change 'local' to 'closure.local', unless it's a 'byCopy' variable:
     if sfByCopy notin local.flags:
@@ -441,7 +456,7 @@ proc transformOuterProc(o: POuterContext, n: PNode): PNode =
       let x = transformOuterProc(o, n.sons[i])
       if x != nil: n.sons[i] = x
 
-proc liftLambdas(fn: PSym, shared: PLLShared, body: PNode): PNode =
+proc liftLambdas(fn: PSym, body: PNode): PNode =
   if body.kind == nkEmpty:
     # ignore forward declaration:
     result = body
@@ -449,9 +464,9 @@ proc liftLambdas(fn: PSym, shared: PLLShared, body: PNode): PNode =
     # fast path: no inner procs, so no closure needed:
     result = body
   else:
-    var o = newOuterContext(fn, shared)
+    var o = newOuterContext(fn)
     let ex = closureCreationPoint(body)
-    o.currentBlock = ex
+    o.currentEnv = newEnv(fn, nil, ex)
     searchForInnerProcs(o, body)
     let a = transformOuterProc(o, body)
     result = ex
@@ -462,7 +477,4 @@ proc liftLambdas(fn: PSym, shared: PLLShared, body: PNode): PNode =
 proc liftLambdas*(n: PNode): PNode =
   assert n.kind in procDefs
   var s = n.sons[namePos].sym
-  var shared: ref TLLShared
-  new shared
-  shared.upField = newSym(skField, upName.getIdent, s)
-  result = liftLambdas(s, shared, s.getBody)
+  result = liftLambdas(s, s.getBody)
diff --git a/tests/run/tclosure2.nim b/tests/run/tclosure2.nim
index 47cf8fa11..5a1cb8075 100644
--- a/tests/run/tclosure2.nim
+++ b/tests/run/tclosure2.nim
@@ -22,7 +22,9 @@ when true:
     for xxxx in 0..9:
       var i = 0
       proc bx =
-        if i > 10: return
+        if i > 10: 
+          echo xxxx
+          return
         i += 1
         #for j in 0 .. 0: echo i
         bx()
diff --git a/tests/run/tinterf.nim b/tests/run/tinterf.nim
index b082b1d3f..648873da0 100644
--- a/tests/run/tinterf.nim
+++ b/tests/run/tinterf.nim
@@ -1,20 +1,24 @@
 discard """
-  output: '''56'''
+  output: '''56 66'''
 """
 
 type
   ITest = tuple[
     setter: proc(v: int) {.closure.},
-    getter: proc(): int {.closure.}]
+    getter1: proc(): int {.closure.},
+    getter2: proc(): int {.closure.}]
 
 proc getInterf(): ITest =
-  var shared: int
+  var shared, shared2: int
   
-  return (setter: proc (x: int) = shared = x,
-          getter: proc (): int = return shared)
+  return (setter: proc (x: int) = 
+            shared = x
+            shared2 = x + 10,
+          getter1: proc (): int = result = shared,
+          getter2: proc (): int = return shared2)
 
 var i = getInterf()
 i.setter(56)
 
-echo i.getter()
+echo i.getter1(), " ", i.getter2()
 
diff --git a/todo.txt b/todo.txt
index 62e4d85ae..aa914bf84 100755
--- a/todo.txt
+++ b/todo.txt
@@ -12,7 +12,8 @@ version 0.9.0
   - deactivate lambda lifting for JS backend
   - Test capture of for loop vars; test generics;
   - test constant closures
-  - implement closures that support nesting > 1
+  - implement closures that support nesting of blocks > 1
+  - implement closures that support nesting of *procs* > 1
 - implement proper coroutines
 
 - document 'do' notation