# # # The Nim Compiler # (c) Copyright 2018 Nim Contributors # # See the file "copying.txt", included in this # distribution, for details about the copyright. # # This file implements closure iterator transformations. # The main idea is to split the closure iterator body to top level statements. # The body is split by yield statement. # # Example: # while a > 0: # echo "hi" # yield a # dec a # # Should be transformed to: # STATE0: # if a > 0: # echo "hi" # :state = 1 # Next state # return a # yield # else: # :state = 2 # Next state # break :stateLoop # Proceed to the next state # STATE1: # dec a # :state = 0 # Next state # break :stateLoop # Proceed to the next state # STATE2: # :state = -1 # End of execution # The transformation should play well with lambdalifting, however depending # on situation, it can be called either before or after lambdalifting # transformation. As such we behave slightly differently, when accessing # iterator state, or using temp variables. If lambdalifting did not happen, # we just create local variables, so that they will be lifted further on. # Otherwise, we utilize existing env, created by lambdalifting. # Lambdalifting treats :state variable specially, it should always end up # as the first field in env. Currently C codegen depends on this behavior. # One special subtransformation is nkStmtListExpr lowering. # Example: # template foo(): int = # yield 1 # 2 # # iterator it(): int {.closure.} = # if foo() == 2: # yield 3 # # If a nkStmtListExpr has yield inside, it has first to be lowered to: # yield 1 # :tmpSlLower = 2 # if :tmpSlLower == 2: # yield 3 # nkTryStmt Transformations: # If the iter has an nkTryStmt with a yield inside # - the closure iter is promoted to have exceptions (ctx.hasExceptions = true) # - exception table is created. This is a const array, where # `abs(exceptionTable[i])` is a state idx to which we should jump from state # `i` should exception be raised in state `i`. For all states in `try` block # the target state is `except` block. For all states in `except` block # the target state is `finally` block. For all other states there is no # target state (0, as the first block can never be neither except nor finally). # `exceptionTable[i]` is < 0 if `abs(exceptionTable[i])` is except block, # and > 0, for finally block. # - local variable :curExc is created # - the iter body is wrapped into a # try: # closureIterSetupExc(:curExc) # ...body... # catch: # :state = exceptionTable[:state] # if :state == 0: raise # No state that could handle exception # :unrollFinally = :state > 0 # Target state is finally # if :state < 0: # :state = -:state # :curExc = getCurrentException() # # nkReturnStmt within a try/except/finally now has to behave differently as we # want the nearest finally block to be executed before the return, thus it is # transformed to: # :tmpResult = returnValue (if return doesn't have a value, this is skipped) # :unrollFinally = true # goto nearestFinally (or -1 if not exists) # # Example: # # try: # yield 0 # raise ... # except: # yield 1 # return 3 # finally: # yield 2 # # Is transformed to (yields are left in place for example simplicity, # in reality the code is subdivided even more, as described above): # # STATE0: # Try # yield 0 # raise ... # :state = 2 # What would happen should we not raise # break :stateLoop # STATE1: # Except # yield 1 # :tmpResult = 3 # Return # :unrollFinally = true # Return # :state = 2 # Goto Finally # break :stateLoop # :state = 2 # What would happen should we not return # break :stateLoop # STATE2: # Finally # yield 2 # if :unrollFinally: # This node is created by `newEndFinallyNode` # if :curExc.isNil: # return :tmpResult # else: # closureIterSetupExc(nil) # raise # state = -1 # Goto next state. In this case we just exit # break :stateLoop import ast, msgs, idents, renderer, magicsys, lowerings, lambdalifting, modulegraphs, lineinfos type Ctx = object g: ModuleGraph fn: PSym stateVarSym: PSym # :state variable. nil if env already introduced by lambdalifting tmpResultSym: PSym # Used when we return, but finally has to interfere unrollFinallySym: PSym # Indicates that we're unrolling finally states (either exception happened or premature return) curExcSym: PSym # Current exception states: seq[PNode] # The resulting states. Every state is an nkState node. blockLevel: int # Temp used to transform break and continue stmts stateLoopLabel: PSym # Label to break on, when jumping between states. exitStateIdx: int # index of the last state tempVarId: int # unique name counter tempVars: PNode # Temp var decls, nkVarSection exceptionTable: seq[int] # For state `i` jump to state `exceptionTable[i]` if exception is raised hasExceptions: bool # Does closure have yield in try? curExcHandlingState: int # Negative for except, positive for finally nearestFinally: int # Index of the nearest finally block. For try/except it # is their finally. For finally it is parent finally. Otherwise -1 const nkSkip = { nkEmpty..nkNilLit, nkTemplateDef, nkTypeSection, nkStaticStmt, nkCommentStmt } + procDefs proc newStateAccess(ctx: var Ctx): PNode = if ctx.stateVarSym.isNil: result = rawIndirectAccess(newSymNode(getEnvParam(ctx.fn)), getStateField(ctx.g, ctx.fn), ctx.fn.info) else: result = newSymNode(ctx.stateVarSym) proc newStateAssgn(ctx: var Ctx, toValue: PNode): PNode = # Creates state assignment: # :state = toValue newTree(nkAsgn, ctx.newStateAccess(), toValue) proc newStateAssgn(ctx: var Ctx, stateNo: int = -2): PNode = # Creates state assignment: # :state = stateNo ctx.newStateAssgn(newIntTypeNode(stateNo, ctx.g.getSysType(TLineInfo(), tyInt))) proc newEnvVar(ctx: var Ctx, name: string, typ: PType): PSym = result = newSym(skVar, getIdent(ctx.g.cache, name), ctx.fn, ctx.fn.info) result.typ = typ assert(not typ.isNil) if not ctx.stateVarSym.isNil: # We haven't gone through labmda lifting yet, so just create a local var, # it will be lifted later if ctx.tempVars.isNil: ctx.tempVars = newNodeI(nkVarSection, ctx.fn.info) addVar(ctx.tempVars, newSymNode(result)) else: let envParam = getEnvParam(ctx.fn) # let obj = envParam.typ.lastSon result = addUniqueField(envParam.typ.lastSon, result, ctx.g.cache) proc newEnvVarAccess(ctx: Ctx, s: PSym): PNode = if ctx.stateVarSym.isNil: result = rawIndirectAccess(newSymNode(getEnvParam(ctx.fn)), s, ctx.fn.info) else: result = newSymNode(s) proc newTmpResultAccess(ctx: var Ctx): PNode = if ctx.tmpResultSym.isNil: ctx.tmpResultSym = ctx.newEnvVar(":tmpResult", ctx.fn.typ[0]) ctx.newEnvVarAccess(ctx.tmpResultSym) proc newUnrollFinallyAccess(ctx: var Ctx, info: TLineInfo): PNode = if ctx.unrollFinallySym.isNil: ctx.unrollFinallySym = ctx.newEnvVar(":unrollFinally", ctx.g.getSysType(info, tyBool)) ctx.newEnvVarAccess(ctx.unrollFinallySym) proc newCurExcAccess(ctx: var Ctx): PNode = if ctx.curExcSym.isNil: ctx.curExcSym = ctx.newEnvVar(":curExc", ctx.g.callCodegenProc("getCurrentException").typ) ctx.newEnvVarAccess(ctx.curExcSym) proc newState(ctx: var Ctx, n, gotoOut: PNode): int = # Creates a new state, adds it to the context fills out `gotoOut` so that it # will goto this state. # Returns index of the newly created state result = ctx.states.len let resLit = ctx.g.newIntLit(n.info, result) let s = newNodeI(nkState, n.info) s.add(resLit) s.add(n) ctx.states.add(s) ctx.exceptionTable.add(ctx.curExcHandlingState) if not gotoOut.isNil: assert(gotoOut.len == 0) gotoOut.add(ctx.g.newIntLit(gotoOut.info, result)) proc toStmtList(n: PNode): PNode = result = n if result.kind notin {nkStmtList, nkStmtListExpr}: result = newNodeI(nkStmtList, n.info) result.add(n) proc addGotoOut(n: PNode, gotoOut: PNode): PNode = # Make sure `n` is a stmtlist, and ends with `gotoOut` result = toStmtList(n) if result.len == 0 or result.sons[^1].kind != nkGotoState: result.add(gotoOut) proc newTempVar(ctx: var Ctx, typ: PType): PSym = result = ctx.newEnvVar(":tmpSlLower" & $ctx.tempVarId, typ) inc ctx.tempVarId proc hasYields(n: PNode): bool = # TODO: This is very inefficient. It traverses the node, looking for nkYieldStmt. case n.kind of nkYieldStmt: result = true of nkSkip: discard else: for c in n: if c.hasYields: result = true break proc transformBreaksAndContinuesInWhile(ctx: var Ctx, n: PNode, before, after: PNode): PNode = result = n case n.kind of nkSkip: discard of nkWhileStmt: discard # Do not recurse into nested whiles of nkContinueStmt: result = before of nkBlockStmt: inc ctx.blockLevel result[1] = ctx.transformBreaksAndContinuesInWhile(result[1], before, after) dec ctx.blockLevel of nkBreakStmt: if ctx.blockLevel == 0: result = after else: for i in 0 ..< n.len: n[i] = ctx.transformBreaksAndContinuesInWhile(n[i], before, after) proc transformBreaksInBlock(ctx: var Ctx, n: PNode, label, after: PNode): PNode = result = n case n.kind of nkSkip: discard of nkBlockStmt, nkWhileStmt: inc ctx.blockLevel result[1] = ctx.transformBreaksInBlock(result[1], label, after) dec ctx.blockLevel of nkBreakStmt: if n[0].kind == nkEmpty: if ctx.blockLevel == 0: result = after else: if label.kind == nkSym and n[0].sym == label.sym: result = after else: for i in 0 ..< n.len: n[i] = ctx.transformBreaksInBlock(n[i], label, after) proc newNullifyCurExc(ctx: var Ctx, info: TLineInfo): PNode = # :curEcx = nil let curExc = ctx.newCurExcAccess() curExc.info = info let nilnode = newNode(nkNilLit) nilnode.typ = curExc.typ result = newTree(nkAsgn, curExc, nilnode) proc newOr(g: ModuleGraph, a, b: PNode): PNode {.inline.} = result = newTree(nkCall, newSymNode(g.getSysMagic(a.info, "or", mOr)), a, b) result.typ = g.getSysType(a.info, tyBool) result.info = a.info proc collectExceptState(ctx: var Ctx, n: PNode): PNode {.inline.} = var ifStmt = newNodeI(nkIfStmt, n.info) let g = ctx.g for c in n: if c.kind == nkExceptBranch: var ifBranch: PNode if c.len > 1: var cond: PNode for i in 0 .. c.len - 2: assert(c[i].kind == nkType) let nextCond = newTree(nkCall, newSymNode(g.getSysMagic(c.info, "of", mOf)), g.callCodegenProc("getCurrentException"), c[i]) nextCond.typ = ctx.g.getSysType(c.info, tyBool) nextCond.info = c.info if cond.isNil: cond = nextCond else: cond = g.newOr(cond, nextCond) ifBranch = newNodeI(nkElifBranch, c.info) ifBranch.add(cond) else: if ifStmt.len == 0: ifStmt = newNodeI(nkStmtList, c.info) ifBranch = newNodeI(nkStmtList, c.info) else: ifBranch = newNodeI(nkElse, c.info) ifBranch.add(c[^1]) ifStmt.add(ifBranch) if ifStmt.len != 0: result = newTree(nkStmtList, ctx.newNullifyCurExc(n.info), ifStmt) else: result = ctx.g.emptyNode proc addElseToExcept(ctx: var Ctx, n: PNode) = if n.kind == nkStmtList and n[1].kind == nkIfStmt and n[1][^1].kind != nkElse: # Not all cases are covered let branchBody = newNodeI(nkStmtList, n.info) block: # :unrollFinally = true branchBody.add(newTree(nkAsgn, ctx.newUnrollFinallyAccess(n.info), newIntTypeNode(1, ctx.g.getSysType(n.info, tyBool)))) block: # :curExc = getCurrentException() branchBody.add(newTree(nkAsgn, ctx.newCurExcAccess(), ctx.g.callCodegenProc("getCurrentException"))) block: # goto nearestFinally branchBody.add(newTree(nkGotoState, ctx.g.newIntLit(n.info, ctx.nearestFinally))) let elseBranch = newTree(nkElse, branchBody) n[1].add(elseBranch) proc getFinallyNode(ctx: var Ctx, n: PNode): PNode = result = n[^1] if result.kind == nkFinally: result = result[0] else: result = ctx.g.emptyNode proc hasYieldsInExpressions(n: PNode): bool = case n.kind of nkSkip: discard of nkStmtListExpr: if isEmptyType(n.typ): for c in n: if c.hasYieldsInExpressions: return true else: result = n.hasYields else: for c in n: if c.hasYieldsInExpressions: return true proc exprToStmtList(n: PNode): tuple[s, res: PNode] = assert(n.kind == nkStmtListExpr) result.s = newNodeI(nkStmtList, n.info) result.s.sons = @[] var n = n while n.kind == nkStmtListExpr: result.s.sons.add(n.sons) result.s.sons.setLen(result.s.sons.len - 1) # delete last son n = n[^1] result.res = n proc newEnvVarAsgn(ctx: Ctx, s: PSym, v: PNode): PNode = result = newTree(nkFastAsgn, ctx.newEnvVarAccess(s), v) result.info = v.info proc addExprAssgn(ctx: Ctx, output, input: PNode, sym: PSym) = if input.kind == nkStmtListExpr: let (st, res) = exprToStmtList(input) output.add(st) output.add(ctx.newEnvVarAsgn(sym, res)) else: output.add(ctx.newEnvVarAsgn(sym, input)) proc convertExprBodyToAsgn(ctx: Ctx, exprBody: PNode, res: PSym): PNode = result = newNodeI(nkStmtList, exprBody.info) ctx.addExprAssgn(result, exprBody, res) proc newNotCall(g: ModuleGraph; e: PNode): PNode = result = newTree(nkCall, newSymNode(g.getSysMagic(e.info, "not", mNot), e.info), e) result.typ = g.getSysType(e.info, tyBool) proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode = result = n case n.kind of nkSkip: discard of nkYieldStmt: var ns = false for i in 0 ..< n.len: n[i] = ctx.lowerStmtListExprs(n[i], ns) if ns: result = newNodeI(nkStmtList, n.info) let (st, ex) = exprToStmtList(n[0]) result.add(st) n[0] = ex result.add(n) needsSplit = true of nkPar, nkObjConstr, nkTupleConstr, nkBracket: var ns = false for i in 0 ..< n.len: n[i] = ctx.lowerStmtListExprs(n[i], ns) if ns: needsSplit = true result = newNodeI(nkStmtListExpr, n.info) if n.typ.isNil: internalError(ctx.g.config, "lowerStmtListExprs: constr typ.isNil") result.typ = n.typ for i in 0 ..< n.len: case n[i].kind of nkExprColonExpr: if n[i][1].kind == nkStmtListExpr: let (st, ex) = exprToStmtList(n[i][1]) result.add(st) n[i][1] = ex of nkStmtListExpr: let (st, ex) = exprToStmtList(n[i]) result.add(st) n[i] = ex else: discard result.add(n) of nkIfStmt, nkIfExpr: var ns = false for i in 0 ..< n.len: n[i] = ctx.lowerStmtListExprs(n[i], ns) if ns: needsSplit = true var tmp: PSym var s: PNode let isExpr = not isEmptyType(n.typ) if isExpr: tmp = ctx.newTempVar(n.typ) result = newNodeI(nkStmtListExpr, n.info) result.typ = n.typ else: result = newNodeI(nkStmtList, n.info) var curS = result for branch in n: case branch.kind of nkElseExpr, nkElse: if isExpr: let branchBody = newNodeI(nkStmtList, branch.info) ctx.addExprAssgn(branchBody, branch[0], tmp) let newBranch = newTree(nkElse, branchBody) curS.add(newBranch) else: curS.add(branch) of nkElifExpr, nkElifBranch: var newBranch: PNode if branch[0].kind == nkStmtListExpr: let (st, res) = exprToStmtList(branch[0]) let elseBody = newTree(nkStmtList, st) newBranch = newTree(nkElifBranch, res, branch[1]) let newIf = newTree(nkIfStmt, newBranch) elseBody.add(newIf) if curS.kind == nkIfStmt: let newElse = newNodeI(nkElse, branch.info) newElse.add(elseBody) curS.add(newElse) else: curS.add(elseBody) curS = newIf else: newBranch = branch if curS.kind == nkIfStmt: curS.add(newBranch) else: let newIf = newTree(nkIfStmt, newBranch) curS.add(newIf) curS = newIf if isExpr: let branchBody = newNodeI(nkStmtList, branch[1].info) ctx.addExprAssgn(branchBody, branch[1], tmp) newBranch[1] = branchBody else: internalError(ctx.g.config, "lowerStmtListExpr(nkIf): " & $branch.kind) if isExpr: result.add(ctx.newEnvVarAccess(tmp)) of nkTryStmt, nkHiddenTryStmt: var ns = false for i in 0 ..< n.len: n[i] = ctx.lowerStmtListExprs(n[i], ns) if ns: needsSplit = true let isExpr = not isEmptyType(n.typ) if isExpr: result = newNodeI(nkStmtListExpr, n.info) result.typ = n.typ let tmp = ctx.newTempVar(n.typ) n[0] = ctx.convertExprBodyToAsgn(n[0], tmp) for i in 1 ..< n.len: let branch = n[i] case branch.kind of nkExceptBranch: if branch[0].kind == nkType: branch[1] = ctx.convertExprBodyToAsgn(branch[1], tmp) else: branch[0] = ctx.convertExprBodyToAsgn(branch[0], tmp) of nkFinally: discard else: internalError(ctx.g.config, "lowerStmtListExpr(nkTryStmt): " & $branch.kind) result.add(n) result.add(ctx.newEnvVarAccess(tmp)) of nkCaseStmt: var ns = false for i in 0 ..< n.len: n[i] = ctx.lowerStmtListExprs(n[i], ns) if ns: needsSplit = true let isExpr = not isEmptyType(n.typ) if isExpr: let tmp = ctx.newTempVar(n.typ) result = newNodeI(nkStmtListExpr, n.info) result.typ = n.typ if n[0].kind == nkStmtListExpr: let (st, ex) = exprToStmtList(n[0]) result.add(st) n[0] = ex for i in 1 ..< n.len: let branch = n[i] case branch.kind of nkOfBranch: branch[^1] = ctx.convertExprBodyToAsgn(branch[^1], tmp) of nkElse: branch[0] = ctx.convertExprBodyToAsgn(branch[0], tmp) else: internalError(ctx.g.config, "lowerStmtListExpr(nkCaseStmt): " & $branch.kind) result.add(n) result.add(ctx.newEnvVarAccess(tmp)) of nkCallKinds, nkChckRange, nkChckRangeF, nkChckRange64: var ns = false for i in 0 ..< n.len: n[i] = ctx.lowerStmtListExprs(n[i], ns) if ns: needsSplit = true let isExpr = not isEmptyType(n.typ) if isExpr: result = newNodeI(nkStmtListExpr, n.info) result.typ = n.typ else: result = newNodeI(nkStmtList, n.info) if n[0].kind == nkSym and n[0].sym.magic in {mAnd, mOr}: # `and`/`or` short cirquiting var cond = n[1] if cond.kind == nkStmtListExpr: let (st, ex) = exprToStmtList(cond) result.add(st) cond = ex let tmp = ctx.newTempVar(cond.typ) result.add(ctx.newEnvVarAsgn(tmp, cond)) var check = ctx.newEnvVarAccess(tmp) if n[0].sym.magic == mOr: check = ctx.g.newNotCall(check) cond = n[2] let ifBody = newNodeI(nkStmtList, cond.info) if cond.kind == nkStmtListExpr: let (st, ex) = exprToStmtList(cond) ifBody.add(st) cond = ex ifBody.add(ctx.newEnvVarAsgn(tmp, cond)) let ifBranch = newTree(nkElifBranch, check, ifBody) let ifNode = newTree(nkIfStmt, ifBranch) result.add(ifNode) result.add(ctx.newEnvVarAccess(tmp)) else: for i in 0 ..< n.len: if n[i].kind == nkStmtListExpr: let (st, ex) = exprToStmtList(n[i]) result.add(st) n[i] = ex if n[i].kind in nkCallKinds: # XXX: This should better be some sort of side effect tracking let tmp = ctx.newTempVar(n[i].typ) result.add(ctx.newEnvVarAsgn(tmp, n[i])) n[i] = ctx.newEnvVarAccess(tmp) result.add(n) of nkVarSection, nkLetSection: result = newNodeI(nkStmtList, n.info) for c in n: let varSect = newNodeI(n.kind, n.info) varSect.add(c) var ns = false c[^1] = ctx.lowerStmtListExprs(c[^1], ns) if ns: needsSplit = true let (st, ex) = exprToStmtList(c[^1]) result.add(st) c[^1] = ex result.add(varSect) of nkDiscardStmt, nkReturnStmt, nkRaiseStmt: var ns = false for i in 0 ..< n.len: n[i] = ctx.lowerStmtListExprs(n[i], ns) if ns: needsSplit = true result = newNodeI(nkStmtList, n.info) let (st, ex) = exprToStmtList(n[0]) result.add(st) n[0] = ex result.add(n) of nkCast, nkHiddenStdConv, nkHiddenSubConv, nkConv, nkObjDownConv: var ns = false for i in 0 ..< n.len: n[i] = ctx.lowerStmtListExprs(n[i], ns) if ns: needsSplit = true result = newNodeI(nkStmtListExpr, n.info) result.typ = n.typ let (st, ex) = exprToStmtList(n[^1]) result.add(st) n[^1] = ex result.add(n) of nkAsgn, nkFastAsgn: var ns = false for i in 0 ..< n.len: n[i] = ctx.lowerStmtListExprs(n[i], ns) if ns: needsSplit = true result = newNodeI(nkStmtList, n.info) if n[0].kind == nkStmtListExpr: let (st, ex) = exprToStmtList(n[0]) result.add(st) n[0] = ex if n[1].kind == nkStmtListExpr: let (st, ex) = exprToStmtList(n[1]) result.add(st) n[1] = ex result.add(n) of nkBracketExpr: var lhsNeedsSplit = false var rhsNeedsSplit = false n[0] = ctx.lowerStmtListExprs(n[0], lhsNeedsSplit) n[1] = ctx.lowerStmtListExprs(n[1], rhsNeedsSplit) if lhsNeedsSplit or rhsNeedsSplit: needsSplit = true result = newNodeI(nkStmtListExpr, n.info) if lhsNeedsSplit: let (st, ex) = exprToStmtList(n[0]) result.add(st) n[0] = ex if rhsNeedsSplit: let (st, ex) = exprToStmtList(n[1]) result.add(st) n[1] = ex result.add(n) of nkWhileStmt: var ns = false var condNeedsSplit = false n[0] = ctx.lowerStmtListExprs(n[0], condNeedsSplit) var bodyNeedsSplit = false n[1] = ctx.lowerStmtListExprs(n[1], bodyNeedsSplit) if condNeedsSplit or bodyNeedsSplit: needsSplit = true if condNeedsSplit: let (st, ex) = exprToStmtList(n[0]) let brk = newTree(nkBreakStmt, ctx.g.emptyNode) let branch = newTree(nkElifBranch, ctx.g.newNotCall(ex), brk) let check = newTree(nkIfStmt, branch) let newBody = newTree(nkStmtList, st, check, n[1]) n[0] = newSymNode(ctx.g.getSysSym(n[0].info, "true")) n[1] = newBody of nkDotExpr: var ns = false n[0] = ctx.lowerStmtListExprs(n[0], ns) if ns: needsSplit = true result = newNodeI(nkStmtListExpr, n.info) result.typ = n.typ let (st, ex) = exprToStmtList(n[0]) result.add(st) n[0] = ex result.add(n) of nkBlockExpr: var ns = false n[1] = ctx.lowerStmtListExprs(n[1], ns) if ns: needsSplit = true result = newNodeI(nkStmtListExpr, n.info) result.typ = n.typ let (st, ex) = exprToStmtList(n[1]) n.kind = nkBlockStmt n.typ = nil n[1] = st result.add(n) result.add(ex) else: for i in 0 ..< n.len: n[i] = ctx.lowerStmtListExprs(n[i], needsSplit) proc newEndFinallyNode(ctx: var Ctx, info: TLineInfo): PNode = # Generate the following code: # if :unrollFinally: # if :curExc.isNil: # return :tmpResult # else: # raise let curExc = ctx.newCurExcAccess() let nilnode = newNode(nkNilLit) nilnode.typ = curExc.typ let cmp = newTree(nkCall, newSymNode(ctx.g.getSysMagic(info, "==", mEqRef), info), curExc, nilnode) cmp.typ = ctx.g.getSysType(info, tyBool) let asgn = newTree(nkFastAsgn, newSymNode(getClosureIterResult(ctx.g, ctx.fn), info), ctx.newTmpResultAccess()) let retStmt = newTree(nkReturnStmt, asgn) let branch = newTree(nkElifBranch, cmp, retStmt) let nullifyExc = newTree(nkCall, newSymNode(ctx.g.getCompilerProc("closureIterSetupExc")), nilnode) nullifyExc.info = info let raiseStmt = newTree(nkRaiseStmt, curExc) raiseStmt.info = info let elseBranch = newTree(nkElse, newTree(nkStmtList, nullifyExc, raiseStmt)) let ifBody = newTree(nkIfStmt, branch, elseBranch) let elifBranch = newTree(nkElifBranch, ctx.newUnrollFinallyAccess(info), ifBody) elifBranch.info = info result = newTree(nkIfStmt, elifBranch) proc transformReturnsInTry(ctx: var Ctx, n: PNode): PNode = result = n # TODO: This is very inefficient. It traverses the node, looking for nkYieldStmt. case n.kind of nkReturnStmt: # We're somewhere in try, transform to finally unrolling assert(ctx.nearestFinally != 0) result = newNodeI(nkStmtList, n.info) block: # :unrollFinally = true let asgn = newNodeI(nkAsgn, n.info) asgn.add(ctx.newUnrollFinallyAccess(n.info)) asgn.add(newIntTypeNode(1, ctx.g.getSysType(n.info, tyBool))) result.add(asgn) if n[0].kind != nkEmpty: let asgnTmpResult = newNodeI(nkAsgn, n.info) asgnTmpResult.add(ctx.newTmpResultAccess()) asgnTmpResult.add(n[0]) result.add(asgnTmpResult) result.add(ctx.newNullifyCurExc(n.info)) let goto = newTree(nkGotoState, ctx.g.newIntLit(n.info, ctx.nearestFinally)) result.add(goto) of nkSkip: discard else: for i in 0 ..< n.len: n[i] = ctx.transformReturnsInTry(n[i]) proc transformClosureIteratorBody(ctx: var Ctx, n: PNode, gotoOut: PNode): PNode = result = n case n.kind of nkSkip: discard of nkStmtList, nkStmtListExpr: result = addGotoOut(result, gotoOut) for i in 0 ..< n.len: if n[i].hasYields: # Create a new split let go = newNodeI(nkGotoState, n[i].info) n[i] = ctx.transformClosureIteratorBody(n[i], go) let s = newNodeI(nkStmtList, n[i + 1].info) for j in i + 1 ..< n.len: s.add(n[j]) n.sons.setLen(i + 1) discard ctx.newState(s, go) if ctx.transformClosureIteratorBody(s, gotoOut) != s: internalError(ctx.g.config, "transformClosureIteratorBody != s") break of nkYieldStmt: result = newNodeI(nkStmtList, n.info) result.add(n) result.add(gotoOut) of nkElse, nkElseExpr: result[0] = addGotoOut(result[0], gotoOut) result[0] = ctx.transformClosureIteratorBody(result[0], gotoOut) of nkElifBranch, nkElifExpr, nkOfBranch: result[^1] = addGotoOut(result[^1], gotoOut) result[^1] = ctx.transformClosureIteratorBody(result[^1], gotoOut) of nkIfStmt, nkCaseStmt: for i in 0 ..< n.len: n[i] = ctx.transformClosureIteratorBody(n[i], gotoOut) if n[^1].kind != nkElse: # We don't have an else branch, but every possible branch has to end with # gotoOut, so add else here. let elseBranch = newTree(nkElse, gotoOut) n.add(elseBranch) of nkWhileStmt: # while e: # s # -> # BEGIN_STATE: # if e: # s # goto BEGIN_STATE # else: # goto OUT result = newNodeI(nkGotoState, n.info) let s = newNodeI(nkStmtList, n.info) discard ctx.newState(s, result) let ifNode = newNodeI(nkIfStmt, n.info) let elifBranch = newNodeI(nkElifBranch, n.info) elifBranch.add(n[0]) var body = addGotoOut(n[1], result) body = ctx.transformBreaksAndContinuesInWhile(body, result, gotoOut) body = ctx.transformClosureIteratorBody(body, result) elifBranch.add(body) ifNode.add(elifBranch) let elseBranch = newTree(nkElse, gotoOut) ifNode.add(elseBranch) s.add(ifNode) of nkBlockStmt: result[1] = addGotoOut(result[1], gotoOut) result[1] = ctx.transformBreaksInBlock(result[1], result[0], gotoOut) result[1] = ctx.transformClosureIteratorBody(result[1], gotoOut) of nkTryStmt, nkHiddenTryStmt: # See explanation above about how this works ctx.hasExceptions = true result = newNodeI(nkGotoState, n.info) var tryBody = toStmtList(n[0]) var exceptBody = ctx.collectExceptState(n) var finallyBody = newTree(nkStmtList, getFinallyNode(ctx, n)) finallyBody = ctx.transformReturnsInTry(finallyBody) finallyBody.add(ctx.newEndFinallyNode(finallyBody.info)) # The following index calculation is based on the knowledge how state # indexes are assigned let tryIdx = ctx.states.len var exceptIdx, finallyIdx: int if exceptBody.kind != nkEmpty: exceptIdx = -(tryIdx + 1) finallyIdx = tryIdx + 2 else: exceptIdx = tryIdx + 1 finallyIdx = tryIdx + 1 let outToFinally = newNodeI(nkGotoState, finallyBody.info) block: # Create initial states. let oldExcHandlingState = ctx.curExcHandlingState ctx.curExcHandlingState = exceptIdx let realTryIdx = ctx.newState(tryBody, result) assert(realTryIdx == tryIdx) if exceptBody.kind != nkEmpty: ctx.curExcHandlingState = finallyIdx let realExceptIdx = ctx.newState(exceptBody, nil) assert(realExceptIdx == -exceptIdx) ctx.curExcHandlingState = oldExcHandlingState let realFinallyIdx = ctx.newState(finallyBody, outToFinally) assert(realFinallyIdx == finallyIdx) block: # Subdivide the states let oldNearestFinally = ctx.nearestFinally ctx.nearestFinally = finallyIdx let oldExcHandlingState = ctx.curExcHandlingState ctx.curExcHandlingState = exceptIdx if ctx.transformReturnsInTry(tryBody) != tryBody: internalError(ctx.g.config, "transformReturnsInTry != tryBody") if ctx.transformClosureIteratorBody(tryBody, outToFinally) != tryBody: internalError(ctx.g.config, "transformClosureIteratorBody != tryBody") ctx.curExcHandlingState = finallyIdx ctx.addElseToExcept(exceptBody) if ctx.transformReturnsInTry(exceptBody) != exceptBody: internalError(ctx.g.config, "transformReturnsInTry != exceptBody") if ctx.transformClosureIteratorBody(exceptBody, outToFinally) != exceptBody: internalError(ctx.g.config, "transformClosureIteratorBody != exceptBody") ctx.curExcHandlingState = oldExcHandlingState ctx.nearestFinally = oldNearestFinally if ctx.transformClosureIteratorBody(finallyBody, gotoOut) != finallyBody: internalError(ctx.g.config, "transformClosureIteratorBody != finallyBody") of nkGotoState, nkForStmt: internalError(ctx.g.config, "closure iter " & $n.kind) else: for i in 0 ..< n.len: n[i] = ctx.transformClosureIteratorBody(n[i], gotoOut) proc stateFromGotoState(n: PNode): int = assert(n.kind == nkGotoState) result = n[0].intVal.int proc transformStateAssignments(ctx: var Ctx, n: PNode): PNode = # This transforms 3 patterns: ########################## 1 # yield e # goto STATE # -> # :state = STATE # return e ########################## 2 # goto STATE # -> # :state = STATE # break :stateLoop ########################## 3 # return e # -> # :state = -1 # return e # result = n case n.kind of nkStmtList, nkStmtListExpr: if n.len != 0 and n[0].kind == nkYieldStmt: assert(n.len == 2) assert(n[1].kind == nkGotoState) result = newNodeI(nkStmtList, n.info) result.add(ctx.newStateAssgn(stateFromGotoState(n[1]))) var retStmt = newNodeI(nkReturnStmt, n.info) if n[0].sons[0].kind != nkEmpty: var a = newNodeI(nkAsgn, n[0].sons[0].info) var retVal = n[0].sons[0] #liftCapturedVars(n.sons[0], owner, d, c) addSon(a, newSymNode(getClosureIterResult(ctx.g, ctx.fn))) addSon(a, retVal) retStmt.add(a) else: retStmt.add(ctx.g.emptyNode) result.add(retStmt) else: for i in 0 ..< n.len: n[i] = ctx.transformStateAssignments(n[i]) of nkSkip: discard of nkReturnStmt: result = newNodeI(nkStmtList, n.info) result.add(ctx.newStateAssgn(-1)) result.add(n) of nkGotoState: result = newNodeI(nkStmtList, n.info) result.add(ctx.newStateAssgn(stateFromGotoState(n))) let breakState = newNodeI(nkBreakStmt, n.info) breakState.add(newSymNode(ctx.stateLoopLabel)) result.add(breakState) else: for i in 0 ..< n.len: n[i] = ctx.transformStateAssignments(n[i]) proc skipStmtList(ctx: Ctx; n: PNode): PNode = result = n while result.kind in {nkStmtList}: if result.len == 0: return ctx.g.emptyNode result = result[0] proc skipEmptyStates(ctx: Ctx, stateIdx: int): int = # Returns first non-empty state idx for `stateIdx`. Returns `stateIdx` if # it is not empty var maxJumps = ctx.states.len # maxJumps used only for debugging purposes. var stateIdx = stateIdx while true: let label = stateIdx if label == ctx.exitStateIdx: break var newLabel = label if label == -1: newLabel = ctx.exitStateIdx else: let fs = skipStmtList(ctx, ctx.states[label][1]) if fs.kind == nkGotoState: newLabel = fs[0].intVal.int if label == newLabel: break stateIdx = newLabel dec maxJumps if maxJumps == 0: assert(false, "Internal error") result = ctx.states[stateIdx][0].intVal.int proc skipThroughEmptyStates(ctx: var Ctx, n: PNode): PNode = result = n case n.kind of nkSkip: discard of nkGotoState: result = copyTree(n) result[0].intVal = ctx.skipEmptyStates(result[0].intVal.int) else: for i in 0 ..< n.len: n[i] = ctx.skipThroughEmptyStates(n[i]) proc newArrayType(g: ModuleGraph; n: int, t: PType, owner: PSym): PType = result = newType(tyArray, owner) let rng = newType(tyRange, owner) rng.n = newTree(nkRange, g.newIntLit(owner.info, 0), g.newIntLit(owner.info, n)) rng.rawAddSon(t) result.rawAddSon(rng) result.rawAddSon(t) proc createExceptionTable(ctx: var Ctx): PNode {.inline.} = result = newNodeI(nkBracket, ctx.fn.info) result.typ = ctx.g.newArrayType(ctx.exceptionTable.len, ctx.g.getSysType(ctx.fn.info, tyInt16), ctx.fn) for i in ctx.exceptionTable: let elem = newIntNode(nkIntLit, i) elem.typ = ctx.g.getSysType(ctx.fn.info, tyInt16) result.add(elem) proc newCatchBody(ctx: var Ctx, info: TLineInfo): PNode {.inline.} = # Generates the code: # :state = exceptionTable[:state] # if :state == 0: raise # :unrollFinally = :state > 0 # if :state < 0: # :state = -:state # :curExc = getCurrentException() result = newNodeI(nkStmtList, info) let intTyp = ctx.g.getSysType(info, tyInt) let boolTyp = ctx.g.getSysType(info, tyBool) # :state = exceptionTable[:state] block: # exceptionTable[:state] let getNextState = newTree(nkBracketExpr, ctx.createExceptionTable(), ctx.newStateAccess()) getNextState.typ = intTyp # :state = exceptionTable[:state] result.add(ctx.newStateAssgn(getNextState)) # if :state == 0: raise block: let cond = newTree(nkCall, ctx.g.getSysMagic(info, "==", mEqI).newSymNode(), ctx.newStateAccess(), newIntTypeNode(0, intTyp)) cond.typ = boolTyp let raiseStmt = newTree(nkRaiseStmt, ctx.g.emptyNode) let ifBranch = newTree(nkElifBranch, cond, raiseStmt) let ifStmt = newTree(nkIfStmt, ifBranch) result.add(ifStmt) # :unrollFinally = :state > 0 block: let cond = newTree(nkCall, ctx.g.getSysMagic(info, "<", mLtI).newSymNode, newIntTypeNode(0, intTyp), ctx.newStateAccess()) cond.typ = boolTyp let asgn = newTree(nkAsgn, ctx.newUnrollFinallyAccess(info), cond) result.add(asgn) # if :state < 0: :state = -:state block: let cond = newTree(nkCall, ctx.g.getSysMagic(info, "<", mLtI).newSymNode, ctx.newStateAccess(), newIntTypeNode(0, intTyp)) cond.typ = boolTyp let negateState = newTree(nkCall, ctx.g.getSysMagic(info, "-", mUnaryMinusI).newSymNode, ctx.newStateAccess()) negateState.typ = intTyp let ifBranch = newTree(nkElifBranch, cond, ctx.newStateAssgn(negateState)) let ifStmt = newTree(nkIfStmt, ifBranch) result.add(ifStmt) # :curExc = getCurrentException() block: result.add(newTree(nkAsgn, ctx.newCurExcAccess(), ctx.g.callCodegenProc("getCurrentException"))) proc wrapIntoTryExcept(ctx: var Ctx, n: PNode): PNode {.inline.} = let setupExc = newTree(nkCall, newSymNode(ctx.g.getCompilerProc("closureIterSetupExc")), ctx.newCurExcAccess()) let tryBody = newTree(nkStmtList, setupExc, n) let exceptBranch = newTree(nkExceptBranch, ctx.newCatchBody(ctx.fn.info)) result = newTree(nkTryStmt, tryBody, exceptBranch) proc wrapIntoStateLoop(ctx: var Ctx, n: PNode): PNode = # while true: # block :stateLoop: # gotoState :state # local vars decl (if needed) # body # Might get wrapped in try-except let loopBody = newNodeI(nkStmtList, n.info) result = newTree(nkWhileStmt, newSymNode(ctx.g.getSysSym(n.info, "true")), loopBody) result.info = n.info let localVars = newNodeI(nkStmtList, n.info) if not ctx.stateVarSym.isNil: let varSect = newNodeI(nkVarSection, n.info) addVar(varSect, newSymNode(ctx.stateVarSym)) localVars.add(varSect) if not ctx.tempVars.isNil: localVars.add(ctx.tempVars) let blockStmt = newNodeI(nkBlockStmt, n.info) blockStmt.add(newSymNode(ctx.stateLoopLabel)) let gs = newNodeI(nkGotoState, n.info) gs.add(ctx.newStateAccess()) gs.add(ctx.g.newIntLit(n.info, ctx.states.len - 1)) var blockBody = newTree(nkStmtList, gs, localVars, n) if ctx.hasExceptions: blockBody = ctx.wrapIntoTryExcept(blockBody) blockStmt.add(blockBody) loopBody.add(blockStmt) proc deleteEmptyStates(ctx: var Ctx) = let goOut = newTree(nkGotoState, ctx.g.newIntLit(TLineInfo(), -1)) ctx.exitStateIdx = ctx.newState(goOut, nil) # Apply new state indexes and mark unused states with -1 var iValid = 0 for i, s in ctx.states: let body = skipStmtList(ctx, s[1]) if body.kind == nkGotoState and i != ctx.states.len - 1 and i != 0: # This is an empty state. Mark with -1. s[0].intVal = -1 else: s[0].intVal = iValid inc iValid for i, s in ctx.states: let body = skipStmtList(ctx, s[1]) if body.kind != nkGotoState or i == 0: discard ctx.skipThroughEmptyStates(s) let excHandlState = ctx.exceptionTable[i] if excHandlState < 0: ctx.exceptionTable[i] = -ctx.skipEmptyStates(-excHandlState) elif excHandlState != 0: ctx.exceptionTable[i] = ctx.skipEmptyStates(excHandlState) var i = 0 while i < ctx.states.len - 1: let fs = skipStmtList(ctx, ctx.states[i][1]) if fs.kind == nkGotoState and i != 0: ctx.states.delete(i) ctx.exceptionTable.delete(i) else: inc i proc transformClosureIterator*(g: ModuleGraph; fn: PSym, n: PNode): PNode = var ctx: Ctx ctx.g = g ctx.fn = fn if getEnvParam(fn).isNil: # Lambda lifting was not done yet. Use temporary :state sym, which will # be handled specially by lambda lifting. Local temp vars (if needed) # should folllow the same logic. ctx.stateVarSym = newSym(skVar, getIdent(ctx.g.cache, ":state"), fn, fn.info) ctx.stateVarSym.typ = g.createClosureIterStateType(fn) ctx.stateLoopLabel = newSym(skLabel, getIdent(ctx.g.cache, ":stateLoop"), fn, fn.info) var n = n.toStmtList discard ctx.newState(n, nil) let gotoOut = newTree(nkGotoState, g.newIntLit(n.info, -1)) var ns = false n = ctx.lowerStmtListExprs(n, ns) if n.hasYieldsInExpressions(): internalError(ctx.g.config, "yield in expr not lowered") # Splitting transformation discard ctx.transformClosureIteratorBody(n, gotoOut) # Optimize empty states away ctx.deleteEmptyStates() # Make new body by concating the list of states result = newNodeI(nkStmtList, n.info) for s in ctx.states: assert(s.len == 2) let body = s[1] s.sons.del(1) result.add(s) result.add(body) result = ctx.transformStateAssignments(result) result = ctx.wrapIntoStateLoop(result) # echo "TRANSFORM TO STATES: " # echo renderTree(result) # echo "exception table:" # for i, e in ctx.exceptionTable: # echo i, " -> ", e