summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--compiler/ast.nim1
-rw-r--r--compiler/closureiters.nim183
-rw-r--r--compiler/lambdalifting.nim68
-rw-r--r--compiler/pragmas.nim5
-rw-r--r--compiler/transf.nim16
-rw-r--r--tests/destructor/tuse_ownedref_after_move.nim2
-rw-r--r--tests/iter/tyieldintry.nim23
7 files changed, 174 insertions, 124 deletions
diff --git a/compiler/ast.nim b/compiler/ast.nim
index 8b3c6afcd..7191fd572 100644
--- a/compiler/ast.nim
+++ b/compiler/ast.nim
@@ -127,6 +127,7 @@ type
     sfMember          # proc is a C++ member of a type
     sfCodegenDecl     # type, proc, global or proc param is marked as codegenDecl
     sfWasGenSym       # symbol was 'gensym'ed
+    sfForceLift       # variable has to be lifted into closure environment
 
   TSymFlags* = set[TSymFlag]
 
diff --git a/compiler/closureiters.nim b/compiler/closureiters.nim
index c1f525488..9ee394111 100644
--- a/compiler/closureiters.nim
+++ b/compiler/closureiters.nim
@@ -36,13 +36,6 @@
 #  else:
 #    return
 
-# The transformation should play well with lambdalifting, however depending
-# on situation, it can be called either before or after lambdalifting
-# transformation. As such we behave slightly differently, when accessing
-# iterator state, or using temp variables. If lambdalifting did not happen,
-# we just create local variables, so that they will be lifted further on.
-# Otherwise, we utilize existing env, created by lambdalifting.
-
 # Lambdalifting treats :state variable specially, it should always end up
 # as the first field in env. Currently C codegen depends on this behavior.
 
@@ -151,7 +144,6 @@ type
   Ctx = object
     g: ModuleGraph
     fn: PSym
-    stateVarSym: PSym # :state variable. nil if env already introduced by lambdalifting
     tmpResultSym: PSym # Used when we return, but finally has to interfere
     unrollFinallySym: PSym # Indicates that we're unrolling finally states (either exception happened or premature return)
     curExcSym: PSym # Current exception
@@ -168,18 +160,18 @@ type
     nearestFinally: int # Index of the nearest finally block. For try/except it
                     # is their finally. For finally it is parent finally. Otherwise -1
     idgen: IdGenerator
+    varStates: Table[ItemId, int] # Used to detect if local variable belongs to multiple states
 
 const
   nkSkip = {nkEmpty..nkNilLit, nkTemplateDef, nkTypeSection, nkStaticStmt,
             nkCommentStmt, nkMixinStmt, nkBindStmt} + procDefs
   emptyStateLabel = -1
+  localNotSeen = -1
+  localRequiresLifting = -2
 
 proc newStateAccess(ctx: var Ctx): PNode =
-  if ctx.stateVarSym.isNil:
-    result = rawIndirectAccess(newSymNode(getEnvParam(ctx.fn)),
-        getStateField(ctx.g, ctx.fn), ctx.fn.info)
-  else:
-    result = newSymNode(ctx.stateVarSym)
+  result = rawIndirectAccess(newSymNode(getEnvParam(ctx.fn)),
+      getStateField(ctx.g, ctx.fn), ctx.fn.info)
 
 proc newStateAssgn(ctx: var Ctx, toValue: PNode): PNode =
   # Creates state assignment:
@@ -195,24 +187,17 @@ proc newEnvVar(ctx: var Ctx, name: string, typ: PType): PSym =
   result = newSym(skVar, getIdent(ctx.g.cache, name), ctx.idgen, ctx.fn, ctx.fn.info)
   result.typ = typ
   result.flags.incl sfNoInit
