summary refs log tree commit diff stats
path: root/compiler
diff options
context:
space:
mode:
authorYuriy Glukhov <yglukhov@users.noreply.github.com>2024-07-03 22:49:30 +0200
committerGitHub <noreply@github.com>2024-07-03 22:49:30 +0200
commit05df263b84de9008266b3d53e2c28b009890ca61 (patch)
tree3616ac90794a37d4a7840136d29c36524a516dcd /compiler
parent051a536275c8851adec4742f90c518fbbb65d13c (diff)
downloadNim-05df263b84de9008266b3d53e2c28b009890ca61.tar.gz
Optimize closure iterator locals (#23787)
This pr redefines the relation between lambda lifting and closureiter
transformation.

Key takeaways:
- Lambdalifting now has less distinction between closureiters and
regular closures. Namely instead of lifting _all_ closureiter variables,
it lifts only those variables it would also lift for simple closure,
i.e. those not owned by the closure.
- It is now closureiter transformation's responsibility to lift all the
locals that need lifting and are not lifted by lambdalifting. So now we
lift only those locals that appear in more than one state. The rest
remains on stack, yay!
- Closureiter transformation always relies on the closure env param
created by lambdalifting. Special care taken to make lambdalifting
create it even in cases when it's "too early" to lift.
- Environments created by lambdalifting will contain `:state` only for
closureiters, whereas previously any closure env contained it.

IMO this is a more reasonable approach as it simplifies not only
lambdalifting, but transf too (e.g. freshVarsForClosureIters is now gone
for good).

I tried to organize the changes logically by commits, so it might be
easier to review this on per commit basis.

Some ugliness:
- Adding lifting to closureiters transformation I had to repeat this
matching of `return result = value` node. I tried to understand why it
is needed, but that was just another rabbit hole, so I left it for
another time. @Araq your input is welcome.
- In the last commit I've reused currently undocumented `liftLocals`
pragma for symbols so that closureiter transformation will forcefully
lift those even if they don't require lifting otherwise. This is needed
for [yasync](https://github.com/yglukhov/yasync) or else it will be very
sad.

Overall I'm quite happy with the results, I'm seeing some noticeable
code size reductions in my projects. Heavy closureiter/async users,
please give it a go.
Diffstat (limited to 'compiler')
-rw-r--r--compiler/ast.nim1
-rw-r--r--compiler/closureiters.nim183
-rw-r--r--compiler/lambdalifting.nim68
-rw-r--r--compiler/pragmas.nim5
-rw-r--r--compiler/transf.nim16
5 files changed, 150 insertions, 123 deletions
diff --git a/compiler/ast.nim b/compiler/ast.nim
index 8b3c6afcd..7191fd572 100644
--- a/compiler/ast.nim
+++ b/compiler/ast.nim
@@ -127,6 +127,7 @@ type
     sfMember          # proc is a C++ member of a type
     sfCodegenDecl     # type, proc, global or proc param is marked as codegenDecl
     sfWasGenSym       # symbol was 'gensym'ed
+    sfForceLift       # variable has to be lifted into closure environment
 
   TSymFlags* = set[TSymFlag]
 
diff --git a/compiler/closureiters.nim b/compiler/closureiters.nim
index c1f525488..9ee394111 100644
--- a/compiler/closureiters.nim
+++ b/compiler/closureiters.nim
@@ -36,13 +36,6 @@
 #  else:
 #    return
 
-# The transformation should play well with lambdalifting, however depending
-# on situation, it can be called either before or after lambdalifting
-# transformation. As such we behave slightly differently, when accessing
-# iterator state, or using temp variables. If lambdalifting did not happen,
-# we just create local variables, so that they will be lifted further on.
-# Otherwise, we utilize existing env, created by lambdalifting.
-
 # Lambdalifting treats :state variable specially, it should always end up
 # as the first field in env. Currently C codegen depends on this behavior.
 
@@ -151,7 +144,6 @@ type
   Ctx = object
     g: ModuleGraph
     fn: PSym
-    stateVarSym: PSym # :state variable. nil if env already introduced by lambdalifting
     tmpResultSym: PSym # Used when we return, but finally has to interfere
     unrollFinallySym: PSym # Indicates that we're unrolling finally states (either exception happened or premature return)
     curExcSym: PSym # Current exception
@@ -168,18 +160,18 @@ type
     nearestFinally: int # Index of the nearest finally block. For try/except it
                     # is their finally. For finally it is parent finally. Otherwise -1
     idgen: IdGenerator
+    varStates: Table[ItemId, int] # Used to detect if local variable belongs to multiple states
 
 const
   nkSkip = {nkEmpty..nkNilLit, nkTemplateDef, nkTypeSection, nkStaticStmt,
             nkCommentStmt, nkMixinStmt, nkBindStmt} + procDefs
   emptyStateLabel = -1
+  localNotSeen = -1
+  localRequiresLifting = -2
 
 proc newStateAccess(ctx: var Ctx): PNode =
-  if ctx.stateVarSym.isNil:
-    result = rawIndirectAccess(newSymNode(getEnvParam(ctx.fn)),
-        getStateField(ctx.g, ctx.fn), ctx.fn.info)
-  else:
-    result = newSymNode(ctx.stateVarSym)
+  result = rawIndirectAccess(newSymNode(getEnvParam(ctx.fn)),
+      getStateField(ctx.g, ctx.fn), ctx.fn.info)
 
 proc newStateAssgn(ctx: var Ctx, toValue: PNode): PNode =
   # Creates state assignment:
@@ -195,24 +187,17 @@ proc newEnvVar(ctx: var Ctx, name: string, typ: PType): PSym =
   result = newSym(skVar, getIdent(ctx.g.cache, name), ctx.idgen, ctx.fn, ctx.fn.info)
   result.typ = typ
   result.flags.incl sfNoInit
-  assert(not typ.isNil)
-
-  if not ctx.stateVarSym.isNil:
-    # We haven't gone through labmda lifting yet, so just create a local var,
-    # it will be lifted later
-    if ctx.tempVars.isNil:
-      ctx.tempVars = newNodeI(nkVarSection, ctx.fn.info)
-      addVar(ctx.tempVars, newSymNode(result))
-  else:
-    let envParam = getEnvParam(ctx.fn)
-    # let obj = envParam.typ.lastSon
-    result = addUniqueField(envParam.typ.elementType, result, ctx.g.cache, ctx.idgen)
+  assert(not typ.isNil, "Env var needs a type")
+
+  let envParam = getEnvParam(ctx.fn)
+  # let obj = envParam.typ.lastSon
+  result = addUniqueField(envParam.typ.elementType, result, ctx.g.cache, ctx.idgen)
 
 proc newEnvVarAccess(ctx: Ctx, s: PSym): PNode =
-  if ctx.stateVarSym.isNil:
-    result = rawIndirectAccess(newSymNode(getEnvParam(ctx.fn)), s, ctx.fn.info)
-  else:
-    result = newSymNode(s)
+  result = rawIndirectAccess(newSymNode(getEnvParam(ctx.fn)), s, ctx.fn.info)
+
+proc newTempVarAccess(ctx: Ctx, s: PSym): PNode =
+  result = newSymNode(s, ctx.fn.info)
 
 proc newTmpResultAccess(ctx: var Ctx): PNode =
   if ctx.tmpResultSym.isNil:
@@ -255,9 +240,18 @@ proc addGotoOut(n: PNode, gotoOut: PNode): PNode =
   if result.len == 0 or result[^1].kind != nkGotoState:
     result.add(gotoOut)
 
-proc newTempVar(ctx: var Ctx, typ: PType): PSym =
-  result = ctx.newEnvVar(":tmpSlLower" & $ctx.tempVarId, typ)
+proc newTempVarDef(ctx: Ctx, s: PSym, initialValue: PNode): PNode =
+  var v = initialValue
+  if v == nil:
+    v = ctx.g.emptyNode
+  newTree(nkVarSection, newTree(nkIdentDefs, newSymNode(s), ctx.g.emptyNode, v))
+
+proc newTempVar(ctx: var Ctx, typ: PType, parent: PNode, initialValue: PNode = nil): PSym =
+  result = newSym(skVar, getIdent(ctx.g.cache, ":tmpSlLower" & $ctx.tempVarId), ctx.idgen, ctx.fn, ctx.fn.info)
   inc ctx.tempVarId
+  result.typ = typ
+  assert(not typ.isNil, "Temp var needs a type")
+  parent.add(ctx.newTempVarDef(result, initialValue))
 
 proc hasYields(n: PNode): bool =
   # TODO: This is very inefficient. It traverses the node, looking for nkYieldStmt.
@@ -429,21 +423,20 @@ proc exprToStmtList(n: PNode): tuple[s, res: PNode] =
 
   result.res = n
 
-
-proc newEnvVarAsgn(ctx: Ctx, s: PSym, v: PNode): PNode =
+proc newTempVarAsgn(ctx: Ctx, s: PSym, v: PNode): PNode =
   if isEmptyType(v.typ):
     result = v
   else:
-    result = newTree(nkFastAsgn, ctx.newEnvVarAccess(s), v)
+    result = newTree(nkFastAsgn, ctx.newTempVarAccess(s), v)
     result.info = v.info
 
 proc addExprAssgn(ctx: Ctx, output, input: PNode, sym: PSym) =
   if input.kind == nkStmtListExpr:
     let (st, res) = exprToStmtList(input)
     output.add(st)
-    output.add(ctx.newEnvVarAsgn(sym, res))
+    output.add(ctx.newTempVarAsgn(sym, res))
   else:
-    output.add(ctx.newEnvVarAsgn(sym, input))
+    output.add(ctx.newTempVarAsgn(sym, input))
 
 proc convertExprBodyToAsgn(ctx: Ctx, exprBody: PNode, res: PSym): PNode =
   result = newNodeI(nkStmtList, exprBody.info)
@@ -457,6 +450,12 @@ proc boolLit(g: ModuleGraph; info: TLineInfo; value: bool): PNode =
   result = newIntLit(g, info, ord value)
   result.typ = getSysType(g, info, tyBool)
 
+proc captureVar(c: var Ctx, s: PSym) =
+  if c.varStates.getOrDefault(s.itemId) != localRequiresLifting:
+    c.varStates[s.itemId] = localRequiresLifting # Mark this variable for lifting
+    let e = getEnvParam(c.fn)
+    discard addField(e.typ.elementType, s, c.g.cache, c.idgen)
+
 proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode =
   result = n
   case n.kind
@@ -513,9 +512,9 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode =
       var tmp: PSym = nil
       let isExpr = not isEmptyType(n.typ)
       if isExpr:
-        tmp = ctx.newTempVar(n.typ)
         result = newNodeI(nkStmtListExpr, n.info)
         result.typ = n.typ
+        tmp = ctx.newTempVar(n.typ, result)
       else:
         result = newNodeI(nkStmtList, n.info)
 
@@ -566,7 +565,7 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode =
         else:
           internalError(ctx.g.config, "lowerStmtListExpr(nkIf): " & $branch.kind)
 
-      if isExpr: result.add(ctx.newEnvVarAccess(tmp))
+      if isExpr: result.add(ctx.newTempVarAccess(tmp))
 
   of nkTryStmt, nkHiddenTryStmt:
     var ns = false
@@ -580,7 +579,7 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode =
       if isExpr:
         result = newNodeI(nkStmtListExpr, n.info)
         result.typ = n.typ
-        let tmp = ctx.newTempVar(n.typ)
+        let tmp = ctx.newTempVar(n.typ, result)
 
         n[0] = ctx.convertExprBodyToAsgn(n[0], tmp)
         for i in 1..<n.len:
@@ -596,7 +595,7 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode =
           else:
             internalError(ctx.g.config, "lowerStmtListExpr(nkTryStmt): " & $branch.kind)
         result.add(n)
-        result.add(ctx.newEnvVarAccess(tmp))
+        result.add(ctx.newTempVarAccess(tmp))
 
   of nkCaseStmt:
     var ns = false
@@ -609,9 +608,9 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode =
       let isExpr = not isEmptyType(n.typ)
 
       if isExpr:
-        let tmp = ctx.newTempVar(n.typ)
         result = newNodeI(nkStmtListExpr, n.info)
         result.typ = n.typ
+        let tmp = ctx.newTempVar(n.typ, result)
 
         if n[0].kind == nkStmtListExpr:
           let (st, ex) = exprToStmtList(n[0])
@@ -628,7 +627,7 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode =
           else:
             internalError(ctx.g.config, "lowerStmtListExpr(nkCaseStmt): " & $branch.kind)
         result.add(n)
-        result.add(ctx.newEnvVarAccess(tmp))
+        result.add(ctx.newTempVarAccess(tmp))
       elif n[0].kind == nkStmtListExpr:
         result = newNodeI(nkStmtList, n.info)
         let (st, ex) = exprToStmtList(n[0])
@@ -658,10 +657,10 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode =
           result.add(st)
           cond = ex
 
-        let tmp = ctx.newTempVar(cond.typ)
-        result.add(ctx.newEnvVarAsgn(tmp, cond))
+        let tmp = ctx.newTempVar(cond.typ, result, cond)
+        # result.add(ctx.newTempVarAsgn(tmp, cond))
 
-        var check = ctx.newEnvVarAccess(tmp)
+        var check = ctx.newTempVarAccess(tmp)
         if n[0].sym.magic == mOr:
           check = ctx.g.newNotCall(check)
 
@@ -671,12 +670,12 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode =
           let (st, ex) = exprToStmtList(cond)
           ifBody.add(st)
           cond = ex
-        ifBody.add(ctx.newEnvVarAsgn(tmp, cond))
+        ifBody.add(ctx.newTempVarAsgn(tmp, cond))
 
         let ifBranch = newTree(nkElifBranch, check, ifBody)
         let ifNode = newTree(nkIfStmt, ifBranch)
         result.add(ifNode)
-        result.add(ctx.newEnvVarAccess(tmp))
+        result.add(ctx.newTempVarAccess(tmp))
       else:
         for i in 0..<n.len:
           if n[i].kind == nkStmtListExpr:
@@ -685,9 +684,9 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode =
             n[i] = ex
 
           if n[i].kind in nkCallKinds: # XXX: This should better be some sort of side effect tracking
-            let tmp = ctx.newTempVar(n[i].typ)
-            result.add(ctx.newEnvVarAsgn(tmp, n[i]))
-            n[i] = ctx.newEnvVarAccess(tmp)
+            let tmp = ctx.newTempVar(n[i].typ, result, n[i])
+            # result.add(ctx.newTempVarAsgn(tmp, n[i]))
+            n[i] = ctx.newTempVarAccess(tmp)
 
         result.add(n)
 
@@ -703,6 +702,12 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode =
         let (st, ex) = exprToStmtList(c[^1])
         result.add(st)
         c[^1] = ex
+      for i in 0 .. c.len - 3:
+        if c[i].kind == nkSym:
+          let s = c[i].sym
+          if sfForceLift in s.flags:
+            ctx.captureVar(s)
+
       result.add(varSect)
 
   of nkDiscardStmt, nkReturnStmt, nkRaiseStmt:
@@ -1279,13 +1284,6 @@ proc wrapIntoStateLoop(ctx: var Ctx, n: PNode): PNode =
   result.info = n.info
 
   let localVars = newNodeI(nkStmtList, n.info)
-  if not ctx.stateVarSym.isNil:
-    let varSect = newNodeI(nkVarSection, n.info)
-    addVar(varSect, newSymNode(ctx.stateVarSym))
-    localVars.add(varSect)
-
-    if not ctx.tempVars.isNil:
-      localVars.add(ctx.tempVars)
 
   let blockStmt = newNodeI(nkBlockStmt, n.info)
   blockStmt.add(newSymNode(ctx.stateLoopLabel))
@@ -1433,15 +1431,67 @@ proc preprocess(c: var PreprocessContext; n: PNode): PNode =
     for i in 0 ..< n.len:
       result[i] = preprocess(c, n[i])
 
+proc detectCapturedVars(c: var Ctx, n: PNode, stateIdx: int) =
+  case n.kind
+  of nkSym:
+    let s = n.sym
+    if s.kind in {skResult, skVar, skLet, skForVar, skTemp} and sfGlobal notin s.flags and s.owner == c.fn:
+      let vs = c.varStates.getOrDefault(s.itemId, localNotSeen)
+      if vs == localNotSeen: # First seing this variable
+        c.varStates[s.itemId] = stateIdx
+      elif vs == localRequiresLifting:
+        discard # Sym already marked
+      elif vs != stateIdx:
+        c.captureVar(s)
+  of nkReturnStmt:
+    if n[0].kind in {nkAsgn, nkFastAsgn, nkSinkAsgn}:
+      # we have a `result = result` expression produced by the closure
+      # transform, let's not touch the LHS in order to make the lifting pass
+      # correct when `result` is lifted
+      detectCapturedVars(c, n[0][1], stateIdx)
+    else:
+      detectCapturedVars(c, n[0], stateIdx)
+  else:
+    for i in 0 ..< n.safeLen:
+      detectCapturedVars(c, n[i], stateIdx)
+
+proc detectCapturedVars(c: var Ctx) =
+  for i, s in c.states:
+    detectCapturedVars(c, s.body, i)
+
+proc liftLocals(c: var Ctx, n: PNode): PNode =
+  result = n
+  case n.kind
+  of nkSym:
+    let s = n.sym
+    if c.varStates.getOrDefault(s.itemId) == localRequiresLifting:
+      # lift
+      let e = getEnvParam(c.fn)
+      let field = getFieldFromObj(e.typ.elementType, s)
+      assert(field != nil)
+      result = rawIndirectAccess(newSymNode(e), field, n.info)
+    # elif c.varStates.getOrDefault(s.itemId, localNotSeen) != localNotSeen:
+    #   echo "Not lifting ", s.name.s
+
+  of nkReturnStmt:
+    if n[0].kind in {nkAsgn, nkFastAsgn, nkSinkAsgn}:
+      # we have a `result = result` expression produced by the closure
+      # transform, let's not touch the LHS in order to make the lifting pass
+      # correct when `result` is lifted
+      n[0][1] = liftLocals(c, n[0][1])
+    else:
+      n[0] = liftLocals(c, n[0])
+  else:
+    for i in 0 ..< n.safeLen:
+      n[i] = liftLocals(c, n[i])
+
 proc transformClosureIterator*(g: ModuleGraph; idgen: IdGenerator; fn: PSym, n: PNode): PNode =
   var ctx = Ctx(g: g, fn: fn, idgen: idgen)
 
-  if getEnvParam(fn).isNil:
-    # Lambda lifting was not done yet. Use temporary :state sym, which will
-    # be handled specially by lambda lifting. Local temp vars (if needed)
-    # should follow the same logic.
-    ctx.stateVarSym = newSym(skVar, getIdent(ctx.g.cache, ":state"), idgen, fn, fn.info)
-    ctx.stateVarSym.typ = g.createClosureIterStateType(fn, idgen)
+  # The transformation should always happen after at least partial lambdalifting
+  # is performed, so that the closure iter environment is always created upfront.
+  doAssert(getEnvParam(fn) != nil, "Env param not created before iter transformation")
+
   ctx.stateLoopLabel = newSym(skLabel, getIdent(ctx.g.cache, ":stateLoop"), idgen, fn, fn.info)
   var pc = PreprocessContext(finallys: @[], config: g.config, idgen: idgen)
   var n = preprocess(pc, n.toStmtList)
@@ -1466,6 +1516,10 @@ proc transformClosureIterator*(g: ModuleGraph; idgen: IdGenerator; fn: PSym, n:
   let caseDispatcher = newTreeI(nkCaseStmt, n.info,
       ctx.newStateAccess())
 
+  # Lamdalifting will not touch our locals, it is our responsibility to lift those that
+  # need it.
+  detectCapturedVars(ctx)
+
   for s in ctx.states:
     let body = ctx.transformStateAssignments(s.body)
     caseDispatcher.add newTreeI(nkOfBranch, body.info, g.newIntLit(body.info, s.label), body)
@@ -1473,6 +1527,7 @@ proc transformClosureIterator*(g: ModuleGraph; idgen: IdGenerator; fn: PSym, n:
   caseDispatcher.add newTreeI(nkElse, n.info, newTreeI(nkReturnStmt, n.info, g.emptyNode))
 
   result = wrapIntoStateLoop(ctx, caseDispatcher)
+  result = liftLocals(ctx, result)
 
   when false:
     echo "TRANSFORM TO STATES: "
@@ -1481,3 +1536,5 @@ proc transformClosureIterator*(g: ModuleGraph; idgen: IdGenerator; fn: PSym, n:
     echo "exception table:"
     for i, e in ctx.exceptionTable:
       echo i, " -> ", e
+
+    echo "ENV: ", renderTree(getEnvParam(fn).typ.elementType.n)
diff --git a/compiler/lambdalifting.nim b/compiler/lambdalifting.nim
index e51fb3ae0..2a72c3846 100644
--- a/compiler/lambdalifting.nim
+++ b/compiler/lambdalifting.nim
@@ -145,11 +145,13 @@ proc createStateField(g: ModuleGraph; iter: PSym; idgen: IdGenerator): PSym =
   result = newSym(skField, getIdent(g.cache, ":state"), idgen, iter, iter.info)
   result.typ = createClosureIterStateType(g, iter, idgen)
 
+template isIterator*(owner: PSym): bool =
+  owner.kind == skIterator and owner.typ.callConv == ccClosure
+
 proc createEnvObj(g: ModuleGraph; idgen: IdGenerator; owner: PSym; info: TLineInfo): PType =
-  # YYY meh, just add the state field for every closure for now, it's too
-  # hard to figure out if it comes from a closure iterator:
   result = createObj(g, idgen, owner, info, final=false)
-  rawAddField(result, createStateField(g, owner, idgen))
+  if owner.isIterator:
+    rawAddField(result, createStateField(g, owner, idgen))
 
 proc getClosureIterResult*(g: ModuleGraph; iter: PSym; idgen: IdGenerator): PSym =
   if resultPos < iter.ast.len:
@@ -172,26 +174,22 @@ proc addHiddenParam(routine: PSym, param: PSym) =
   assert sfFromGeneric in param.flags
   #echo "produced environment: ", param.id, " for ", routine.id
 
-proc getHiddenParam(g: ModuleGraph; routine: PSym): PSym =
+proc getEnvParam*(routine: PSym): PSym =
   let params = routine.ast[paramsPos]
   let hidden = lastSon(params)
   if hidden.kind == nkSym and hidden.sym.kind == skParam and hidden.sym.name.s == paramName:
     result = hidden.sym
     assert sfFromGeneric in result.flags
   else:
+    result = nil
+
+proc getHiddenParam(g: ModuleGraph; routine: PSym): PSym =
+  result = getEnvParam(routine)
+  if result.isNil:
     # writeStackTrace()
     localError(g.config, routine.info, "internal error: could not find env param for " & routine.name.s)
     result = routine
 
-proc getEnvParam*(routine: PSym): PSym =
-  let params = routine.ast[paramsPos]
-  let hidden = lastSon(params)
-  if hidden.kind == nkSym and hidden.sym.name.s == paramName:
-    result = hidden.sym
-    assert sfFromGeneric in result.flags
-  else:
-    result = nil
-
 proc interestingVar(s: PSym): bool {.inline.} =
   result = s.kind in {skVar, skLet, skTemp, skForVar, skParam, skResult} and
     sfGlobal notin s.flags and
@@ -230,15 +228,6 @@ proc makeClosure*(g: ModuleGraph; idgen: IdGenerator; prc: PSym; env: PNode; inf
   if tfHasAsgn in result.typ.flags or optSeqDestructors in g.config.globalOptions:
     prc.flags.incl sfInjectDestructors
 
-proc interestingIterVar(s: PSym): bool {.inline.} =
-  # XXX optimization: Only lift the variable if it lives across
-  # yield/return boundaries! This can potentially speed up
-  # closure iterators quite a bit.
-  result = s.kind in {skResult, skVar, skLet, skTemp, skForVar} and sfGlobal notin s.flags
-
-template isIterator*(owner: PSym): bool =
-  owner.kind == skIterator and owner.typ.callConv == ccClosure
-
 template liftingHarmful(conf: ConfigRef; owner: PSym): bool =
   ## lambda lifting can be harmful for JS-like code generators.
   let isCompileTime = sfCompileTime in owner.flags or owner.kind == skMacro
@@ -285,15 +274,6 @@ proc liftIterSym*(g: ModuleGraph; n: PNode; idgen: IdGenerator; owner: PSym): PN
   createTypeBoundOpsLL(g, env.typ, n.info, idgen, owner)
   result.add makeClosure(g, idgen, iter, env, n.info)
 
-proc freshVarForClosureIter*(g: ModuleGraph; s: PSym; idgen: IdGenerator; owner: PSym): PNode =
-  let envParam = getHiddenParam(g, owner)
-  let obj = envParam.typ.skipTypes({tyOwned, tyRef, tyPtr})
-  let field = addField(obj, s, g.cache, idgen)
-
-  var access = newSymNode(envParam)
-  assert obj.kind == tyObject
-  result = rawIndirectAccess(access, field, s.info)
-
 # ------------------ new stuff -------------------------------------------
 
 proc markAsClosure(g: ModuleGraph; owner: PSym; n: PNode) =
@@ -341,9 +321,13 @@ proc getEnvTypeForOwner(c: var DetectionPass; owner: PSym;
                         info: TLineInfo): PType =
   result = c.ownerToType.getOrDefault(owner.id)
   if result.isNil:
-    result = newType(tyRef, c.idgen, owner)
-    let obj = createEnvObj(c.graph, c.idgen, owner, info)
-    rawAddSon(result, obj)
+    let env = getEnvParam(owner)
+    if env.isNil or not owner.isIterator:
+      result = newType(tyRef, c.idgen, owner)
+      let obj = createEnvObj(c.graph, c.idgen, owner, info)
+      rawAddSon(result, obj)
+    else:
+      result = env.typ
     c.ownerToType[owner.id] = result
 
 proc asOwnedRef(c: var DetectionPass; t: PType): PType =
@@ -459,17 +443,6 @@ proc detectCapturedVars(n: PNode; owner: PSym; c: var DetectionPass) =
       if owner.isIterator:
         c.somethingToDo = true
         addClosureParam(c, owner, n.info)
-        if interestingIterVar(s):
-          if not c.capturedVars.contains(s.id):
-            if not c.inTypeOf: c.capturedVars.incl(s.id)
-            let obj = getHiddenParam(c.graph, owner).typ.skipTypes({tyOwned, tyRef, tyPtr})
-            #let obj = c.getEnvTypeForOwner(s.owner).skipTypes({tyOwned, tyRef, tyPtr})
-
-            if s.name.id == getIdent(c.graph.cache, ":state").id:
-              obj.n[0].sym.flags.incl sfNoInit
-              obj.n[0].sym.itemId = ItemId(module: s.itemId.module, item: -s.itemId.item)
-            else:
-              discard addField(obj, s, c.graph.cache, c.idgen)
     # direct or indirect dependency:
     elif innerClosure or interested:
       discard """
@@ -773,8 +746,6 @@ proc liftCapturedVars(n: PNode; owner: PSym; d: var DetectionPass;
     elif s.id in d.capturedVars:
       if s.owner != owner:
         result = accessViaEnvParam(d.graph, n, owner)
-      elif owner.isIterator and interestingIterVar(s):
-        result = accessViaEnvParam(d.graph, n, owner)
       else:
         result = accessViaEnvVar(n, owner, d, c)
   of nkEmpty..pred(nkSym), succ(nkSym)..nkNilLit, nkComesFrom,
@@ -892,6 +863,9 @@ proc liftLambdas*(g: ModuleGraph; fn: PSym, body: PNode; tooEarly: var bool;
     # ignore forward declaration:
     result = body
     tooEarly = true
+    if fn.isIterator:
+      var d = initDetectionPass(g, fn, idgen)
+      addClosureParam(d, fn, body.info)
   else:
     var d = initDetectionPass(g, fn, idgen)
     detectCapturedVars(body, fn, d)
diff --git a/compiler/pragmas.nim b/compiler/pragmas.nim
index ed1e067c0..c724aa025 100644
--- a/compiler/pragmas.nim
+++ b/compiler/pragmas.nim
@@ -81,7 +81,7 @@ const
     wRequiresInit, wNoalias, wAlign, wNoInit} - {wExportNims, wNodecl} # why exclude these?
   varPragmas* = declPragmas + {wVolatile, wRegister, wThreadVar,
     wMagic, wHeader, wCompilerProc, wCore, wDynlib,
-    wNoInit, wCompileTime, wGlobal,
+    wNoInit, wCompileTime, wGlobal, wLiftLocals,
     wGensym, wInject, wCodegenDecl,
     wGuard, wGoto, wCursor, wNoalias, wAlign}
   constPragmas* = declPragmas + {wHeader, wMagic,
@@ -1308,7 +1308,8 @@ proc singlePragma(c: PContext, sym: PSym, n: PNode, i: var int,
         noVal(c, it)
         if sym == nil: invalidPragma(c, it)
         else: sym.flags.incl sfUsed
-      of wLiftLocals: discard
+      of wLiftLocals:
+        sym.flags.incl(sfForceLift)
       of wRequires, wInvariant, wAssume, wAssert:
         pragmaProposition(c, it)
       of wEnsures:
diff --git a/compiler/transf.nim b/compiler/transf.nim
index 8dcb74729..ce0157f22 100644
--- a/compiler/transf.nim
+++ b/compiler/transf.nim
@@ -97,10 +97,7 @@ proc newTemp(c: PTransf, typ: PType, info: TLineInfo): PNode =
   r.typ = typ #skipTypes(typ, {tyGenericInst, tyAlias, tySink})
   incl(r.flags, sfFromGeneric)
   let owner = getCurrOwner(c)
-  if owner.isIterator and not c.tooEarly:
-    result = freshVarForClosureIter(c.graph, r, c.idgen, owner)
-  else:
-    result = newSymNode(r)
+  result = newSymNode(r)
 
 proc transform(c: PTransf, n: PNode): PNode
 
@@ -176,13 +173,10 @@ proc transformSym(c: PTransf, n: PNode): PNode =
 
 proc freshVar(c: PTransf; v: PSym): PNode =
   let owner = getCurrOwner(c)
-  if owner.isIterator and not c.tooEarly:
-    result = freshVarForClosureIter(c.graph, v, c.idgen, owner)
-  else:
-    var newVar = copySym(v, c.idgen)
-    incl(newVar.flags, sfFromGeneric)
-    newVar.owner = owner
-    result = newSymNode(newVar)
+  var newVar = copySym(v, c.idgen)
+  incl(newVar.flags, sfFromGeneric)
+  newVar.owner = owner
+  result = newSymNode(newVar)
 
 proc transformVarSection(c: PTransf, v: PNode): PNode =
   result = newTransNode(v)