summary refs log tree commit diff stats
path: root/compiler/closureiters.nim
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/closureiters.nim')
-rw-r--r--compiler/closureiters.nim298
1 files changed, 216 insertions, 82 deletions
diff --git a/compiler/closureiters.nim b/compiler/closureiters.nim
index 4e4523601..8bdd04ca7 100644
--- a/compiler/closureiters.nim
+++ b/compiler/closureiters.nim
@@ -18,7 +18,8 @@
 #    dec a
 #
 # Should be transformed to:
-#  STATE0:
+#  case :state
+#  of 0:
 #    if a > 0:
 #      echo "hi"
 #      :state = 1 # Next state
@@ -26,19 +27,14 @@
 #    else:
 #      :state = 2 # Next state
 #      break :stateLoop # Proceed to the next state
-#  STATE1:
+#  of 1:
 #    dec a
 #    :state = 0 # Next state
 #    break :stateLoop # Proceed to the next state
-#  STATE2:
+#  of 2:
 #    :state = -1 # End of execution
-
-# The transformation should play well with lambdalifting, however depending
-# on situation, it can be called either before or after lambdalifting
-# transformation. As such we behave slightly differently, when accessing
-# iterator state, or using temp variables. If lambdalifting did not happen,
-# we just create local variables, so that they will be lifted further on.
-# Otherwise, we utilize existing env, created by lambdalifting.
+#  else:
+#    return
 
 # Lambdalifting treats :state variable specially, it should always end up
 # as the first field in env. Currently C codegen depends on this behavior.
@@ -104,12 +100,13 @@
 # Is transformed to (yields are left in place for example simplicity,
 #    in reality the code is subdivided even more, as described above):
 #
-# STATE0: # Try
+# case :state
+# of 0: # Try
 #   yield 0
 #   raise ...
 #   :state = 2 # What would happen should we not raise
 #   break :stateLoop
-# STATE1: # Except
+# of 1: # Except
 #   yield 1
 #   :tmpResult = 3           # Return
 #   :unrollFinally = true # Return
@@ -117,7 +114,7 @@
 #   break :stateLoop
 #   :state = 2 # What would happen should we not return
 #   break :stateLoop
-# STATE2: # Finally
+# of 2: # Finally
 #   yield 2
 #   if :unrollFinally: # This node is created by `newEndFinallyNode`
 #     if :curExc.isNil:
@@ -130,6 +127,8 @@
 #       raise
 #   state = -1 # Goto next state. In this case we just exit
 #   break :stateLoop
+# else:
+#   return
 
 import
   ast, msgs, idents,
@@ -145,12 +144,11 @@ type
   Ctx = object
     g: ModuleGraph
     fn: PSym
-    stateVarSym: PSym # :state variable. nil if env already introduced by lambdalifting
     tmpResultSym: PSym # Used when we return, but finally has to interfere
     unrollFinallySym: PSym # Indicates that we're unrolling finally states (either exception happened or premature return)
     curExcSym: PSym # Current exception
 
-    states: seq[PNode] # The resulting states. Every state is an nkState node.
+    states: seq[tuple[label: int, body: PNode]] # The resulting states.
     blockLevel: int # Temp used to transform break and continue stmts
     stateLoopLabel: PSym # Label to break on, when jumping between states.
     exitStateIdx: int # index of the last state
@@ -162,15 +160,23 @@ type
     nearestFinally: int # Index of the nearest finally block. For try/except it
                     # is their finally. For finally it is parent finally. Otherwise -1
     idgen: IdGenerator
+    varStates: Table[ItemId, int] # Used to detect if local variable belongs to multiple states
+    stateVarSym: PSym # :state variable. nil if env already introduced by lambdalifting
+      # remove if -d:nimOptIters is default, treating it as always nil
+    nimOptItersEnabled: bool # tracks if -d:nimOptIters is enabled
+      # should be default when issues are fixed, see #24094
 
 const
   nkSkip = {nkEmpty..nkNilLit, nkTemplateDef, nkTypeSection, nkStaticStmt,
             nkCommentStmt, nkMixinStmt, nkBindStmt} + procDefs