-  assert(not typ.isNil)
-
-  if not ctx.stateVarSym.isNil:
-    # We haven't gone through labmda lifting yet, so just create a local var,
-    # it will be lifted later
-    if ctx.tempVars.isNil:
-      ctx.tempVars = newNodeI(nkVarSection, ctx.fn.info)
-      addVar(ctx.tempVars, newSymNode(result))
-  else:
-    let envParam = getEnvParam(ctx.fn)
-    # let obj = envParam.typ.lastSon
-    result = addUniqueField(envParam.typ.elementType, result, ctx.g.cache, ctx.idgen)
+  assert(not typ.isNil, "Env var needs a type")
+
+  let envParam = getEnvParam(ctx.fn)
+  # let obj = envParam.typ.lastSon
+  result = addUniqueField(envParam.typ.elementType, result, ctx.g.cache, ctx.idgen)
 
 proc newEnvVarAccess(ctx: Ctx, s: PSym): PNode =
-  if ctx.stateVarSym.isNil:
-    result = rawIndirectAccess(newSymNode(getEnvParam(ctx.fn)), s, ctx.fn.info)
-  else:
-    result = newSymNode(s)
+  result = rawIndirectAccess(newSymNode(getEnvParam(ctx.fn)), s, ctx.fn.info)
+
+proc newTempVarAccess(ctx: Ctx, s: PSym): PNode =
+  result = newSymNode(s, ctx.fn.info)
 
 proc newTmpResultAccess(ctx: var Ctx): PNode =
   if ctx.tmpResultSym.isNil:
@@ -255,9 +240,18 @@ proc addGotoOut(n: PNode, gotoOut: PNode): PNode =
   if result.len == 0 or result[^1].kind != nkGotoState:
     result.add(gotoOut)
 
-proc newTempVar(ctx: var Ctx, typ: PType): PSym =
-  result = ctx.newEnvVar(":tmpSlLower" & $ctx.tempVarId, typ)
+proc newTempVarDef(ctx: Ctx, s: PSym, initialValue: PNode): PNode =
+  var v = initialValue
+  if v == nil:
+    v = ctx.g.emptyNode
+  newTree(nkVarSection, newTree(nkIdentDefs, newSymNode(s), ctx.g.emptyNode, v))
+
+proc newTempVar(ctx: var Ctx, typ: PType, parent: PNode, initialValue: PNode = nil): PSym =
+  result = newSym(skVar, getIdent(ctx.g.cache, ":tmpSlLower" & $ctx.tempVarId), ctx.idgen, ctx.fn, ctx.fn.info)
   inc ctx.tempVarId
+  result.typ = typ
+  assert(not typ.isNil, "Temp var needs a type")
+  parent.add(ctx.newTempVarDef(result, initialValue))
 
 proc hasYields(n: PNode): bool =
   # TODO: This is very inefficient. It traverses the node, looking for nkYieldStmt.
@@ -429,21 +423,20 @@ proc exprToStmtList(n: PNode): tuple[s, res: PNode] =
 
   result.res = n
 
-
-proc newEnvVarAsgn(ctx: Ctx, s: PSym, v: PNode): PNode =
+proc newTempVarAsgn(ctx: Ctx, s: PSym, v: PNode): PNode =
   if isEmptyType(v.typ):
     result = v
   else:
-    result = newTree(nkFastAsgn, ctx.newEnvVarAccess(s), v)
+    result = newTree(nkFastAsgn, ctx.newTempVarAccess(s), v)
     result.info = v.info
 
 proc addExprAssgn(ctx: Ctx, output, input: PNode, sym: PSym) =
   if input.kind == nkStmtListExpr:
     let (st, res) = exprToStmtList(input)
     output.add(st)
-    output.add(ctx.newEnvVarAsgn(sym, res))
+    output.add(ctx.newTempVarAsgn(sym, res))
   else:
-    output.add(ctx.newEnvVarAsgn(sym, input))
+    output.add(ctx.newTempVarAsgn(sym, input))
 
 proc convertExprBodyToAsgn(ctx: Ctx, exprBody: PNode, res: PSym): PNode =
   result = newNodeI(nkStmtList, exprBody.info)
