diff options
Diffstat (limited to 'compiler/closureiters.nim')
-rw-r--r-- | compiler/closureiters.nim | 405 |
1 files changed, 293 insertions, 112 deletions
diff --git a/compiler/closureiters.nim b/compiler/closureiters.nim index e474e5ba2..8bdd04ca7 100644 --- a/compiler/closureiters.nim +++ b/compiler/closureiters.nim @@ -18,7 +18,8 @@ # dec a # # Should be transformed to: -# STATE0: +# case :state +# of 0: # if a > 0: # echo "hi" # :state = 1 # Next state @@ -26,19 +27,14 @@ # else: # :state = 2 # Next state # break :stateLoop # Proceed to the next state -# STATE1: +# of 1: # dec a # :state = 0 # Next state # break :stateLoop # Proceed to the next state -# STATE2: +# of 2: # :state = -1 # End of execution - -# The transformation should play well with lambdalifting, however depending -# on situation, it can be called either before or after lambdalifting -# transformation. As such we behave slightly differently, when accessing -# iterator state, or using temp variables. If lambdalifting did not happen, -# we just create local variables, so that they will be lifted further on. -# Otherwise, we utilize existing env, created by lambdalifting. +# else: +# return # Lambdalifting treats :state variable specially, it should always end up # as the first field in env. Currently C codegen depends on this behavior. @@ -104,12 +100,13 @@ # Is transformed to (yields are left in place for example simplicity, # in reality the code is subdivided even more, as described above): # -# STATE0: # Try +# case :state +# of 0: # Try # yield 0 # raise ... # :state = 2 # What would happen should we not raise # break :stateLoop -# STATE1: # Except +# of 1: # Except # yield 1 # :tmpResult = 3 # Return # :unrollFinally = true # Return @@ -117,32 +114,41 @@ # break :stateLoop # :state = 2 # What would happen should we not return # break :stateLoop -# STATE2: # Finally +# of 2: # Finally # yield 2 # if :unrollFinally: # This node is created by `newEndFinallyNode` # if :curExc.isNil: -# return :tmpResult +# if nearestFinally == 0: +# return :tmpResult +# else: +# :state = nearestFinally # bubble up # else: # closureIterSetupExc(nil) # raise # state = -1 # Goto next state. In this case we just exit # break :stateLoop +# else: +# return import ast, msgs, idents, renderer, magicsys, lowerings, lambdalifting, modulegraphs, lineinfos, - tables, options + options + +import std/tables + +when defined(nimPreviewSlimSystem): + import std/assertions type Ctx = object g: ModuleGraph fn: PSym - stateVarSym: PSym # :state variable. nil if env already introduced by lambdalifting tmpResultSym: PSym # Used when we return, but finally has to interfere unrollFinallySym: PSym # Indicates that we're unrolling finally states (either exception happened or premature return) curExcSym: PSym # Current exception - states: seq[PNode] # The resulting states. Every state is an nkState node. + states: seq[tuple[label: int, body: PNode]] # The resulting states. blockLevel: int # Temp used to transform break and continue stmts stateLoopLabel: PSym # Label to break on, when jumping between states. exitStateIdx: int # index of the last state @@ -154,15 +160,23 @@ type nearestFinally: int # Index of the nearest finally block. For try/except it # is their finally. For finally it is parent finally. Otherwise -1 idgen: IdGenerator + varStates: Table[ItemId, int] # Used to detect if local variable belongs to multiple states + stateVarSym: PSym # :state variable. nil if env already introduced by lambdalifting + # remove if -d:nimOptIters is default, treating it as always nil + nimOptItersEnabled: bool # tracks if -d:nimOptIters is enabled + # should be default when issues are fixed, see #24094 const nkSkip = {nkEmpty..nkNilLit, nkTemplateDef, nkTypeSection, nkStaticStmt, nkCommentStmt, nkMixinStmt, nkBindStmt} + procDefs + emptyStateLabel = -1 + localNotSeen = -1 + localRequiresLifting = -2 proc newStateAccess(ctx: var Ctx): PNode = if ctx.stateVarSym.isNil: result = rawIndirectAccess(newSymNode(getEnvParam(ctx.fn)), - getStateField(ctx.g, ctx.fn), ctx.fn.info) + getStateField(ctx.g, ctx.fn), ctx.fn.info) else: result = newSymNode(ctx.stateVarSym) @@ -177,9 +191,10 @@ proc newStateAssgn(ctx: var Ctx, stateNo: int = -2): PNode = ctx.newStateAssgn(newIntTypeNode(stateNo, ctx.g.getSysType(TLineInfo(), tyInt))) proc newEnvVar(ctx: var Ctx, name: string, typ: PType): PSym = - result = newSym(skVar, getIdent(ctx.g.cache, name), nextSymId(ctx.idgen), ctx.fn, ctx.fn.info) + result = newSym(skVar, getIdent(ctx.g.cache, name), ctx.idgen, ctx.fn, ctx.fn.info) result.typ = typ - assert(not typ.isNil) + result.flags.incl sfNoInit + assert(not typ.isNil, "Env var needs a type") if not ctx.stateVarSym.isNil: # We haven't gone through labmda lifting yet, so just create a local var, @@ -190,7 +205,7 @@ proc newEnvVar(ctx: var Ctx, name: string, typ: PType): PSym = else: let envParam = getEnvParam(ctx.fn) # let obj = envParam.typ.lastSon - result = addUniqueField(envParam.typ.lastSon, result, ctx.g.cache, ctx.idgen) + result = addUniqueField(envParam.typ.elementType, result, ctx.g.cache, ctx.idgen) proc newEnvVarAccess(ctx: Ctx, s: PSym): PNode = if ctx.stateVarSym.isNil: @@ -198,9 +213,12 @@ proc newEnvVarAccess(ctx: Ctx, s: PSym): PNode = else: result = newSymNode(s) +proc newTempVarAccess(ctx: Ctx, s: PSym): PNode = + result = newSymNode(s, ctx.fn.info) + proc newTmpResultAccess(ctx: var Ctx): PNode = if ctx.tmpResultSym.isNil: - ctx.tmpResultSym = ctx.newEnvVar(":tmpResult", ctx.fn.typ[0]) + ctx.tmpResultSym = ctx.newEnvVar(":tmpResult", ctx.fn.typ.returnType) ctx.newEnvVarAccess(ctx.tmpResultSym) proc newUnrollFinallyAccess(ctx: var Ctx, info: TLineInfo): PNode = @@ -220,10 +238,7 @@ proc newState(ctx: var Ctx, n, gotoOut: PNode): int = result = ctx.states.len let resLit = ctx.g.newIntLit(n.info, result) - let s = newNodeI(nkState, n.info) - s.add(resLit) - s.add(n) - ctx.states.add(s) + ctx.states.add((result, n)) ctx.exceptionTable.add(ctx.curExcHandlingState) if not gotoOut.isNil: @@ -242,9 +257,26 @@ proc addGotoOut(n: PNode, gotoOut: PNode): PNode = if result.len == 0 or result[^1].kind != nkGotoState: result.add(gotoOut) -proc newTempVar(ctx: var Ctx, typ: PType): PSym = - result = ctx.newEnvVar(":tmpSlLower" & $ctx.tempVarId, typ) +proc newTempVarDef(ctx: Ctx, s: PSym, initialValue: PNode): PNode = + var v = initialValue + if v == nil: + v = ctx.g.emptyNode + newTree(nkVarSection, newTree(nkIdentDefs, newSymNode(s), ctx.g.emptyNode, v)) + +proc newEnvVarAsgn(ctx: Ctx, s: PSym, v: PNode): PNode + +proc newTempVar(ctx: var Ctx, typ: PType, parent: PNode, initialValue: PNode = nil): PSym = + if ctx.nimOptItersEnabled: + result = newSym(skVar, getIdent(ctx.g.cache, ":tmpSlLower" & $ctx.tempVarId), ctx.idgen, ctx.fn, ctx.fn.info) + else: + result = ctx.newEnvVar(":tmpSlLower" & $ctx.tempVarId, typ) inc ctx.tempVarId + result.typ = typ + assert(not typ.isNil, "Temp var needs a type") + if ctx.nimOptItersEnabled: + parent.add(ctx.newTempVarDef(result, initialValue)) + elif initialValue != nil: + parent.add(ctx.newEnvVarAsgn(result, initialValue)) proc hasYields(n: PNode): bool = # TODO: This is very inefficient. It traverses the node, looking for nkYieldStmt. @@ -252,10 +284,11 @@ proc hasYields(n: PNode): bool = of nkYieldStmt: result = true of nkSkip: - discard + result = false else: - for c in n: - if c.hasYields: + result = false + for i in ord(n.kind == nkCast)..<n.len: + if n[i].hasYields: result = true break @@ -319,7 +352,7 @@ proc collectExceptState(ctx: var Ctx, n: PNode): PNode {.inline.} = var ifBranch: PNode if c.len > 1: - var cond: PNode + var cond: PNode = nil for i in 0..<c.len - 1: assert(c[i].kind == nkType) let nextCond = newTree(nkCall, @@ -382,26 +415,29 @@ proc getFinallyNode(ctx: var Ctx, n: PNode): PNode = proc hasYieldsInExpressions(n: PNode): bool = case n.kind of nkSkip: - discard + result = false of nkStmtListExpr: if isEmptyType(n.typ): + result = false for c in n: if c.hasYieldsInExpressions: return true else: result = n.hasYields of nkCast: + result = false for i in 1..<n.len: if n[i].hasYieldsInExpressions: return true else: + result = false for c in n: if c.hasYieldsInExpressions: return true proc exprToStmtList(n: PNode): tuple[s, res: PNode] = assert(n.kind == nkStmtListExpr) - result.s = newNodeI(nkStmtList, n.info) + result = (newNodeI(nkStmtList, n.info), nil) result.s.sons = @[] var n = n @@ -412,28 +448,50 @@ proc exprToStmtList(n: PNode): tuple[s, res: PNode] = result.res = n +proc newTempVarAsgn(ctx: Ctx, s: PSym, v: PNode): PNode = + if isEmptyType(v.typ): + result = v + else: + result = newTree(nkFastAsgn, ctx.newTempVarAccess(s), v) + result.info = v.info proc newEnvVarAsgn(ctx: Ctx, s: PSym, v: PNode): PNode = - result = newTree(nkFastAsgn, ctx.newEnvVarAccess(s), v) - result.info = v.info + # unused with -d:nimOptIters + if isEmptyType(v.typ): + result = v + else: + result = newTree(nkFastAsgn, ctx.newEnvVarAccess(s), v) + result.info = v.info proc addExprAssgn(ctx: Ctx, output, input: PNode, sym: PSym) = + var input = input if input.kind == nkStmtListExpr: let (st, res) = exprToStmtList(input) output.add(st) - output.add(ctx.newEnvVarAsgn(sym, res)) + input = res + if ctx.nimOptItersEnabled: + output.add(ctx.newTempVarAsgn(sym, input)) else: output.add(ctx.newEnvVarAsgn(sym, input)) proc convertExprBodyToAsgn(ctx: Ctx, exprBody: PNode, res: PSym): PNode = result = newNodeI(nkStmtList, exprBody.info) - if exprBody.typ != nil: - ctx.addExprAssgn(result, exprBody, res) + ctx.addExprAssgn(result, exprBody, res) proc newNotCall(g: ModuleGraph; e: PNode): PNode = result = newTree(nkCall, newSymNode(g.getSysMagic(e.info, "not", mNot), e.info), e) result.typ = g.getSysType(e.info, tyBool) +proc boolLit(g: ModuleGraph; info: TLineInfo; value: bool): PNode = + result = newIntLit(g, info, ord value) + result.typ = getSysType(g, info, tyBool) + +proc captureVar(c: var Ctx, s: PSym) = + if c.varStates.getOrDefault(s.itemId) != localRequiresLifting: + c.varStates[s.itemId] = localRequiresLifting # Mark this variable for lifting + let e = getEnvParam(c.fn) + discard addField(e.typ.elementType, s, c.g.cache, c.idgen) + proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode = result = n case n.kind @@ -487,12 +545,12 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode = if ns: needsSplit = true - var tmp: PSym + var tmp: PSym = nil let isExpr = not isEmptyType(n.typ) if isExpr: - tmp = ctx.newTempVar(n.typ) result = newNodeI(nkStmtListExpr, n.info) result.typ = n.typ + tmp = ctx.newTempVar(n.typ, result) else: result = newNodeI(nkStmtList, n.info) @@ -543,7 +601,11 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode = else: internalError(ctx.g.config, "lowerStmtListExpr(nkIf): " & $branch.kind) - if isExpr: result.add(ctx.newEnvVarAccess(tmp)) + if isExpr: + if ctx.nimOptItersEnabled: + result.add(ctx.newTempVarAccess(tmp)) + else: + result.add(ctx.newEnvVarAccess(tmp)) of nkTryStmt, nkHiddenTryStmt: var ns = false @@ -557,7 +619,7 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode = if isExpr: result = newNodeI(nkStmtListExpr, n.info) result.typ = n.typ - let tmp = ctx.newTempVar(n.typ) + let tmp = ctx.newTempVar(n.typ, result) n[0] = ctx.convertExprBodyToAsgn(n[0], tmp) for i in 1..<n.len: @@ -573,7 +635,10 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode = else: internalError(ctx.g.config, "lowerStmtListExpr(nkTryStmt): " & $branch.kind) result.add(n) - result.add(ctx.newEnvVarAccess(tmp)) + if ctx.nimOptItersEnabled: + result.add(ctx.newTempVarAccess(tmp)) + else: + result.add(ctx.newEnvVarAccess(tmp)) of nkCaseStmt: var ns = false @@ -586,9 +651,9 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode = let isExpr = not isEmptyType(n.typ) if isExpr: - let tmp = ctx.newTempVar(n.typ) result = newNodeI(nkStmtListExpr, n.info) result.typ = n.typ + let tmp = ctx.newTempVar(n.typ, result) if n[0].kind == nkStmtListExpr: let (st, ex) = exprToStmtList(n[0]) @@ -605,7 +670,16 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode = else: internalError(ctx.g.config, "lowerStmtListExpr(nkCaseStmt): " & $branch.kind) result.add(n) - result.add(ctx.newEnvVarAccess(tmp)) + if ctx.nimOptItersEnabled: + result.add(ctx.newTempVarAccess(tmp)) + else: + result.add(ctx.newEnvVarAccess(tmp)) + elif n[0].kind == nkStmtListExpr: + result = newNodeI(nkStmtList, n.info) + let (st, ex) = exprToStmtList(n[0]) + result.add(st) + n[0] = ex + result.add(n) of nkCallKinds, nkChckRange, nkChckRangeF, nkChckRange64: var ns = false @@ -629,10 +703,14 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode = result.add(st) cond = ex - let tmp = ctx.newTempVar(cond.typ) - result.add(ctx.newEnvVarAsgn(tmp, cond)) + let tmp = ctx.newTempVar(cond.typ, result, cond) + # result.add(ctx.newTempVarAsgn(tmp, cond)) - var check = ctx.newEnvVarAccess(tmp) + var check: PNode + if ctx.nimOptItersEnabled: + check = ctx.newTempVarAccess(tmp) + else: + check = ctx.newEnvVarAccess(tmp) if n[0].sym.magic == mOr: check = ctx.g.newNotCall(check) @@ -642,12 +720,18 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode = let (st, ex) = exprToStmtList(cond) ifBody.add(st) cond = ex - ifBody.add(ctx.newEnvVarAsgn(tmp, cond)) + if ctx.nimOptItersEnabled: + ifBody.add(ctx.newTempVarAsgn(tmp, cond)) + else: + ifBody.add(ctx.newEnvVarAsgn(tmp, cond)) let ifBranch = newTree(nkElifBranch, check, ifBody) let ifNode = newTree(nkIfStmt, ifBranch) result.add(ifNode) - result.add(ctx.newEnvVarAccess(tmp)) + if ctx.nimOptItersEnabled: + result.add(ctx.newTempVarAccess(tmp)) + else: + result.add(ctx.newEnvVarAccess(tmp)) else: for i in 0..<n.len: if n[i].kind == nkStmtListExpr: @@ -656,9 +740,12 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode = n[i] = ex if n[i].kind in nkCallKinds: # XXX: This should better be some sort of side effect tracking - let tmp = ctx.newTempVar(n[i].typ) - result.add(ctx.newEnvVarAsgn(tmp, n[i])) - n[i] = ctx.newEnvVarAccess(tmp) + let tmp = ctx.newTempVar(n[i].typ, result, n[i]) + # result.add(ctx.newTempVarAsgn(tmp, n[i])) + if ctx.nimOptItersEnabled: + n[i] = ctx.newTempVarAccess(tmp) + else: + n[i] = ctx.newEnvVarAccess(tmp) result.add(n) @@ -674,6 +761,12 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode = let (st, ex) = exprToStmtList(c[^1]) result.add(st) c[^1] = ex + for i in 0 .. c.len - 3: + if c[i].kind == nkSym: + let s = c[i].sym + if sfForceLift in s.flags: + ctx.captureVar(s) + result.add(varSect) of nkDiscardStmt, nkReturnStmt, nkRaiseStmt: @@ -704,7 +797,7 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode = n[^1] = ex result.add(n) - of nkAsgn, nkFastAsgn: + of nkAsgn, nkFastAsgn, nkSinkAsgn: var ns = false for i in 0..<n.len: n[i] = ctx.lowerStmtListExprs(n[i], ns) @@ -759,7 +852,7 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode = let check = newTree(nkIfStmt, branch) let newBody = newTree(nkStmtList, st, check, n[1]) - n[0] = newSymNode(ctx.g.getSysSym(n[0].info, "true")) + n[0] = ctx.g.boolLit(n[0].info, true) n[1] = newBody of nkDotExpr, nkCheckedFieldExpr: @@ -796,7 +889,10 @@ proc newEndFinallyNode(ctx: var Ctx, info: TLineInfo): PNode = # Generate the following code: # if :unrollFinally: # if :curExc.isNil: - # return :tmpResult + # if nearestFinally == 0: + # return :tmpResult + # else: + # :state = nearestFinally # bubble up # else: # raise let curExc = ctx.newCurExcAccess() @@ -805,11 +901,20 @@ proc newEndFinallyNode(ctx: var Ctx, info: TLineInfo): PNode = let cmp = newTree(nkCall, newSymNode(ctx.g.getSysMagic(info, "==", mEqRef), info), curExc, nilnode) cmp.typ = ctx.g.getSysType(info, tyBool) - let asgn = newTree(nkFastAsgn, - newSymNode(getClosureIterResult(ctx.g, ctx.fn, ctx.idgen), info), - ctx.newTmpResultAccess()) + let retStmt = + if ctx.nearestFinally == 0: + # last finally, we can return + let retValue = if ctx.fn.typ.returnType.isNil: + ctx.g.emptyNode + else: + newTree(nkFastAsgn, + newSymNode(getClosureIterResult(ctx.g, ctx.fn, ctx.idgen), info), + ctx.newTmpResultAccess()) + newTree(nkReturnStmt, retValue) + else: + # bubble up to next finally + newTree(nkGotoState, ctx.g.newIntLit(info, ctx.nearestFinally)) - let retStmt = newTree(nkReturnStmt, asgn) let branch = newTree(nkElifBranch, cmp, retStmt) let nullifyExc = newTree(nkCall, newSymNode(ctx.g.getCompilerProc("closureIterSetupExc")), nilnode) @@ -829,7 +934,9 @@ proc transformReturnsInTry(ctx: var Ctx, n: PNode): PNode = case n.kind of nkReturnStmt: # We're somewhere in try, transform to finally unrolling - assert(ctx.nearestFinally != 0) + if ctx.nearestFinally == 0: + # return is within the finally + return result = newNodeI(nkStmtList, n.info) @@ -842,7 +949,7 @@ proc transformReturnsInTry(ctx: var Ctx, n: PNode): PNode = if n[0].kind != nkEmpty: let asgnTmpResult = newNodeI(nkAsgn, n.info) asgnTmpResult.add(ctx.newTmpResultAccess()) - let x = if n[0].kind in {nkAsgn, nkFastAsgn}: n[0][1] else: n[0] + let x = if n[0].kind in {nkAsgn, nkFastAsgn, nkSinkAsgn}: n[0][1] else: n[0] asgnTmpResult.add(x) result.add(asgnTmpResult) @@ -853,6 +960,13 @@ proc transformReturnsInTry(ctx: var Ctx, n: PNode): PNode = of nkSkip: discard + of nkTryStmt: + if n.hasYields: + # the inner try will handle these transformations + discard + else: + for i in 0..<n.len: + n[i] = ctx.transformReturnsInTry(n[i]) else: for i in 0..<n.len: n[i] = ctx.transformReturnsInTry(n[i]) @@ -1092,10 +1206,10 @@ proc skipEmptyStates(ctx: Ctx, stateIdx: int): int = let label = stateIdx if label == ctx.exitStateIdx: break var newLabel = label - if label == -1: + if label == emptyStateLabel: newLabel = ctx.exitStateIdx else: - let fs = skipStmtList(ctx, ctx.states[label][1]) + let fs = skipStmtList(ctx, ctx.states[label].body) if fs.kind == nkGotoState: newLabel = fs[0].intVal.int if label == newLabel: break @@ -1104,7 +1218,7 @@ proc skipEmptyStates(ctx: Ctx, stateIdx: int): int = if maxJumps == 0: assert(false, "Internal error") - result = ctx.states[stateIdx][0].intVal.int + result = ctx.states[stateIdx].label proc skipThroughEmptyStates(ctx: var Ctx, n: PNode): PNode= result = n @@ -1119,10 +1233,10 @@ proc skipThroughEmptyStates(ctx: var Ctx, n: PNode): PNode= n[i] = ctx.skipThroughEmptyStates(n[i]) proc newArrayType(g: ModuleGraph; n: int, t: PType; idgen: IdGenerator; owner: PSym): PType = - result = newType(tyArray, nextTypeId(idgen), owner) + result = newType(tyArray, idgen, owner) - let rng = newType(tyRange, nextTypeId(idgen), owner) - rng.n = newTree(nkRange, g.newIntLit(owner.info, 0), g.newIntLit(owner.info, n)) + let rng = newType(tyRange, idgen, owner) + rng.n = newTree(nkRange, g.newIntLit(owner.info, 0), g.newIntLit(owner.info, n - 1)) rng.rawAddSon(t) result.rawAddSon(rng) @@ -1222,11 +1336,10 @@ proc wrapIntoTryExcept(ctx: var Ctx, n: PNode): PNode {.inline.} = proc wrapIntoStateLoop(ctx: var Ctx, n: PNode): PNode = # while true: # block :stateLoop: - # gotoState :state # local vars decl (if needed) # body # Might get wrapped in try-except let loopBody = newNodeI(nkStmtList, n.info) - result = newTree(nkWhileStmt, newSymNode(ctx.g.getSysSym(n.info, "true")), loopBody) + result = newTree(nkWhileStmt, ctx.g.boolLit(n.info, true), loopBody) result.info = n.info let localVars = newNodeI(nkStmtList, n.info) @@ -1241,11 +1354,7 @@ proc wrapIntoStateLoop(ctx: var Ctx, n: PNode): PNode = let blockStmt = newNodeI(nkBlockStmt, n.info) blockStmt.add(newSymNode(ctx.stateLoopLabel)) - let gs = newNodeI(nkGotoState, n.info) - gs.add(ctx.newStateAccess()) - gs.add(ctx.g.newIntLit(n.info, ctx.states.len - 1)) - - var blockBody = newTree(nkStmtList, gs, localVars, n) + var blockBody = newTree(nkStmtList, localVars, n) if ctx.hasExceptions: blockBody = ctx.wrapIntoTryExcept(blockBody) @@ -1258,29 +1367,28 @@ proc deleteEmptyStates(ctx: var Ctx) = # Apply new state indexes and mark unused states with -1 var iValid = 0 - for i, s in ctx.states: - let body = skipStmtList(ctx, s[1]) + for i, s in ctx.states.mpairs: + let body = skipStmtList(ctx, s.body) if body.kind == nkGotoState and i != ctx.states.len - 1 and i != 0: # This is an empty state. Mark with -1. - s[0].intVal = -1 + s.label = emptyStateLabel else: - s[0].intVal = iValid + s.label = iValid inc iValid for i, s in ctx.states: - let body = skipStmtList(ctx, s[1]) + let body = skipStmtList(ctx, s.body) if body.kind != nkGotoState or i == 0: - discard ctx.skipThroughEmptyStates(s) + discard ctx.skipThroughEmptyStates(s.body) let excHandlState = ctx.exceptionTable[i] if excHandlState < 0: ctx.exceptionTable[i] = -ctx.skipEmptyStates(-excHandlState) elif excHandlState != 0: ctx.exceptionTable[i] = ctx.skipEmptyStates(excHandlState) - var i = 0 + var i = 1 # ignore the entry and the exit while i < ctx.states.len - 1: - let fs = skipStmtList(ctx, ctx.states[i][1]) - if fs.kind == nkGotoState and i != 0: + if ctx.states[i].label == emptyStateLabel: ctx.states.delete(i) ctx.exceptionTable.delete(i) else: @@ -1315,7 +1423,7 @@ proc freshVars(n: PNode; c: var FreshVarsContext): PNode = let idefs = copyNode(it) for v in 0..it.len-3: if it[v].kind == nkSym: - let x = copySym(it[v].sym, nextSymId(c.idgen)) + let x = copySym(it[v].sym, c.idgen) c.tab[it[v].sym.id] = x idefs.add newSymNode(x) else: @@ -1326,6 +1434,7 @@ proc freshVars(n: PNode; c: var FreshVarsContext): PNode = else: result.add it of nkRaiseStmt: + result = nil localError(c.config, c.info, "unsupported control flow: 'finally: ... raise' duplicated because of 'break'") else: result = n @@ -1339,20 +1448,24 @@ proc preprocess(c: var PreprocessContext; n: PNode): PNode = # detect: 'finally: raises X' which is currently not supported. We produce # an error for this case for now. All this will be done properly with Yuriy's # patch. + result = n case n.kind of nkTryStmt: let f = n.lastSon + var didAddSomething = false if f.kind == nkFinally: c.finallys.add f.lastSon + didAddSomething = true for i in 0 ..< n.len: result[i] = preprocess(c, n[i]) - if f.kind == nkFinally: + if didAddSomething: discard c.finallys.pop() of nkWhileStmt, nkBlockStmt: + if not n.hasYields: return n c.blocks.add((n, c.finallys.len)) for i in 0 ..< n.len: result[i] = preprocess(c, n[i]) @@ -1376,7 +1489,7 @@ proc preprocess(c: var PreprocessContext; n: PNode): PNode = result = newNodeI(nkStmtList, n.info) for i in countdown(c.finallys.high, fin): var vars = FreshVarsContext(tab: initTable[int, PSym](), config: c.config, info: n.info, idgen: c.idgen) - result.add freshVars(preprocess(c, c.finallys[i]), vars) + result.add freshVars(copyTree(c.finallys[i]), vars) c.idgen = vars.idgen result.add n of nkSkip: discard @@ -1384,19 +1497,78 @@ proc preprocess(c: var PreprocessContext; n: PNode): PNode = for i in 0 ..< n.len: result[i] = preprocess(c, n[i]) +proc detectCapturedVars(c: var Ctx, n: PNode, stateIdx: int) = + case n.kind + of nkSym: + let s = n.sym + if s.kind in {skResult, skVar, skLet, skForVar, skTemp} and sfGlobal notin s.flags and s.owner == c.fn: + let vs = c.varStates.getOrDefault(s.itemId, localNotSeen) + if vs == localNotSeen: # First seing this variable + c.varStates[s.itemId] = stateIdx + elif vs == localRequiresLifting: + discard # Sym already marked + elif vs != stateIdx: + c.captureVar(s) + of nkReturnStmt: + if n[0].kind in {nkAsgn, nkFastAsgn, nkSinkAsgn}: + # we have a `result = result` expression produced by the closure + # transform, let's not touch the LHS in order to make the lifting pass + # correct when `result` is lifted + detectCapturedVars(c, n[0][1], stateIdx) + else: + detectCapturedVars(c, n[0], stateIdx) + else: + for i in 0 ..< n.safeLen: + detectCapturedVars(c, n[i], stateIdx) + +proc detectCapturedVars(c: var Ctx) = + for i, s in c.states: + detectCapturedVars(c, s.body, i) + +proc liftLocals(c: var Ctx, n: PNode): PNode = + result = n + case n.kind + of nkSym: + let s = n.sym + if c.varStates.getOrDefault(s.itemId) == localRequiresLifting: + # lift + let e = getEnvParam(c.fn) + let field = getFieldFromObj(e.typ.elementType, s) + assert(field != nil) + result = rawIndirectAccess(newSymNode(e), field, n.info) + # elif c.varStates.getOrDefault(s.itemId, localNotSeen) != localNotSeen: + # echo "Not lifting ", s.name.s + + of nkReturnStmt: + if n[0].kind in {nkAsgn, nkFastAsgn, nkSinkAsgn}: + # we have a `result = result` expression produced by the closure + # transform, let's not touch the LHS in order to make the lifting pass + # correct when `result` is lifted + n[0][1] = liftLocals(c, n[0][1]) + else: + n[0] = liftLocals(c, n[0]) + else: + for i in 0 ..< n.safeLen: + n[i] = liftLocals(c, n[i]) + proc transformClosureIterator*(g: ModuleGraph; idgen: IdGenerator; fn: PSym, n: PNode): PNode = - var ctx: Ctx - ctx.g = g - ctx.fn = fn - ctx.idgen = idgen + var ctx = Ctx(g: g, fn: fn, idgen: idgen, + # should be default when issues are fixed, see #24094: + nimOptItersEnabled: isDefined(g.config, "nimOptIters")) if getEnvParam(fn).isNil: - # Lambda lifting was not done yet. Use temporary :state sym, which will - # be handled specially by lambda lifting. Local temp vars (if needed) - # should follow the same logic. - ctx.stateVarSym = newSym(skVar, getIdent(ctx.g.cache, ":state"), nextSymId(idgen), fn, fn.info) - ctx.stateVarSym.typ = g.createClosureIterStateType(fn, idgen) - ctx.stateLoopLabel = newSym(skLabel, getIdent(ctx.g.cache, ":stateLoop"), nextSymId(idgen), fn, fn.info) + if ctx.nimOptItersEnabled: + # The transformation should always happen after at least partial lambdalifting + # is performed, so that the closure iter environment is always created upfront. + doAssert(false, "Env param not created before iter transformation") + else: + # Lambda lifting was not done yet. Use temporary :state sym, which will + # be handled specially by lambda lifting. Local temp vars (if needed) + # should follow the same logic. + ctx.stateVarSym = newSym(skVar, getIdent(ctx.g.cache, ":state"), idgen, fn, fn.info) + ctx.stateVarSym.typ = g.createClosureIterStateType(fn, idgen) + + ctx.stateLoopLabel = newSym(skLabel, getIdent(ctx.g.cache, ":stateLoop"), idgen, fn, fn.info) var pc = PreprocessContext(finallys: @[], config: g.config, idgen: idgen) var n = preprocess(pc, n.toStmtList) #echo "transformed into ", n @@ -1417,21 +1589,30 @@ proc transformClosureIterator*(g: ModuleGraph; idgen: IdGenerator; fn: PSym, n: # Optimize empty states away ctx.deleteEmptyStates() - # Make new body by concatenating the list of states - result = newNodeI(nkStmtList, n.info) + let caseDispatcher = newTreeI(nkCaseStmt, n.info, + ctx.newStateAccess()) + + if ctx.nimOptItersEnabled: + # Lamdalifting will not touch our locals, it is our responsibility to lift those that + # need it. + detectCapturedVars(ctx) + for s in ctx.states: - assert(s.len == 2) - let body = s[1] - s.sons.del(1) - result.add(s) - result.add(body) + let body = ctx.transformStateAssignments(s.body) + caseDispatcher.add newTreeI(nkOfBranch, body.info, g.newIntLit(body.info, s.label), body) + + caseDispatcher.add newTreeI(nkElse, n.info, newTreeI(nkReturnStmt, n.info, g.emptyNode)) + + result = wrapIntoStateLoop(ctx, caseDispatcher) + if ctx.nimOptItersEnabled: + result = liftLocals(ctx, result) - result = ctx.transformStateAssignments(result) - result = ctx.wrapIntoStateLoop(result) + when false: + echo "TRANSFORM TO STATES: " + echo renderTree(result) - # echo "TRANSFORM TO STATES: " - # echo renderTree(result) + echo "exception table:" + for i, e in ctx.exceptionTable: + echo i, " -> ", e - # echo "exception table:" - # for i, e in ctx.exceptionTable: - # echo i, " -> ", e + echo "ENV: ", renderTree(getEnvParam(fn).typ.elementType.n) |