+  emptyStateLabel = -1
+  localNotSeen = -1
+  localRequiresLifting = -2
 
 proc newStateAccess(ctx: var Ctx): PNode =
   if ctx.stateVarSym.isNil:
     result = rawIndirectAccess(newSymNode(getEnvParam(ctx.fn)),
-        getStateField(ctx.g, ctx.fn), ctx.fn.info)
+      getStateField(ctx.g, ctx.fn), ctx.fn.info)
   else:
     result = newSymNode(ctx.stateVarSym)
 
@@ -187,7 +193,8 @@ proc newStateAssgn(ctx: var Ctx, stateNo: int = -2): PNode =
 proc newEnvVar(ctx: var Ctx, name: string, typ: PType): PSym =
   result = newSym(skVar, getIdent(ctx.g.cache, name), ctx.idgen, ctx.fn, ctx.fn.info)
   result.typ = typ
-  assert(not typ.isNil)
+  result.flags.incl sfNoInit
+  assert(not typ.isNil, "Env var needs a type")
 
   if not ctx.stateVarSym.isNil:
     # We haven't gone through labmda lifting yet, so just create a local var,
@@ -198,7 +205,7 @@ proc newEnvVar(ctx: var Ctx, name: string, typ: PType): PSym =
   else:
     let envParam = getEnvParam(ctx.fn)
     # let obj = envParam.typ.lastSon
-    result = addUniqueField(envParam.typ.lastSon, result, ctx.g.cache, ctx.idgen)
+    result = addUniqueField(envParam.typ.elementType, result, ctx.g.cache, ctx.idgen)
 
 proc newEnvVarAccess(ctx: Ctx, s: PSym): PNode =
   if ctx.stateVarSym.isNil:
@@ -206,9 +213,12 @@ proc newEnvVarAccess(ctx: Ctx, s: PSym): PNode =
   else:
     result = newSymNode(s)
 
+proc newTempVarAccess(ctx: Ctx, s: PSym): PNode =
+  result = newSymNode(s, ctx.fn.info)
+
 proc newTmpResultAccess(ctx: var Ctx): PNode =
   if ctx.tmpResultSym.isNil:
-    ctx.tmpResultSym = ctx.newEnvVar(":tmpResult", ctx.fn.typ[0])
+    ctx.tmpResultSym = ctx.newEnvVar(":tmpResult", ctx.fn.typ.returnType)
   ctx.newEnvVarAccess(ctx.tmpResultSym)
 
 proc newUnrollFinallyAccess(ctx: var Ctx, info: TLineInfo): PNode =
@@ -228,10 +238,7 @@ proc newState(ctx: var Ctx, n, gotoOut: PNode): int =
 
   result = ctx.states.len
   let resLit = ctx.g.newIntLit(n.info, result)
-  let s = newNodeI(nkState, n.info)
-  s.add(resLit)
-  s.add(n)
-  ctx.states.add(s)
+  ctx.states.add((result, n))
   ctx.exceptionTable.add(ctx.curExcHandlingState)
 
   if not gotoOut.isNil:
@@ -250,9 +257,26 @@ proc addGotoOut(n: PNode, gotoOut: PNode): PNode =
   if result.len == 0 or result[^1].kind != nkGotoState:
     result.add(gotoOut)
 
-proc newTempVar(ctx: var Ctx, typ: PType): PSym =
-  result = ctx.newEnvVar(":tmpSlLower" & $ctx.tempVarId, typ)
+proc newTempVarDef(ctx: Ctx, s: PSym, initialValue: PNode): PNode =
+  var v = initialValue
+  if v == nil:
+    v = ctx.g.emptyNode
+  newTree(nkVarSection, newTree(nkIdentDefs, newSymNode(s), ctx.g.emptyNode, v))
+
+proc newEnvVarAsgn(ctx: Ctx, s: PSym, v: PNode): PNode
+
+proc newTempVar(ctx: var Ctx, typ: PType, parent: PNode, initialValue: PNode = nil): PSym =
+  if ctx.nimOptItersEnabled:
+    result = newSym(skVar, getIdent(ctx.g.cache, ":tmpSlLower" & $ctx.tempVarId), ctx.idgen, ctx.fn, ctx.fn.info)
+  else:
+    result = ctx.newEnvVar(":tmpSlLower" & $ctx.tempVarId, typ)
   inc ctx.tempVarId