@@ -457,6 +450,12 @@ proc boolLit(g: ModuleGraph; info: TLineInfo; value: bool): PNode =
   result = newIntLit(g, info, ord value)
   result.typ = getSysType(g, info, tyBool)
 
+proc captureVar(c: var Ctx, s: PSym) =
+  if c.varStates.getOrDefault(s.itemId) != localRequiresLifting:
+    c.varStates[s.itemId] = localRequiresLifting # Mark this variable for lifting
+    let e = getEnvParam(c.fn)
+    discard addField(e.typ.elementType, s, c.g.cache, c.idgen)
+
 proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode =
   result = n
   case n.kind
@@ -513,9 +512,9 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode =
       var tmp: PSym = nil
       let isExpr = not isEmptyType(n.typ)
       if isExpr:
-        tmp = ctx.newTempVar(n.typ)
         result = newNodeI(nkStmtListExpr, n.info)
         result.typ = n.typ
+        tmp = ctx.newTempVar(n.typ, result)
       else:
         result = newNodeI(nkStmtList, n.info)
 
@@ -566,7 +565,7 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode =
         else:
           internalError(ctx.g.config, "lowerStmtListExpr(nkIf): " & $branch.kind)
 
-      if isExpr: result.add(ctx.newEnvVarAccess(tmp))
+      if isExpr: result.add(ctx.newTempVarAccess(tmp))
 
   of nkTryStmt, nkHiddenTryStmt:
     var ns = false
@@ -580,7 +579,7 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode =
       if isExpr:
         result = newNodeI(nkStmtListExpr, n.info)
         result.typ = n.typ
-        let tmp = ctx.newTempVar(n.typ)
+        let tmp = ctx.newTempVar(n.typ, result)
 
         n[0] = ctx.convertExprBodyToAsgn(n[0], tmp)
         for i in 1..<n.len:
@@ -596,7 +595,7 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode =
           else:
             internalError(ctx.g.config, "lowerStmtListExpr(nkTryStmt): " & $branch.kind)
         result.add(n)
-        result.add(ctx.newEnvVarAccess(tmp))
+        result.add(ctx.newTempVarAccess(tmp))
 
   of nkCaseStmt:
     var ns = false
@@ -609,9 +608,9 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode =
       let isExpr = not isEmptyType(n.typ)
 
       if isExpr:
-        let tmp = ctx.newTempVar(n.typ)
         result = newNodeI(nkStmtListExpr, n.info)
         result.typ = n.typ
+        let tmp = ctx.newTempVar(n.typ, result)
 
         if n[0].kind == nkStmtListExpr:
           let (st, ex) = exprToStmtList(n[0])
@@ -628,7 +627,7 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode =
           else:
             internalError(ctx.g.config, "lowerStmtListExpr(nkCaseStmt): " & $branch.kind)
         result.add(n)
-        result.add(ctx.newEnvVarAccess(tmp))
+        result.add(ctx.newTempVarAccess(tmp))
       elif n[0].kind == nkStmtListExpr:
         result = newNodeI(nkStmtList, n.info)
         let (st, ex) = exprToStmtList(n[0])
@@ -658,10 +657,10 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode =
           result.add(st)
           cond = ex
 
-        let tmp = ctx.newTempVar(cond.typ)
-        result.add(ctx.newEnvVarAsgn(tmp, cond))
+        let tmp = ctx.newTempVar(cond.typ, result, cond)
+        # result.add(ctx.newTempVarAsgn(tmp, cond))
 
-        var check = ctx.newEnvVarAccess(tmp)
+        var check = ctx.newTempVarAccess(tmp)
         if n[0].sym.magic == mOr:
           check = ctx.g.newNotCall(check)
 
@@ -671,12 +670,12 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode =
           let (st, ex) = exprToStmtList(cond)
           ifBody.add(st)
           cond = ex
