diff options
Diffstat (limited to 'compiler/closureiters.nim')
-rw-r--r-- | compiler/closureiters.nim | 298 |
1 files changed, 216 insertions, 82 deletions
diff --git a/compiler/closureiters.nim b/compiler/closureiters.nim index 4e4523601..8bdd04ca7 100644 --- a/compiler/closureiters.nim +++ b/compiler/closureiters.nim @@ -18,7 +18,8 @@ # dec a # # Should be transformed to: -# STATE0: +# case :state +# of 0: # if a > 0: # echo "hi" # :state = 1 # Next state @@ -26,19 +27,14 @@ # else: # :state = 2 # Next state # break :stateLoop # Proceed to the next state -# STATE1: +# of 1: # dec a # :state = 0 # Next state # break :stateLoop # Proceed to the next state -# STATE2: +# of 2: # :state = -1 # End of execution - -# The transformation should play well with lambdalifting, however depending -# on situation, it can be called either before or after lambdalifting -# transformation. As such we behave slightly differently, when accessing -# iterator state, or using temp variables. If lambdalifting did not happen, -# we just create local variables, so that they will be lifted further on. -# Otherwise, we utilize existing env, created by lambdalifting. +# else: +# return # Lambdalifting treats :state variable specially, it should always end up # as the first field in env. Currently C codegen depends on this behavior. @@ -104,12 +100,13 @@ # Is transformed to (yields are left in place for example simplicity, # in reality the code is subdivided even more, as described above): # -# STATE0: # Try +# case :state +# of 0: # Try # yield 0 # raise ... # :state = 2 # What would happen should we not raise # break :stateLoop -# STATE1: # Except +# of 1: # Except # yield 1 # :tmpResult = 3 # Return # :unrollFinally = true # Return @@ -117,7 +114,7 @@ # break :stateLoop # :state = 2 # What would happen should we not return # break :stateLoop -# STATE2: # Finally +# of 2: # Finally # yield 2 # if :unrollFinally: # This node is created by `newEndFinallyNode` # if :curExc.isNil: @@ -130,6 +127,8 @@ # raise # state = -1 # Goto next state. In this case we just exit # break :stateLoop +# else: +# return import ast, msgs, idents, @@ -145,12 +144,11 @@ type Ctx = object g: ModuleGraph fn: PSym - stateVarSym: PSym # :state variable. nil if env already introduced by lambdalifting tmpResultSym: PSym # Used when we return, but finally has to interfere unrollFinallySym: PSym # Indicates that we're unrolling finally states (either exception happened or premature return) curExcSym: PSym # Current exception - states: seq[PNode] # The resulting states. Every state is an nkState node. + states: seq[tuple[label: int, body: PNode]] # The resulting states. blockLevel: int # Temp used to transform break and continue stmts stateLoopLabel: PSym # Label to break on, when jumping between states. exitStateIdx: int # index of the last state @@ -162,15 +160,23 @@ type nearestFinally: int # Index of the nearest finally block. For try/except it # is their finally. For finally it is parent finally. Otherwise -1 idgen: IdGenerator + varStates: Table[ItemId, int] # Used to detect if local variable belongs to multiple states + stateVarSym: PSym # :state variable. nil if env already introduced by lambdalifting + # remove if -d:nimOptIters is default, treating it as always nil + nimOptItersEnabled: bool # tracks if -d:nimOptIters is enabled + # should be default when issues are fixed, see #24094 const nkSkip = {nkEmpty..nkNilLit, nkTemplateDef, nkTypeSection, nkStaticStmt, nkCommentStmt, nkMixinStmt, nkBindStmt} + procDefs + emptyStateLabel = -1 + localNotSeen = -1 + localRequiresLifting = -2 proc newStateAccess(ctx: var Ctx): PNode = if ctx.stateVarSym.isNil: result = rawIndirectAccess(newSymNode(getEnvParam(ctx.fn)), - getStateField(ctx.g, ctx.fn), ctx.fn.info) + getStateField(ctx.g, ctx.fn), ctx.fn.info) else: result = newSymNode(ctx.stateVarSym) @@ -187,7 +193,8 @@ proc newStateAssgn(ctx: var Ctx, stateNo: int = -2): PNode = proc newEnvVar(ctx: var Ctx, name: string, typ: PType): PSym = result = newSym(skVar, getIdent(ctx.g.cache, name), ctx.idgen, ctx.fn, ctx.fn.info) result.typ = typ - assert(not typ.isNil) + result.flags.incl sfNoInit + assert(not typ.isNil, "Env var needs a type") if not ctx.stateVarSym.isNil: # We haven't gone through labmda lifting yet, so just create a local var, @@ -198,7 +205,7 @@ proc newEnvVar(ctx: var Ctx, name: string, typ: PType): PSym = else: let envParam = getEnvParam(ctx.fn) # let obj = envParam.typ.lastSon - result = addUniqueField(envParam.typ.lastSon, result, ctx.g.cache, ctx.idgen) + result = addUniqueField(envParam.typ.elementType, result, ctx.g.cache, ctx.idgen) proc newEnvVarAccess(ctx: Ctx, s: PSym): PNode = if ctx.stateVarSym.isNil: @@ -206,9 +213,12 @@ proc newEnvVarAccess(ctx: Ctx, s: PSym): PNode = else: result = newSymNode(s) +proc newTempVarAccess(ctx: Ctx, s: PSym): PNode = + result = newSymNode(s, ctx.fn.info) + proc newTmpResultAccess(ctx: var Ctx): PNode = if ctx.tmpResultSym.isNil: - ctx.tmpResultSym = ctx.newEnvVar(":tmpResult", ctx.fn.typ[0]) + ctx.tmpResultSym = ctx.newEnvVar(":tmpResult", ctx.fn.typ.returnType) ctx.newEnvVarAccess(ctx.tmpResultSym) proc newUnrollFinallyAccess(ctx: var Ctx, info: TLineInfo): PNode = @@ -228,10 +238,7 @@ proc newState(ctx: var Ctx, n, gotoOut: PNode): int = result = ctx.states.len let resLit = ctx.g.newIntLit(n.info, result) - let s = newNodeI(nkState, n.info) - s.add(resLit) - s.add(n) - ctx.states.add(s) + ctx.states.add((result, n)) ctx.exceptionTable.add(ctx.curExcHandlingState) if not gotoOut.isNil: @@ -250,9 +257,26 @@ proc addGotoOut(n: PNode, gotoOut: PNode): PNode = if result.len == 0 or result[^1].kind != nkGotoState: result.add(gotoOut) -proc newTempVar(ctx: var Ctx, typ: PType): PSym = - result = ctx.newEnvVar(":tmpSlLower" & $ctx.tempVarId, typ) +proc newTempVarDef(ctx: Ctx, s: PSym, initialValue: PNode): PNode = + var v = initialValue + if v == nil: + v = ctx.g.emptyNode + newTree(nkVarSection, newTree(nkIdentDefs, newSymNode(s), ctx.g.emptyNode, v)) + +proc newEnvVarAsgn(ctx: Ctx, s: PSym, v: PNode): PNode + +proc newTempVar(ctx: var Ctx, typ: PType, parent: PNode, initialValue: PNode = nil): PSym = + if ctx.nimOptItersEnabled: + result = newSym(skVar, getIdent(ctx.g.cache, ":tmpSlLower" & $ctx.tempVarId), ctx.idgen, ctx.fn, ctx.fn.info) + else: + result = ctx.newEnvVar(":tmpSlLower" & $ctx.tempVarId, typ) inc ctx.tempVarId + result.typ = typ + assert(not typ.isNil, "Temp var needs a type") + if ctx.nimOptItersEnabled: + parent.add(ctx.newTempVarDef(result, initialValue)) + elif initialValue != nil: + parent.add(ctx.newEnvVarAsgn(result, initialValue)) proc hasYields(n: PNode): bool = # TODO: This is very inefficient. It traverses the node, looking for nkYieldStmt. @@ -263,8 +287,8 @@ proc hasYields(n: PNode): bool = result = false else: result = false - for c in n: - if c.hasYields: + for i in ord(n.kind == nkCast)..<n.len: + if n[i].hasYields: result = true break @@ -413,7 +437,7 @@ proc hasYieldsInExpressions(n: PNode): bool = proc exprToStmtList(n: PNode): tuple[s, res: PNode] = assert(n.kind == nkStmtListExpr) - result.s = newNodeI(nkStmtList, n.info) + result = (newNodeI(nkStmtList, n.info), nil) result.s.sons = @[] var n = n @@ -424,8 +448,15 @@ proc exprToStmtList(n: PNode): tuple[s, res: PNode] = result.res = n +proc newTempVarAsgn(ctx: Ctx, s: PSym, v: PNode): PNode = + if isEmptyType(v.typ): + result = v + else: + result = newTree(nkFastAsgn, ctx.newTempVarAccess(s), v) + result.info = v.info proc newEnvVarAsgn(ctx: Ctx, s: PSym, v: PNode): PNode = + # unused with -d:nimOptIters if isEmptyType(v.typ): result = v else: @@ -433,10 +464,13 @@ proc newEnvVarAsgn(ctx: Ctx, s: PSym, v: PNode): PNode = result.info = v.info proc addExprAssgn(ctx: Ctx, output, input: PNode, sym: PSym) = + var input = input if input.kind == nkStmtListExpr: let (st, res) = exprToStmtList(input) output.add(st) - output.add(ctx.newEnvVarAsgn(sym, res)) + input = res + if ctx.nimOptItersEnabled: + output.add(ctx.newTempVarAsgn(sym, input)) else: output.add(ctx.newEnvVarAsgn(sym, input)) @@ -448,6 +482,16 @@ proc newNotCall(g: ModuleGraph; e: PNode): PNode = result = newTree(nkCall, newSymNode(g.getSysMagic(e.info, "not", mNot), e.info), e) result.typ = g.getSysType(e.info, tyBool) +proc boolLit(g: ModuleGraph; info: TLineInfo; value: bool): PNode = + result = newIntLit(g, info, ord value) + result.typ = getSysType(g, info, tyBool) + +proc captureVar(c: var Ctx, s: PSym) = + if c.varStates.getOrDefault(s.itemId) != localRequiresLifting: + c.varStates[s.itemId] = localRequiresLifting # Mark this variable for lifting + let e = getEnvParam(c.fn) + discard addField(e.typ.elementType, s, c.g.cache, c.idgen) + proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode = result = n case n.kind @@ -504,9 +548,9 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode = var tmp: PSym = nil let isExpr = not isEmptyType(n.typ) if isExpr: - tmp = ctx.newTempVar(n.typ) result = newNodeI(nkStmtListExpr, n.info) result.typ = n.typ + tmp = ctx.newTempVar(n.typ, result) else: result = newNodeI(nkStmtList, n.info) @@ -557,7 +601,11 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode = else: internalError(ctx.g.config, "lowerStmtListExpr(nkIf): " & $branch.kind) - if isExpr: result.add(ctx.newEnvVarAccess(tmp)) + if isExpr: + if ctx.nimOptItersEnabled: + result.add(ctx.newTempVarAccess(tmp)) + else: + result.add(ctx.newEnvVarAccess(tmp)) of nkTryStmt, nkHiddenTryStmt: var ns = false @@ -571,7 +619,7 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode = if isExpr: result = newNodeI(nkStmtListExpr, n.info) result.typ = n.typ - let tmp = ctx.newTempVar(n.typ) + let tmp = ctx.newTempVar(n.typ, result) n[0] = ctx.convertExprBodyToAsgn(n[0], tmp) for i in 1..<n.len: @@ -587,7 +635,10 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode = else: internalError(ctx.g.config, "lowerStmtListExpr(nkTryStmt): " & $branch.kind) result.add(n) - result.add(ctx.newEnvVarAccess(tmp)) + if ctx.nimOptItersEnabled: + result.add(ctx.newTempVarAccess(tmp)) + else: + result.add(ctx.newEnvVarAccess(tmp)) of nkCaseStmt: var ns = false @@ -600,9 +651,9 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode = let isExpr = not isEmptyType(n.typ) if isExpr: - let tmp = ctx.newTempVar(n.typ) result = newNodeI(nkStmtListExpr, n.info) result.typ = n.typ + let tmp = ctx.newTempVar(n.typ, result) if n[0].kind == nkStmtListExpr: let (st, ex) = exprToStmtList(n[0]) @@ -619,7 +670,10 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode = else: internalError(ctx.g.config, "lowerStmtListExpr(nkCaseStmt): " & $branch.kind) result.add(n) - result.add(ctx.newEnvVarAccess(tmp)) + if ctx.nimOptItersEnabled: + result.add(ctx.newTempVarAccess(tmp)) + else: + result.add(ctx.newEnvVarAccess(tmp)) elif n[0].kind == nkStmtListExpr: result = newNodeI(nkStmtList, n.info) let (st, ex) = exprToStmtList(n[0]) @@ -649,10 +703,14 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode = result.add(st) cond = ex - let tmp = ctx.newTempVar(cond.typ) - result.add(ctx.newEnvVarAsgn(tmp, cond)) + let tmp = ctx.newTempVar(cond.typ, result, cond) + # result.add(ctx.newTempVarAsgn(tmp, cond)) - var check = ctx.newEnvVarAccess(tmp) + var check: PNode + if ctx.nimOptItersEnabled: + check = ctx.newTempVarAccess(tmp) + else: + check = ctx.newEnvVarAccess(tmp) if n[0].sym.magic == mOr: check = ctx.g.newNotCall(check) @@ -662,12 +720,18 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode = let (st, ex) = exprToStmtList(cond) ifBody.add(st) cond = ex - ifBody.add(ctx.newEnvVarAsgn(tmp, cond)) + if ctx.nimOptItersEnabled: + ifBody.add(ctx.newTempVarAsgn(tmp, cond)) + else: + ifBody.add(ctx.newEnvVarAsgn(tmp, cond)) let ifBranch = newTree(nkElifBranch, check, ifBody) let ifNode = newTree(nkIfStmt, ifBranch) result.add(ifNode) - result.add(ctx.newEnvVarAccess(tmp)) + if ctx.nimOptItersEnabled: + result.add(ctx.newTempVarAccess(tmp)) + else: + result.add(ctx.newEnvVarAccess(tmp)) else: for i in 0..<n.len: if n[i].kind == nkStmtListExpr: @@ -676,9 +740,12 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode = n[i] = ex if n[i].kind in nkCallKinds: # XXX: This should better be some sort of side effect tracking - let tmp = ctx.newTempVar(n[i].typ) - result.add(ctx.newEnvVarAsgn(tmp, n[i])) - n[i] = ctx.newEnvVarAccess(tmp) + let tmp = ctx.newTempVar(n[i].typ, result, n[i]) + # result.add(ctx.newTempVarAsgn(tmp, n[i])) + if ctx.nimOptItersEnabled: + n[i] = ctx.newTempVarAccess(tmp) + else: + n[i] = ctx.newEnvVarAccess(tmp) result.add(n) @@ -694,6 +761,12 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode = let (st, ex) = exprToStmtList(c[^1]) result.add(st) c[^1] = ex + for i in 0 .. c.len - 3: + if c[i].kind == nkSym: + let s = c[i].sym + if sfForceLift in s.flags: + ctx.captureVar(s) + result.add(varSect) of nkDiscardStmt, nkReturnStmt, nkRaiseStmt: @@ -779,7 +852,7 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode = let check = newTree(nkIfStmt, branch) let newBody = newTree(nkStmtList, st, check, n[1]) - n[0] = newSymNode(ctx.g.getSysSym(n[0].info, "true")) + n[0] = ctx.g.boolLit(n[0].info, true) n[1] = newBody of nkDotExpr, nkCheckedFieldExpr: @@ -831,7 +904,7 @@ proc newEndFinallyNode(ctx: var Ctx, info: TLineInfo): PNode = let retStmt = if ctx.nearestFinally == 0: # last finally, we can return - let retValue = if ctx.fn.typ[0].isNil: + let retValue = if ctx.fn.typ.returnType.isNil: ctx.g.emptyNode else: newTree(nkFastAsgn, @@ -1133,10 +1206,10 @@ proc skipEmptyStates(ctx: Ctx, stateIdx: int): int = let label = stateIdx if label == ctx.exitStateIdx: break var newLabel = label - if label == -1: + if label == emptyStateLabel: newLabel = ctx.exitStateIdx else: - let fs = skipStmtList(ctx, ctx.states[label][1]) + let fs = skipStmtList(ctx, ctx.states[label].body) if fs.kind == nkGotoState: newLabel = fs[0].intVal.int if label == newLabel: break @@ -1145,7 +1218,7 @@ proc skipEmptyStates(ctx: Ctx, stateIdx: int): int = if maxJumps == 0: assert(false, "Internal error") - result = ctx.states[stateIdx][0].intVal.int + result = ctx.states[stateIdx].label proc skipThroughEmptyStates(ctx: var Ctx, n: PNode): PNode= result = n @@ -1263,11 +1336,10 @@ proc wrapIntoTryExcept(ctx: var Ctx, n: PNode): PNode {.inline.} = proc wrapIntoStateLoop(ctx: var Ctx, n: PNode): PNode = # while true: # block :stateLoop: - # gotoState :state # local vars decl (if needed) # body # Might get wrapped in try-except let loopBody = newNodeI(nkStmtList, n.info) - result = newTree(nkWhileStmt, newSymNode(ctx.g.getSysSym(n.info, "true")), loopBody) + result = newTree(nkWhileStmt, ctx.g.boolLit(n.info, true), loopBody) result.info = n.info let localVars = newNodeI(nkStmtList, n.info) @@ -1282,11 +1354,7 @@ proc wrapIntoStateLoop(ctx: var Ctx, n: PNode): PNode = let blockStmt = newNodeI(nkBlockStmt, n.info) blockStmt.add(newSymNode(ctx.stateLoopLabel)) - let gs = newNodeI(nkGotoState, n.info) - gs.add(ctx.newStateAccess()) - gs.add(ctx.g.newIntLit(n.info, ctx.states.len - 1)) - - var blockBody = newTree(nkStmtList, gs, localVars, n) + var blockBody = newTree(nkStmtList, localVars, n) if ctx.hasExceptions: blockBody = ctx.wrapIntoTryExcept(blockBody) @@ -1299,29 +1367,28 @@ proc deleteEmptyStates(ctx: var Ctx) = # Apply new state indexes and mark unused states with -1 var iValid = 0 - for i, s in ctx.states: - let body = skipStmtList(ctx, s[1]) + for i, s in ctx.states.mpairs: + let body = skipStmtList(ctx, s.body) if body.kind == nkGotoState and i != ctx.states.len - 1 and i != 0: # This is an empty state. Mark with -1. - s[0].intVal = -1 + s.label = emptyStateLabel else: - s[0].intVal = iValid + s.label = iValid inc iValid for i, s in ctx.states: - let body = skipStmtList(ctx, s[1]) + let body = skipStmtList(ctx, s.body) if body.kind != nkGotoState or i == 0: - discard ctx.skipThroughEmptyStates(s) + discard ctx.skipThroughEmptyStates(s.body) let excHandlState = ctx.exceptionTable[i] if excHandlState < 0: ctx.exceptionTable[i] = -ctx.skipEmptyStates(-excHandlState) elif excHandlState != 0: ctx.exceptionTable[i] = ctx.skipEmptyStates(excHandlState) - var i = 0 + var i = 1 # ignore the entry and the exit while i < ctx.states.len - 1: - let fs = skipStmtList(ctx, ctx.states[i][1]) - if fs.kind == nkGotoState and i != 0: + if ctx.states[i].label == emptyStateLabel: ctx.states.delete(i) ctx.exceptionTable.delete(i) else: @@ -1430,18 +1497,77 @@ proc preprocess(c: var PreprocessContext; n: PNode): PNode = for i in 0 ..< n.len: result[i] = preprocess(c, n[i]) +proc detectCapturedVars(c: var Ctx, n: PNode, stateIdx: int) = + case n.kind + of nkSym: + let s = n.sym + if s.kind in {skResult, skVar, skLet, skForVar, skTemp} and sfGlobal notin s.flags and s.owner == c.fn: + let vs = c.varStates.getOrDefault(s.itemId, localNotSeen) + if vs == localNotSeen: # First seing this variable + c.varStates[s.itemId] = stateIdx + elif vs == localRequiresLifting: + discard # Sym already marked + elif vs != stateIdx: + c.captureVar(s) + of nkReturnStmt: + if n[0].kind in {nkAsgn, nkFastAsgn, nkSinkAsgn}: + # we have a `result = result` expression produced by the closure + # transform, let's not touch the LHS in order to make the lifting pass + # correct when `result` is lifted + detectCapturedVars(c, n[0][1], stateIdx) + else: + detectCapturedVars(c, n[0], stateIdx) + else: + for i in 0 ..< n.safeLen: + detectCapturedVars(c, n[i], stateIdx) + +proc detectCapturedVars(c: var Ctx) = + for i, s in c.states: + detectCapturedVars(c, s.body, i) + +proc liftLocals(c: var Ctx, n: PNode): PNode = + result = n + case n.kind + of nkSym: + let s = n.sym + if c.varStates.getOrDefault(s.itemId) == localRequiresLifting: + # lift + let e = getEnvParam(c.fn) + let field = getFieldFromObj(e.typ.elementType, s) + assert(field != nil) + result = rawIndirectAccess(newSymNode(e), field, n.info) + # elif c.varStates.getOrDefault(s.itemId, localNotSeen) != localNotSeen: + # echo "Not lifting ", s.name.s + + of nkReturnStmt: + if n[0].kind in {nkAsgn, nkFastAsgn, nkSinkAsgn}: + # we have a `result = result` expression produced by the closure + # transform, let's not touch the LHS in order to make the lifting pass + # correct when `result` is lifted + n[0][1] = liftLocals(c, n[0][1]) + else: + n[0] = liftLocals(c, n[0]) + else: + for i in 0 ..< n.safeLen: + n[i] = liftLocals(c, n[i]) + proc transformClosureIterator*(g: ModuleGraph; idgen: IdGenerator; fn: PSym, n: PNode): PNode = - var ctx: Ctx - ctx.g = g - ctx.fn = fn - ctx.idgen = idgen + var ctx = Ctx(g: g, fn: fn, idgen: idgen, + # should be default when issues are fixed, see #24094: + nimOptItersEnabled: isDefined(g.config, "nimOptIters")) if getEnvParam(fn).isNil: - # Lambda lifting was not done yet. Use temporary :state sym, which will - # be handled specially by lambda lifting. Local temp vars (if needed) - # should follow the same logic. - ctx.stateVarSym = newSym(skVar, getIdent(ctx.g.cache, ":state"), idgen, fn, fn.info) - ctx.stateVarSym.typ = g.createClosureIterStateType(fn, idgen) + if ctx.nimOptItersEnabled: + # The transformation should always happen after at least partial lambdalifting + # is performed, so that the closure iter environment is always created upfront. + doAssert(false, "Env param not created before iter transformation") + else: + # Lambda lifting was not done yet. Use temporary :state sym, which will + # be handled specially by lambda lifting. Local temp vars (if needed) + # should follow the same logic. + ctx.stateVarSym = newSym(skVar, getIdent(ctx.g.cache, ":state"), idgen, fn, fn.info) + ctx.stateVarSym.typ = g.createClosureIterStateType(fn, idgen) + ctx.stateLoopLabel = newSym(skLabel, getIdent(ctx.g.cache, ":stateLoop"), idgen, fn, fn.info) var pc = PreprocessContext(finallys: @[], config: g.config, idgen: idgen) var n = preprocess(pc, n.toStmtList) @@ -1463,17 +1589,23 @@ proc transformClosureIterator*(g: ModuleGraph; idgen: IdGenerator; fn: PSym, n: # Optimize empty states away ctx.deleteEmptyStates() - # Make new body by concatenating the list of states - result = newNodeI(nkStmtList, n.info) + let caseDispatcher = newTreeI(nkCaseStmt, n.info, + ctx.newStateAccess()) + + if ctx.nimOptItersEnabled: + # Lamdalifting will not touch our locals, it is our responsibility to lift those that + # need it. + detectCapturedVars(ctx) + for s in ctx.states: - assert(s.len == 2) - let body = s[1] - s.sons.del(1) - result.add(s) - result.add(body) + let body = ctx.transformStateAssignments(s.body) + caseDispatcher.add newTreeI(nkOfBranch, body.info, g.newIntLit(body.info, s.label), body) + + caseDispatcher.add newTreeI(nkElse, n.info, newTreeI(nkReturnStmt, n.info, g.emptyNode)) - result = ctx.transformStateAssignments(result) - result = ctx.wrapIntoStateLoop(result) + result = wrapIntoStateLoop(ctx, caseDispatcher) + if ctx.nimOptItersEnabled: + result = liftLocals(ctx, result) when false: echo "TRANSFORM TO STATES: " @@ -1482,3 +1614,5 @@ proc transformClosureIterator*(g: ModuleGraph; idgen: IdGenerator; fn: PSym, n: echo "exception table:" for i, e in ctx.exceptionTable: echo i, " -> ", e + + echo "ENV: ", renderTree(getEnvParam(fn).typ.elementType.n) |