+  result.typ = typ
+  assert(not typ.isNil, "Temp var needs a type")
+  if ctx.nimOptItersEnabled:
+    parent.add(ctx.newTempVarDef(result, initialValue))
+  elif initialValue != nil:
+    parent.add(ctx.newEnvVarAsgn(result, initialValue))
 
 proc hasYields(n: PNode): bool =
   # TODO: This is very inefficient. It traverses the node, looking for nkYieldStmt.
@@ -263,8 +287,8 @@ proc hasYields(n: PNode): bool =
     result = false
   else:
     result = false
-    for c in n:
-      if c.hasYields:
+    for i in ord(n.kind == nkCast)..<n.len:
+      if n[i].hasYields:
         result = true
         break
 
@@ -413,7 +437,7 @@ proc hasYieldsInExpressions(n: PNode): bool =
 
 proc exprToStmtList(n: PNode): tuple[s, res: PNode] =
   assert(n.kind == nkStmtListExpr)
-  result.s = newNodeI(nkStmtList, n.info)
+  result = (newNodeI(nkStmtList, n.info), nil)
   result.s.sons = @[]
 
   var n = n
@@ -424,8 +448,15 @@ proc exprToStmtList(n: PNode): tuple[s, res: PNode] =
 
   result.res = n
 
+proc newTempVarAsgn(ctx: Ctx, s: PSym, v: PNode): PNode =
+  if isEmptyType(v.typ):
+    result = v
+  else:
+    result = newTree(nkFastAsgn, ctx.newTempVarAccess(s), v)
+    result.info = v.info
 
 proc newEnvVarAsgn(ctx: Ctx, s: PSym, v: PNode): PNode =
+  # unused with -d:nimOptIters
   if isEmptyType(v.typ):
     result = v
   else:
@@ -433,10 +464,13 @@ proc newEnvVarAsgn(ctx: Ctx, s: PSym, v: PNode): PNode =
     result.info = v.info
 
 proc addExprAssgn(ctx: Ctx, output, input: PNode, sym: PSym) =
+  var input = input
   if input.kind == nkStmtListExpr:
     let (st, res) = exprToStmtList(input)
     output.add(st)
-    output.add(ctx.newEnvVarAsgn(sym, res))
+    input = res
+  if ctx.nimOptItersEnabled:
+    output.add(ctx.newTempVarAsgn(sym, input))
   else:
     output.add(ctx.newEnvVarAsgn(sym, input))
 
@@ -448,6 +482,16 @@ proc newNotCall(g: ModuleGraph; e: PNode): PNode =
   result = newTree(nkCall, newSymNode(g.getSysMagic(e.info, "not", mNot), e.info), e)
   result.typ = g.getSysType(e.info, tyBool)
 
+proc boolLit(g: ModuleGraph; info: TLineInfo; value: bool): PNode =
+  result = newIntLit(g, info, ord value)
+  result.typ = getSysType(g, info, tyBool)
+
+proc captureVar(c: var Ctx, s: PSym) =
+  if c.varStates.getOrDefault(s.itemId) != localRequiresLifting:
+    c.varStates[s.itemId] = localRequiresLifting # Mark this variable for lifting
+    let e = getEnvParam(c.fn)
+    discard addField(e.typ.elementType, s, c.g.cache, c.idgen)
+
 proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode =
   result = n
   case n.kind
@@ -504,9 +548,9 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode =
       var tmp: PSym = nil
       let isExpr = not isEmptyType(n.typ)
       if isExpr:
-        tmp = ctx.newTempVar(n.typ)
         result = newNodeI(nkStmtListExpr, n.info)
         result.typ = n.typ
+        tmp = ctx.newTempVar(n.typ, result)
       else:
         result = newNodeI(nkStmtList, n.info)
 
@@ -557,7 +601,11 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode =
         else:
           internalError(ctx.g.config, "lowerStmtListExpr(nkIf): " & $branch.kind)
 