-        ifBody.add(ctx.newEnvVarAsgn(tmp, cond))
+        ifBody.add(ctx.newTempVarAsgn(tmp, cond))
 
         let ifBranch = newTree(nkElifBranch, check, ifBody)
         let ifNode = newTree(nkIfStmt, ifBranch)
         result.add(ifNode)
-        result.add(ctx.newEnvVarAccess(tmp))
+        result.add(ctx.newTempVarAccess(tmp))
       else:
         for i in 0..<n.len:
           if n[i].kind == nkStmtListExpr:
@@ -685,9 +684,9 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode =
             n[i] = ex
 
           if n[i].kind in nkCallKinds: # XXX: This should better be some sort of side effect tracking
-            let tmp = ctx.newTempVar(n[i].typ)
-            result.add(ctx.newEnvVarAsgn(tmp, n[i]))
-            n[i] = ctx.newEnvVarAccess(tmp)
+            let tmp = ctx.newTempVar(n[i].typ, result, n[i])
+            # result.add(ctx.newTempVarAsgn(tmp, n[i]))
+            n[i] = ctx.newTempVarAccess(tmp)
 
         result.add(n)
 
@@ -703,6 +702,12 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode =
         let (st, ex) = exprToStmtList(c[^1])
         result.add(st)
         c[^1] = ex
+      for i in 0 .. c.len - 3:
+        if c[i].kind == nkSym:
+          let s = c[i].sym
+          if sfForceLift in s.flags:
+            ctx.captureVar(s)
+
       result.add(varSect)
 
   of nkDiscardStmt, nkReturnStmt, nkRaiseStmt:
@@ -1279,13 +1284,6 @@ proc wrapIntoStateLoop(ctx: var Ctx, n: PNode): PNode =
   result.info = n.info
 
   let localVars = newNodeI(nkStmtList, n.info)
-  if not ctx.stateVarSym.isNil:
-    let varSect = newNodeI(nkVarSection, n.info)
-    addVar(varSect, newSymNode(ctx.stateVarSym))
-    localVars.add(varSect)
-
-    if not ctx.tempVars.isNil:
-      localVars.add(ctx.tempVars)
 
   let blockStmt = newNodeI(nkBlockStmt, n.info)
   blockStmt.add(newSymNode(ctx.stateLoopLabel))
@@ -1433,15 +1431,67 @@ proc preprocess(c: var PreprocessContext; n: PNode): PNode =
     for i in 0 ..< n.len:
       result[i] = preprocess(c, n[i])
 
+proc detectCapturedVars(c: var Ctx, n: PNode, stateIdx: int) =
+  case n.kind
+  of nkSym:
+    let s = n.sym
+    if s.kind in {skResult, skVar, skLet, skForVar, skTemp} and sfGlobal notin s.flags and s.owner == c.fn:
+      let vs = c.varStates.getOrDefault(s.itemId, localNotSeen)
+      if vs == localNotSeen: # First seing this variable
+        c.varStates[s.itemId] = stateIdx
+      elif vs == localRequiresLifting:
+        discard # Sym already marked
+      elif vs != stateIdx:
+        c.captureVar(s)
+  of nkReturnStmt:
+    if n[0].kind in {nkAsgn, nkFastAsgn, nkSinkAsgn}:
+      # we have a `result = result` expression produced by the closure
+      # transform, let's not touch the LHS in order to make the lifting pass
+      # correct when `result` is lifted
+      detectCapturedVars(c, n[0][1], stateIdx)
+    else:
+      detectCapturedVars(c, n[0], stateIdx)
+  else:
+    for i in 0 ..< n.safeLen:
+      detectCapturedVars(c, n[i], stateIdx)
+
+proc detectCapturedVars(c: var Ctx) =
+  for i, s in c.states:
+    detectCapturedVars(c, s.body, i)
+
+proc liftLocals(c: var Ctx, n: PNode): PNode =
+  result = n
+  case n.kind
+  of nkSym:
+    let s = n.sym
+    if c.varStates.getOrDefault(s.itemId) == localRequiresLifting:
+      # lift
+      let e = getEnvParam(c.fn)
+      let field = getFieldFromObj(e.typ.elementType, s)
+      assert(field != nil)
+      result = rawIndirectAccess(newSymNode(e), field, n.info)
+    # elif c.varStates.getOrDefault(s.itemId, localNotSeen) != localNotSeen:
+    #   echo "Not lifting ", s.name.s
+
+  of nkReturnStmt:
+    if n[0].kind in {nkAsgn, nkFastAsgn, nkSinkAsgn}:
+      # we have a `result = result` expression produced by the closure
+      # transform, let's not touch the LHS in order to make the lifting pass
+      # correct when `result` is lifted
+      n[0][1] = liftLocals(c, n[0][1])
+    else:
+      n[0] = liftLocals(c, n[0])
+  else:
+    for i in 0 ..< n.safeLen:
+      n[i] = liftLocals(c, n[i])
+
 proc transformClosureIterator*(g: ModuleGraph; idgen: IdGenerator; fn: PSym, n: PNode): PNode =
   var ctx = Ctx(g: g, fn: fn, idgen: idgen)
 