-      if isExpr: result.add(ctx.newEnvVarAccess(tmp))
+      if isExpr:
+        if ctx.nimOptItersEnabled:
+          result.add(ctx.newTempVarAccess(tmp))
+        else:
+          result.add(ctx.newEnvVarAccess(tmp))
 
   of nkTryStmt, nkHiddenTryStmt:
     var ns = false
@@ -571,7 +619,7 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode =
       if isExpr:
         result = newNodeI(nkStmtListExpr, n.info)
         result.typ = n.typ
-        let tmp = ctx.newTempVar(n.typ)
+        let tmp = ctx.newTempVar(n.typ, result)
 
         n[0] = ctx.convertExprBodyToAsgn(n[0], tmp)
         for i in 1..<n.len:
@@ -587,7 +635,10 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode =
           else:
             internalError(ctx.g.config, "lowerStmtListExpr(nkTryStmt): " & $branch.kind)
         result.add(n)
-        result.add(ctx.newEnvVarAccess(tmp))
+        if ctx.nimOptItersEnabled:
+          result.add(ctx.newTempVarAccess(tmp))
+        else:
+          result.add(ctx.newEnvVarAccess(tmp))
 
   of nkCaseStmt:
     var ns = false
@@ -600,9 +651,9 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode =
       let isExpr = not isEmptyType(n.typ)
 
       if isExpr:
-        let tmp = ctx.newTempVar(n.typ)
         result = newNodeI(nkStmtListExpr, n.info)
         result.typ = n.typ
+        let tmp = ctx.newTempVar(n.typ, result)
 
         if n[0].kind == nkStmtListExpr:
           let (st, ex) = exprToStmtList(n[0])
@@ -619,7 +670,10 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode =
           else:
             internalError(ctx.g.config, "lowerStmtListExpr(nkCaseStmt): " & $branch.kind)
         result.add(n)
-        result.add(ctx.newEnvVarAccess(tmp))
+        if ctx.nimOptItersEnabled:
+          result.add(ctx.newTempVarAccess(tmp))
+        else:
+          result.add(ctx.newEnvVarAccess(tmp))
       elif n[0].kind == nkStmtListExpr:
         result = newNodeI(nkStmtList, n.info)
         let (st, ex) = exprToStmtList(n[0])
@@ -649,10 +703,14 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode =
           result.add(st)
           cond = ex
 
-        let tmp = ctx.newTempVar(cond.typ)
-        result.add(ctx.newEnvVarAsgn(tmp, cond))
+        let tmp = ctx.newTempVar(cond.typ, result, cond)
+        # result.add(ctx.newTempVarAsgn(tmp, cond))
 
-        var check = ctx.newEnvVarAccess(tmp)
+        var check: PNode
+        if ctx.nimOptItersEnabled:
+          check = ctx.newTempVarAccess(tmp)
+        else:
+          check = ctx.newEnvVarAccess(tmp)
         if n[0].sym.magic == mOr:
           check = ctx.g.newNotCall(check)
 
@@ -662,12 +720,18 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode =
           let (st, ex) = exprToStmtList(cond)
           ifBody.add(st)
           cond = ex
-        ifBody.add(ctx.newEnvVarAsgn(tmp, cond))
+        if ctx.nimOptItersEnabled:
+          ifBody.add(ctx.newTempVarAsgn(tmp, cond))
+        else:
+          ifBody.add(ctx.newEnvVarAsgn(tmp, cond))
 
         let ifBranch = newTree(nkElifBranch, check, ifBody)
         let ifNode = newTree(nkIfStmt, ifBranch)
         result.add(ifNode)
-        result.add(ctx.newEnvVarAccess(tmp))
+        if ctx.nimOptItersEnabled:
+          result.add(ctx.newTempVarAccess(tmp))
+        else:
+          result.add(ctx.newEnvVarAccess(tmp))
       else:
         for i in 0..<n.len:
           if n[i].kind == nkStmtListExpr:
@@ -676,9 +740,12 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode =
             n[i] = ex
 
           if n[i].kind in nkCallKinds: # XXX: This should better be some sort of side effect tracking