-  if getEnvParam(fn).isNil:
-    # Lambda lifting was not done yet. Use temporary :state sym, which will
-    # be handled specially by lambda lifting. Local temp vars (if needed)
-    # should follow the same logic.
-    ctx.stateVarSym = newSym(skVar, getIdent(ctx.g.cache, ":state"), idgen, fn, fn.info)
-    ctx.stateVarSym.typ = g.createClosureIterStateType(fn, idgen)
+  # The transformation should always happen after at least partial lambdalifting
+  # is performed, so that the closure iter environment is always created upfront.
+  doAssert(getEnvParam(fn) != nil, "Env param not created before iter transformation")
+
   ctx.stateLoopLabel = newSym(skLabel, getIdent(ctx.g.cache, ":stateLoop"), idgen, fn, fn.info)
   var pc = PreprocessContext(finallys: @[], config: g.config, idgen: idgen)
   var n = preprocess(pc, n.toStmtList)
@@ -1466,6 +1516,10 @@ proc transformClosureIterator*(g: ModuleGraph; idgen: IdGenerator; fn: PSym, n:
   let caseDispatcher = newTreeI(nkCaseStmt, n.info,
       ctx.newStateAccess())
 
+  # Lamdalifting will not touch our locals, it is our responsibility to lift those that
+  # need it.
+  detectCapturedVars(ctx)
+
   for s in ctx.states:
     let body = ctx.transformStateAssignments(s.body)
     caseDispatcher.add newTreeI(nkOfBranch, body.info, g.newIntLit(body.info, s.label), body)
@@ -1473,6 +1527,7 @@ proc transformClosureIterator*(g: ModuleGraph; idgen: IdGenerator; fn: PSym, n:
   caseDispatcher.add newTreeI(nkElse, n.info, newTreeI(nkReturnStmt, n.info, g.emptyNode))
 
   result = wrapIntoStateLoop(ctx, caseDispatcher)
+  result = liftLocals(ctx, result)
 
   when false:
     echo "TRANSFORM TO STATES: "
@@ -1481,3 +1536,5 @@ proc transformClosureIterator*(g: ModuleGraph; idgen: IdGenerator; fn: PSym, n:
     echo "exception table:"
     for i, e in ctx.exceptionTable:
       echo i, " -> ", e
+
+    echo "ENV: ", renderTree(getEnvParam(fn).typ.elementType.n)
diff --git a/compiler/lambdalifting.nim b/compiler/lambdalifting.nim
index e51fb3ae0..2a72c3846 100644
--- a/compiler/lambdalifting.nim
+++ b/compiler/lambdalifting.nim
@@ -145,11 +145,13 @@ proc createStateField(g: ModuleGraph; iter: PSym; idgen: IdGenerator): PSym =
   result = newSym(skField, getIdent(g.cache, ":state"), idgen, iter, iter.info)
   result.typ = createClosureIterStateType(g, iter, idgen)
 
+template isIterator*(owner: PSym): bool =
+  owner.kind == skIterator and owner.typ.callConv == ccClosure
+
 proc createEnvObj(g: ModuleGraph; idgen: IdGenerator; owner: PSym; info: TLineInfo): PType =
-  # YYY meh, just add the state field for every closure for now, it's too
-  # hard to figure out if it comes from a closure iterator:
   result = createObj(g, idgen, owner, info, final=false)
-  rawAddField(result, createStateField(g, owner, idgen))
+  if owner.isIterator:
+    rawAddField(result, createStateField(g, owner, idgen))
 
 proc getClosureIterResult*(g: ModuleGraph; iter: PSym; idgen: IdGenerator): PSym =
   if resultPos < iter.ast.len:
@@ -172,26 +174,22 @@ proc addHiddenParam(routine: PSym, param: PSym) =
   assert sfFromGeneric in param.flags
   #echo "produced environment: ", param.id, " for ", routine.id
 
-proc getHiddenParam(g: ModuleGraph; routine: PSym): PSym =
+proc getEnvParam*(routine: PSym): PSym =
   let params = routine.ast[paramsPos]
   let hidden = lastSon(params)
   if hidden.kind == nkSym and hidden.sym.kind == skParam and hidden.sym.name.s == paramName:
     result = hidden.sym
     assert sfFromGeneric in result.flags
   else:
+    result = nil
+
+proc getHiddenParam(g: ModuleGraph; routine: PSym): PSym =
+  result = getEnvParam(routine)
+  if result.isNil:
     # writeStackTrace()
     localError(g.config, routine.info, "internal error: could not find env param for " & routine.name.s)
     result = routine
 
-proc getEnvParam*(routine: PSym): PSym =
-  let params = routine.ast[paramsPos]
-  let hidden = lastSon(params)
-  if hidden.kind == nkSym and hidden.sym.name.s == paramName:
-    result = hidden.sym
-    assert sfFromGeneric in result.flags
-  else:
-    result = nil
-
 proc interestingVar(s: PSym): bool {.inline.} =
   result = s.kind in {skVar, skLet, skTemp, skForVar, skParam, skResult} and
     sfGlobal notin s.flags and
@@ -230,15 +228,6 @@ proc makeClosure*(g: ModuleGraph; idgen: IdGenerator; prc: PSym; env: PNode; inf
   if tfHasAsgn in result.typ.flags or optSeqDestructors in g.config.globalOptions:
     prc.flags.incl sfInjectDestructors
 
-proc interestingIterVar(s: PSym): bool {.inline.} =
-  # XXX optimization: Only lift the variable if it lives across
-  # yield/return boundaries! This can potentially speed up
-  # closure iterators quite a bit.
-  result = s.kind in {skResult, skVar, skLet, skTemp, skForVar} and sfGlobal notin s.flags
-
-template isIterator*(owner: PSym): bool =
-  owner.kind == skIterator and owner.typ.callConv == ccClosure
-
 template liftingHarmful(conf: ConfigRef; owner: PSym): bool =
   ## lambda lifting can be harmful for JS-like code generators.
   let isCompileTime = sfCompileTime in owner.flags or owner.kind == skMacro
@@ -285,15 +274,6 @@ proc liftIterSym*(g: ModuleGraph; n: PNode; idgen: IdGenerator; owner: PSym): PN
   createTypeBoundOpsLL(g, env.typ, n.info, idgen, owner)
   result.add makeClosure(g, idgen, iter, env, n.info)
 
-proc freshVarForClosureIter*(g: ModuleGraph; s: PSym; idgen: IdGenerator; owner: PSym): PNode =
-  let envParam = getHiddenParam(g, owner)
-  let obj = envParam.typ.skipTypes({tyOwned, tyRef, tyPtr})
-  let field = addField(obj, s, g.cache, idgen)
-
-  var access = newSymNode(envParam)
-  assert obj.kind == tyObject
-  result = rawIndirectAccess(access, field, s.info)
-
 # ------------------ new stuff -------------------------------------------
 
 proc markAsClosure(g: ModuleGraph; owner: PSym; n: PNode) =
@@ -341,9 +321,13 @@ proc getEnvTypeForOwner(c: var DetectionPass; owner: PSym;
                         info: TLineInfo): PType =
   result = c.ownerToType.getOrDefault(owner.id)
   if result.isNil:
-    result = newType(tyRef, c.idgen, owner)
-    let obj = createEnvObj(c.graph, c.idgen, owner, info)
-    rawAddSon(result, obj)
+    let env = getEnvParam(owner)
+    if env.isNil or not owner.isIterator:
+      result = newType(tyRef, c.idgen, owner)
+      let obj = createEnvObj(c.graph, c.idgen, owner, info)
+      rawAddSon(result, obj)
+    else:
+      result = env.typ
     c.ownerToType[owner.id] = result
 
 proc asOwnedRef(c: var DetectionPass; t: PType): PType =
@@ -459,17 +443,6 @@ proc detectCapturedVars(n: PNode; owner: PSym; c: var DetectionPass) =
       if owner.isIterator:
         c.somethingToDo = true
         addClosureParam(c, owner, n.info)
-        if interestingIterVar(s):
-          if not c.capturedVars.contains(s.id):
-            if not c.inTypeOf: c.capturedVars.incl(s.id)
-            let obj = getHiddenParam(c.graph, owner).typ.skipTypes({tyOwned, tyRef, tyPtr})
-            #let obj = c.getEnvTypeForOwner(s.owner).skipTypes({tyOwned, tyRef, tyPtr})
-
-            if s.name.id == getIdent(c.graph.cache, ":state").id:
-              obj.n[0].sym.flags.incl sfNoInit
-              obj.n[0].sym.itemId = ItemId(module: s.itemId.module, item: -s.itemId.item)
-            else:
-              discard addField(obj, s, c.graph.cache, c.idgen)
     # direct or indirect dependency:
     elif innerClosure or interested:
       discard """
@@ -773,8 +746,6 @@ proc liftCapturedVars(n: PNode; owner: PSym; d: var DetectionPass;
     elif s.id in d.capturedVars:
       if s.owner != owner:
         result = accessViaEnvParam(d.graph, n, owner)
-      elif owner.isIterator and interestingIterVar(s):
-        result = accessViaEnvParam(d.graph, n, owner)
       else:
         result = accessViaEnvVar(n, owner, d, c)
   of nkEmpty..pred(nkSym), succ(nkSym)..nkNilLit, nkComesFrom,
@@ -892,6 +863,9 @@ proc liftLambdas*(g: ModuleGraph; fn: PSym, body: PNode; tooEarly: var bool;
     # ignore forward declaration:
     result = body
     tooEarly = true
+    if fn.isIterator:
+      var d = initDetectionPass(g, fn, idgen)
+      addClosureParam(d, fn, body.info)
   else:
     var d = initDetectionPass(g, fn, idgen)
     detectCapturedVars(body, fn, d)
diff --git a/compiler/pragmas.nim b/compiler/pragmas.nim
index ed1e067c0..c724aa025 100644
--- a/compiler/pragmas.nim
+++ b/compiler/pragmas.nim
@@ -81,7 +81,7 @@ const
     wRequiresInit, wNoalias, wAlign, wNoInit} - {wExportNims, wNodecl} # why exclude these?
   varPragmas* = declPragmas + {wVolatile, wRegister, wThreadVar,
     wMagic, wHeader, wCompilerProc, wCore, wDynlib,
-    wNoInit, wCompileTime, wGlobal,
+    wNoInit, wCompileTime, wGlobal, wLiftLocals,
     wGensym, wInject, wCodegenDecl,
     wGuard, wGoto, wCursor, wNoalias, wAlign}
   constPragmas* = declPragmas + {wHeader, wMagic,
@@ -1308,7 +1308,8 @@ proc singlePragma(c: PContext, sym: PSym, n: PNode, i: var int,
         noVal(c, it)
         if sym == nil: invalidPragma(c, it)
         else: sym.flags.incl sfUsed
-      of wLiftLocals: discard
+      of wLiftLocals:
+        sym.flags.incl(sfForceLift)
       of wRequires, wInvariant, wAssume, wAssert:
         pragmaProposition(c, it)
       of wEnsures:
diff --git a/compiler/transf.nim b/compiler/transf.nim
index 8dcb74729..ce0157f22 100644
--- a/compiler/transf.nim
+++ b/compiler/transf.nim
@@ -97,10 +97,7 @@ proc newTemp(c: PTransf, typ: PType, info: TLineInfo): PNode =
   r.typ = typ #skipTypes(typ, {tyGenericInst, tyAlias, tySink})
   incl(r.flags, sfFromGeneric)
   let owner = getCurrOwner(c)
-  if owner.isIterator and not c.tooEarly:
-    result = freshVarForClosureIter(c.graph, r, c.idgen, owner)
-  else:
-    result = newSymNode(r)
+  result = newSymNode(r)
 
 proc transform(c: PTransf, n: PNode): PNode
 
@@ -176,13 +173,10 @@ proc transformSym(c: PTransf, n: PNode): PNode =
 
 proc freshVar(c: PTransf; v: PSym): PNode =
   let owner = getCurrOwner(c)
-  if owner.isIterator and not c.tooEarly:
-    result = freshVarForClosureIter(c.graph, v, c.idgen, owner)
-  else:
-    var newVar = copySym(v, c.idgen)
-    incl(newVar.flags, sfFromGeneric)
-    newVar.owner = owner
-    result = newSymNode(newVar)
+  var newVar = copySym(v, c.idgen)
+  incl(newVar.flags, sfFromGeneric)
+  newVar.owner = owner
+  result = newSymNode(newVar)
 
 proc transformVarSection(c: PTransf, v: PNode): PNode =
   result = newTransNode(v)
diff --git a/tests/destructor/tuse_ownedref_after_move.nim b/tests/destructor/tuse_ownedref_after_move.nim
index 69348d530..34d5be65a 100644
--- a/tests/destructor/tuse_ownedref_after_move.nim
+++ b/tests/destructor/tuse_ownedref_after_move.nim
@@ -1,6 +1,6 @@
 discard """
   cmd: '''nim c --newruntime $file'''
-  errormsg: "'=copy' is not available for type <owned Button>; requires a copy because it's not the last read of ':envAlt.b1'; routine: main"
+  errormsg: "'=copy' is not available for type <owned Button>; requires a copy because it's not the last read of ':envAlt.b0'; routine: main"
   line: 48
 """
 
diff --git a/tests/iter/tyieldintry.nim b/tests/iter/tyieldintry.nim
index 47ba6efa6..04409795b 100644
--- a/tests/iter/tyieldintry.nim
+++ b/tests/iter/tyieldintry.nim
@@ -504,3 +504,26 @@ block: # void iterator
     except:
       discard
   var a = it
+
+block: # Locals present in only 1 state should be on the stack
+  proc checkOnStack(a: pointer, shouldBeOnStack: bool) =
+    # Quick and dirty way to check if a points to stack
+    var dummy = 0
+    let dummyAddr = addr dummy
+    let distance = abs(cast[int](dummyAddr) - cast[int](a))
+    const requiredDistance = 300
+    if shouldBeOnStack:
+      doAssert(distance <= requiredDistance, "a is not on stack, but should")
+    else:
+      doAssert(distance > requiredDistance, "a is on stack, but should not")
+
+  iterator it(): int {.closure.} =
+    var a = 1
+    var b = 2
+    var c {.liftLocals.} = 3
+    checkOnStack(addr a, true)
+    checkOnStack(addr b, false)
+    checkOnStack(addr c, false)
+    yield a
+    yield b
+  test(it, 1, 2)