-            let tmp = ctx.newTempVar(n[i].typ)
-            result.add(ctx.newEnvVarAsgn(tmp, n[i]))
-            n[i] = ctx.newEnvVarAccess(tmp)
+            let tmp = ctx.newTempVar(n[i].typ, result, n[i])
+            # result.add(ctx.newTempVarAsgn(tmp, n[i]))
+            if ctx.nimOptItersEnabled:
+              n[i] = ctx.newTempVarAccess(tmp)
+            else:
+              n[i] = ctx.newEnvVarAccess(tmp)
 
         result.add(n)
 
@@ -694,6 +761,12 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode =
         let (st, ex) = exprToStmtList(c[^1])
         result.add(st)
         c[^1] = ex
+      for i in 0 .. c.len - 3:
+        if c[i].kind == nkSym:
+          let s = c[i].sym
+          if sfForceLift in s.flags:
+            ctx.captureVar(s)
+
       result.add(varSect)
 
   of nkDiscardStmt, nkReturnStmt, nkRaiseStmt:
@@ -779,7 +852,7 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode =
         let check = newTree(nkIfStmt, branch)
         let newBody = newTree(nkStmtList, st, check, n[1])
 
-        n[0] = newSymNode(ctx.g.getSysSym(n[0].info, "true"))
+        n[0] = ctx.g.boolLit(n[0].info, true)
         n[1] = newBody
 
   of nkDotExpr, nkCheckedFieldExpr:
@@ -831,7 +904,7 @@ proc newEndFinallyNode(ctx: var Ctx, info: TLineInfo): PNode =
   let retStmt =
     if ctx.nearestFinally == 0:
       # last finally, we can return
-      let retValue = if ctx.fn.typ[0].isNil:
+      let retValue = if ctx.fn.typ.returnType.isNil:
                    ctx.g.emptyNode
                  else:
                    newTree(nkFastAsgn,
@@ -1133,10 +1206,10 @@ proc skipEmptyStates(ctx: Ctx, stateIdx: int): int =
     let label = stateIdx
     if label == ctx.exitStateIdx: break
     var newLabel = label
-    if label == -1:
+    if label == emptyStateLabel:
       newLabel = ctx.exitStateIdx
     else:
-      let fs = skipStmtList(ctx, ctx.states[label][1])
+      let fs = skipStmtList(ctx, ctx.states[label].body)
       if fs.kind == nkGotoState:
         newLabel = fs[0].intVal.int
     if label == newLabel: break
@@ -1145,7 +1218,7 @@ proc skipEmptyStates(ctx: Ctx, stateIdx: int): int =
     if maxJumps == 0:
       assert(false, "Internal error")
 
-  result = ctx.states[stateIdx][0].intVal.int
+  result = ctx.states[stateIdx].label
 
 proc skipThroughEmptyStates(ctx: var Ctx, n: PNode): PNode=
   result = n
@@ -1263,11 +1336,10 @@ proc wrapIntoTryExcept(ctx: var Ctx, n: PNode): PNode {.inline.} =
 proc wrapIntoStateLoop(ctx: var Ctx, n: PNode): PNode =
   # while true:
   #   block :stateLoop:
-  #     gotoState :state
   #     local vars decl (if needed)
   #     body # Might get wrapped in try-except
   let loopBody = newNodeI(nkStmtList, n.info)
-  result = newTree(nkWhileStmt, newSymNode(ctx.g.getSysSym(n.info, "true")), loopBody)
+  result = newTree(nkWhileStmt, ctx.g.boolLit(n.info, true), loopBody)
   result.info = n.info
 
   let localVars = newNodeI(nkStmtList, n.info)
@@ -1282,11 +1354,7 @@ proc wrapIntoStateLoop(ctx: var Ctx, n: PNode): PNode =
   let blockStmt = newNodeI(nkBlockStmt, n.info)
   blockStmt.add(newSymNode(ctx.stateLoopLabel))
 
-  let gs = newNodeI(nkGotoState, n.info)
-  gs.add(ctx.newStateAccess())
-  gs.add(ctx.g.newIntLit(n.info, ctx.states.len - 1))
-
-  var blockBody = newTree(nkStmtList, gs, localVars, n)
+  var blockBody = newTree(nkStmtList, localVars, n)
   if ctx.hasExceptions:
     blockBody = ctx.wrapIntoTryExcept(blockBody)
 
@@ -1299,29 +1367,28 @@ proc deleteEmptyStates(ctx: var Ctx) =
 
   # Apply new state indexes and mark unused states with -1
   var iValid = 0
-  for i, s in ctx.states:
-    let body = skipStmtList(ctx, s[1])
+  for i, s in ctx.states.mpairs:
+    let body = skipStmtList(ctx, s.body)
     if body.kind == nkGotoState and i != ctx.states.len - 1 and i != 0:
       # This is an empty state. Mark with -1.
-      s[0].intVal = -1
+      s.label = emptyStateLabel
     else:
-      s[0].intVal = iValid
+      s.label = iValid
       inc iValid
 
   for i, s in ctx.states:
-    let body = skipStmtList(ctx, s[1])
+    let body = skipStmtList(ctx, s.body)
     if body.kind != nkGotoState or i == 0:
-      discard ctx.skipThroughEmptyStates(s)
+      discard ctx.skipThroughEmptyStates(s.body)
       let excHandlState = ctx.exceptionTable[i]
       if excHandlState < 0:
         ctx.exceptionTable[i] = -ctx.skipEmptyStates(-excHandlState)
       elif excHandlState != 0:
         ctx.exceptionTable[i] = ctx.skipEmptyStates(excHandlState)
 
-  var i = 0
+  var i = 1 # ignore the entry and the exit
   while i < ctx.states.len - 1:
-    let fs = skipStmtList(ctx, ctx.states[i][1])
-    if fs.kind == nkGotoState and i != 0:
+    if ctx.states[i].label == emptyStateLabel:
       ctx.states.delete(i)
       ctx.exceptionTable.delete(i)
     else:
@@ -1430,18 +1497,77 @@ proc preprocess(c: var PreprocessContext; n: PNode): PNode =
     for i in 0 ..< n.len:
       result[i] = preprocess(c, n[i])
 
+proc detectCapturedVars(c: var Ctx, n: PNode, stateIdx: int) =
+  case n.kind
+  of nkSym:
+    let s = n.sym
+    if s.kind in {skResult, skVar, skLet, skForVar, skTemp} and sfGlobal notin s.flags and s.owner == c.fn:
+      let vs = c.varStates.getOrDefault(s.itemId, localNotSeen)
+      if vs == localNotSeen: # First seing this variable
+        c.varStates[s.itemId] = stateIdx
+      elif vs == localRequiresLifting:
+        discard # Sym already marked
+      elif vs != stateIdx:
+        c.captureVar(s)
+  of nkReturnStmt:
+    if n[0].kind in {nkAsgn, nkFastAsgn, nkSinkAsgn}:
+      # we have a `result = result` expression produced by the closure
+      # transform, let's not touch the LHS in order to make the lifting pass
+      # correct when `result` is lifted
+      detectCapturedVars(c, n[0][1], stateIdx)
+    else:
+      detectCapturedVars(c, n[0], stateIdx)
+  else:
+    for i in 0 ..< n.safeLen:
+      detectCapturedVars(c, n[i], stateIdx)
+
+proc detectCapturedVars(c: var Ctx) =
+  for i, s in c.states:
+    detectCapturedVars(c, s.body, i)
+
+proc liftLocals(c: var Ctx, n: PNode): PNode =
+  result = n
+  case n.kind
+  of nkSym:
+    let s = n.sym
+    if c.varStates.getOrDefault(s.itemId) == localRequiresLifting:
+      # lift
+      let e = getEnvParam(c.fn)
+      let field = getFieldFromObj(e.typ.elementType, s)
+      assert(field != nil)
+      result = rawIndirectAccess(newSymNode(e), field, n.info)
+    # elif c.varStates.getOrDefault(s.itemId, localNotSeen) != localNotSeen:
+    #   echo "Not lifting ", s.name.s
+
+  of nkReturnStmt:
+    if n[0].kind in {nkAsgn, nkFastAsgn, nkSinkAsgn}:
+      # we have a `result = result` expression produced by the closure
+      # transform, let's not touch the LHS in order to make the lifting pass
+      # correct when `result` is lifted
+      n[0][1] = liftLocals(c, n[0][1])
+    else:
+      n[0] = liftLocals(c, n[0])
+  else:
+    for i in 0 ..< n.safeLen:
+      n[i] = liftLocals(c, n[i])
+
 proc transformClosureIterator*(g: ModuleGraph; idgen: IdGenerator; fn: PSym, n: PNode): PNode =
-  var ctx: Ctx
-  ctx.g = g
-  ctx.fn = fn
-  ctx.idgen = idgen
+  var ctx = Ctx(g: g, fn: fn, idgen: idgen,
+    # should be default when issues are fixed, see #24094:
+    nimOptItersEnabled: isDefined(g.config, "nimOptIters"))
 
   if getEnvParam(fn).isNil:
-    # Lambda lifting was not done yet. Use temporary :state sym, which will
-    # be handled specially by lambda lifting. Local temp vars (if needed)
-    # should follow the same logic.
-    ctx.stateVarSym = newSym(skVar, getIdent(ctx.g.cache, ":state"), idgen, fn, fn.info)
-    ctx.stateVarSym.typ = g.createClosureIterStateType(fn, idgen)
+    if ctx.nimOptItersEnabled:
+      # The transformation should always happen after at least partial lambdalifting
+      # is performed, so that the closure iter environment is always created upfront.
+      doAssert(false, "Env param not created before iter transformation")
+    else:
+      # Lambda lifting was not done yet. Use temporary :state sym, which will
+      # be handled specially by lambda lifting. Local temp vars (if needed)
+      # should follow the same logic.
+      ctx.stateVarSym = newSym(skVar, getIdent(ctx.g.cache, ":state"), idgen, fn, fn.info)
+      ctx.stateVarSym.typ = g.createClosureIterStateType(fn, idgen)
+
   ctx.stateLoopLabel = newSym(skLabel, getIdent(ctx.g.cache, ":stateLoop"), idgen, fn, fn.info)
   var pc = PreprocessContext(finallys: @[], config: g.config, idgen: idgen)
   var n = preprocess(pc, n.toStmtList)
@@ -1463,17 +1589,23 @@ proc transformClosureIterator*(g: ModuleGraph; idgen: IdGenerator; fn: PSym, n:
   # Optimize empty states away
   ctx.deleteEmptyStates()
 
-  # Make new body by concatenating the list of states
-  result = newNodeI(nkStmtList, n.info)
+  let caseDispatcher = newTreeI(nkCaseStmt, n.info,
+      ctx.newStateAccess())
+
+  if ctx.nimOptItersEnabled:
+    # Lamdalifting will not touch our locals, it is our responsibility to lift those that
+    # need it.
+    detectCapturedVars(ctx)
+
   for s in ctx.states:
-    assert(s.len == 2)
-    let body = s[1]
-    s.sons.del(1)
-    result.add(s)
-    result.add(body)
+    let body = ctx.transformStateAssignments(s.body)
+    caseDispatcher.add newTreeI(nkOfBranch, body.info, g.newIntLit(body.info, s.label), body)
+
+  caseDispatcher.add newTreeI(nkElse, n.info, newTreeI(nkReturnStmt, n.info, g.emptyNode))
 
-  result = ctx.transformStateAssignments(result)
-  result = ctx.wrapIntoStateLoop(result)
+  result = wrapIntoStateLoop(ctx, caseDispatcher)
+  if ctx.nimOptItersEnabled:
+    result = liftLocals(ctx, result)
 
   when false:
     echo "TRANSFORM TO STATES: "
@@ -1482,3 +1614,5 @@ proc transformClosureIterator*(g: ModuleGraph; idgen: IdGenerator; fn: PSym, n:
     echo "exception table:"
     for i, e in ctx.exceptionTable:
       echo i, " -> ", e
+
+    echo "ENV: ", renderTree(getEnvParam(fn).typ.elementType.n)