diff options
author | Yuriy Glukhov <yuriy.glukhov@gmail.com> | 2018-05-04 15:23:47 +0300 |
---|---|---|
committer | Yuriy Glukhov <yuriy.glukhov@gmail.com> | 2018-05-09 22:25:28 +0300 |
commit | ce634909281ffc8efbc7d192f557ffe38f49e740 (patch) | |
tree | 6d8ce63c9009e299262586774ddabbaa875b4b89 | |
parent | 0ed6c3e476e421827d081e9ab0d9fcb0d3de5eb2 (diff) | |
download | Nim-ce634909281ffc8efbc7d192f557ffe38f49e740.tar.gz |
Yield in try
-rw-r--r-- | compiler/closureiters.nim | 586 | ||||
-rw-r--r-- | compiler/semexprs.nim | 2 | ||||
-rw-r--r-- | lib/system/embedded.nim | 3 | ||||
-rw-r--r-- | lib/system/excpt.nim | 4 | ||||
-rw-r--r-- | tests/async/tasynctry2.nim | 4 | ||||
-rw-r--r-- | tests/iter/tyieldintry.nim | 201 |
6 files changed, 709 insertions, 91 deletions
diff --git a/compiler/closureiters.nim b/compiler/closureiters.nim index 7653176de..504f70347 100644 --- a/compiler/closureiters.nim +++ b/compiler/closureiters.nim @@ -59,6 +59,80 @@ # if :tmpSlLower == 2: # yield 3 +# nkTryStmt Transformations: +# If the iter has an nkTryStmt with a yield inside +# - the closure iter is promoted to have exceptions (ctx.hasExceptions = true) +# - exception table is created. This is a const array, where +# `abs(exceptionTable[i])` is a state idx to which we should jump from state +# `i` should exception be raised in state `i`. For all states in `try` block +# the target state is `except` block. For all states in `except` block +# the target state is `finally` block. For all other states there is no +# target state (0, as the first block can never be neither except nor finally). +# `exceptionTable[i]` is < 0 if `abs(exceptionTable[i])` is except block, +# and > 0, for finally block. +# - local variable :curExc is created +# - the iter body is wrapped into a +# try: +# closureIterSetupExc(:curExc) +# ...body... +# catch: +# :state = exceptionTable[:state] +# if :state == 0: raise # No state that could handle exception +# :unrollFinally = :state > 0 # Target state is finally +# if :state < 0: +# :state = -:state +# :curExc = getCurrentException() +# +# nkReturnStmt within a try/except/finally now has to behave differently as we +# want the nearest finally block to be executed before the return, thus it is +# transformed to: +# :tmpResult = returnValue (if return doesn't have a value, this is skipped) +# :unrollFinally = true +# goto nearestFinally (or -1 if not exists) +# +# Every finally block calls closureIterEndFinally() upon its successful +# completion. +# +# Example: +# +# try: +# yield 0 +# raise ... +# except: +# yield 1 +# return 3 +# finally: +# yield 2 +# +# Is transformed to (yields are left in place for example simplicity, +# in reality the code is subdivided even more, as described above): +# +# STATE0: # Try +# yield 0 +# raise ... +# :state = 2 # What would happen should we not raise +# break :stateLoop +# STATE1: # Except +# yield 1 +# :tmpResult = 3 # Return +# :unrollFinally = true # Return +# :state = 2 # Goto Finally +# break :stateLoop +# :state = 2 # What would happen should we not return +# break :stateLoop +# STATE2: # Finally +# yield 2 +# if :unrollFinally: # This node is created by `newEndFinallyNode` +# when nearestFinally == 0: # Pseudocode. The `when` is not emitted in reality +# if :curExc.isNil: +# return :tmpResult +# else: +# raise +# else: +# :state = nearestFinally +# break :stateLoop +# state = -1 # Goto next state. In this case we just exit +# break :stateLoop import intsets, strutils, options, ast, astalgo, trees, treetab, msgs, os, options, @@ -69,6 +143,10 @@ type Ctx = object fn: PSym stateVarSym: PSym # :state variable. nil if env already introduced by lambdalifting + tmpResultSym: PSym # Used when we return, but finally has to interfere + unrollFinallySym: PSym # Indicates that we're unrolling finally states (either exception happened or premature return) + curExcSym: PSym # Current exception + states: seq[PNode] # The resulting states. Every state is an nkState node. blockLevel: int # Temp used to transform break and continue stmts stateLoopLabel: PSym # Label to break on, when jumping between states. @@ -76,20 +154,66 @@ type tempVarId: int # unique name counter tempVars: PNode # Temp var decls, nkVarSection loweredStmtListExpr: PNode # Temporary used for nkStmtListExpr lowering + exceptionTable: seq[int] # For state `i` jump to state `exceptionTable[i]` if exception is raised + hasExceptions: bool # Does closure have yield in try? + curExcHandlingState: int # Negative for except, positive for finally + nearestFinally: int # Index of the nearest finally block. For try/except it + # is their finally. For finally it is parent finally. Otherwise -1 + +proc newStateAccess(ctx: var Ctx): PNode = + if ctx.stateVarSym.isNil: + result = rawIndirectAccess(newSymNode(getEnvParam(ctx.fn)), getStateField(ctx.fn), ctx.fn.info) + else: + result = newSymNode(ctx.stateVarSym) + +proc newStateAssgn(ctx: var Ctx, toValue: PNode): PNode = + # Creates state assignment: + # :state = toValue + result = newNode(nkAsgn) + result.add(ctx.newStateAccess()) + result.add(toValue) proc newStateAssgn(ctx: var Ctx, stateNo: int = -2): PNode = - # Creates state assignmen: + # Creates state assignment: # :state = stateNo + ctx.newStateAssgn(newIntTypeNode(nkIntLit, stateNo, getSysType(tyInt))) - result = newNode(nkAsgn) +proc newEnvVar(ctx: var Ctx, name: string, typ: PType): PSym = + result = newSym(skVar, getIdent(name), ctx.fn, ctx.fn.info) + result.typ = typ + + if not ctx.stateVarSym.isNil: + # We haven't gone through labmda lifting yet, so just create a local var, + # it will be lifted later + if ctx.tempVars.isNil: + ctx.tempVars = newNode(nkVarSection) + addVar(ctx.tempVars, newSymNode(result)) + else: + let envParam = getEnvParam(ctx.fn) + # let obj = envParam.typ.lastSon + result = addUniqueField(envParam.typ.lastSon, result) + +proc newEnvVarAccess(ctx: Ctx, s: PSym): PNode = if ctx.stateVarSym.isNil: - let state = getStateField(ctx.fn) - assert state != nil - result.add(rawIndirectAccess(newSymNode(getEnvParam(ctx.fn)), - state, result.info)) + result = rawIndirectAccess(newSymNode(getEnvParam(ctx.fn)), s, ctx.fn.info) else: - result.add(newSymNode(ctx.stateVarSym)) - result.add(newIntTypeNode(nkIntLit, stateNo, getSysType(tyInt))) + result = newSymNode(s) + +proc newTmpResultAccess(ctx: var Ctx): PNode = + if ctx.tmpResultSym.isNil: + debug(ctx.fn.typ) + ctx.tmpResultSym = ctx.newEnvVar(":tmpResult", ctx.fn.typ[0]) + ctx.newEnvVarAccess(ctx.tmpResultSym) + +proc newUnrollFinallyAccess(ctx: var Ctx): PNode = + if ctx.unrollFinallySym.isNil: + ctx.unrollFinallySym = ctx.newEnvVar(":unrollFinally", getSysType(tyBool)) + ctx.newEnvVarAccess(ctx.unrollFinallySym) + +proc newCurExcAccess(ctx: var Ctx): PNode = + if ctx.curExcSym.isNil: + ctx.curExcSym = ctx.newEnvVar(":curExc", callCodegenProc("getCurrentException", emptyNode).typ) + ctx.newEnvVarAccess(ctx.curExcSym) proc setStateInAssgn(stateAssgn: PNode, stateNo: int) = assert stateAssgn.kind == nkAsgn @@ -107,6 +231,8 @@ proc newState(ctx: var Ctx, n, gotoOut: PNode): int = s.add(resLit) s.add(n) ctx.states.add(s) + ctx.exceptionTable.add(ctx.curExcHandlingState) + if not gotoOut.isNil: assert(gotoOut.len == 0) gotoOut.add(newIntLit(result)) @@ -119,27 +245,13 @@ proc toStmtList(n: PNode): PNode = proc addGotoOut(n: PNode, gotoOut: PNode): PNode = # Make sure `n` is a stmtlist, and ends with `gotoOut` - result = toStmtList(n) if result.len != 0 and result.sons[^1].kind != nkGotoState: result.add(gotoOut) proc newTempVarAccess(ctx: var Ctx, typ: PType, i: TLineInfo): PNode = - let s = newSym(skVar, getIdent(":tmpSlLower" & $ctx.tempVarId), ctx.fn, i) - s.typ = typ - - if not ctx.stateVarSym.isNil: - # We haven't gone through labmda lifting yet, so just create a local var, - # it will be lifted later - if ctx.tempVars.isNil: - ctx.tempVars = newNode(nkVarSection) - addVar(ctx.tempVars, newSymNode(s)) - - result = newSymNode(s) - else: - # Lambda lifting is done, insert temp var to env. - result = freshVarForClosureIter(s, ctx.fn) - + let s = ctx.newEnvVar(":tmpSlLower" & $ctx.tempVarId, typ) + result = ctx.newEnvVarAccess(s) inc ctx.tempVarId proc hasYields(n: PNode): bool = @@ -197,7 +309,17 @@ proc transformBreaksInBlock(ctx: var Ctx, n: PNode, label, after: PNode): PNode for i in 0 ..< n.len: n[i] = ctx.transformBreaksInBlock(n[i], label, after) -proc collectExceptState(n: PNode): PNode = +proc newNullifyCurExc(ctx: var Ctx): PNode = + # :curEcx = nil + result = newNode(nkAsgn) + let curExc = ctx.newCurExcAccess() + result.add(curExc) + + let nilnode = newNode(nkNilLit) + nilnode.typ = curExc.typ + result.add(nilnode) + +proc collectExceptState(ctx: var Ctx, n: PNode): PNode = var ifStmt = newNode(nkIfStmt) for c in n: if c.kind == nkExceptBranch: @@ -208,8 +330,10 @@ proc collectExceptState(n: PNode): PNode = assert(c.len == 2) ifBranch = newNode(nkElifBranch) let expression = newNodeI(nkCall, n.info) + expression.add(newSymNode(getSysMagic("of", mOf))) expression.add(callCodegenProc("getCurrentException", emptyNode)) expression.add(c[0]) + expression.typ = getSysType(tyBool) ifBranch.add(expression) branchBody = c[1] else: @@ -226,10 +350,37 @@ proc collectExceptState(n: PNode): PNode = if ifStmt.len != 0: result = newNode(nkStmtList) + result.add(ctx.newNullifyCurExc()) result.add(ifStmt) else: result = emptyNode +proc addElseToExcept(ctx: var Ctx, n: PNode) = + if n.kind == nkStmtList and n[1].kind == nkIfStmt and n[1][^1].kind != nkElse: + # Not all cases are covered + let elseBranch = newNode(nkElse) + let branchBody = newNode(nkStmtList) + + block: # :unrollFinally = true + let asgn = newNode(nkAsgn) + asgn.add(ctx.newUnrollFinallyAccess()) + asgn.add(newIntTypeNode(nkIntLit, 1, getSysType(tyBool))) + branchBody.add(asgn) + + block: # :curExc = getCurrentException() + let asgn = newNode(nkAsgn) + asgn.add(ctx.newCurExcAccess) + asgn.add(callCodegenProc("getCurrentException", emptyNode)) + branchBody.add(asgn) + + block: # goto nearestFinally + let goto = newNode(nkGotoState) + goto.add(newIntLit(ctx.nearestFinally)) + branchBody.add(goto) + + elseBranch.add(branchBody) + n[1].add(elseBranch) + proc getFinallyNode(n: PNode): PNode = result = n[^1] if result.kind == nkFinally: @@ -273,6 +424,101 @@ proc lowerStmtListExpr(ctx: var Ctx, n: PNode): PNode = for i in 0 ..< n.len: n[i] = ctx.lowerStmtListExpr(n[i]) +proc newEndFinallyNode(ctx: var Ctx): PNode = + # Generate the following code: + # if :unrollFinally: + # when nearestFinally == 0: # Pseudocode. The `when` is not emitted in reality + # if :curExc.isNil: + # return :tmpResult + # else: + # raise + # else: + # goto nearestFinally + # :state = nearestFinally + # break :stateLoop + + result = newNode(nkIfStmt) + + let elifBranch = newNode(nkElifBranch) + elifBranch.add(ctx.newUnrollFinallyAccess()) + result.add(elifBranch) + + var ifBody: PNode + + if ctx.nearestFinally == 0 or true: + ifBody = newNode(nkIfStmt) + let branch = newNode(nkElifBranch) + + let cmp = newNode(nkCall) + cmp.add(getSysMagic("==", mEqRef).newSymNode) + let curExc = ctx.newCurExcAccess() + let nilnode = newNode(nkNilLit) + nilnode.typ = curExc.typ + cmp.add(curExc) + cmp.add(nilnode) + cmp.typ = getSysType(tyBool) + branch.add(cmp) + + var retStmt = newNode(nkReturnStmt) + if true: + var a = newNode(nkAsgn) + addSon(a, newSymNode(getClosureIterResult(ctx.fn))) + addSon(a, ctx.newTmpResultAccess()) + retStmt.add(a) + else: + retStmt.add(emptyNode) + branch.add(retStmt) + + let elseBranch = newNode(nkElse) + let raiseStmt = newNode(nkRaiseStmt) + + # The C++ backend requires `getCurrentException` here. + raiseStmt.add(callCodegenProc("getCurrentException", emptyNode)) + elseBranch.add(raiseStmt) + + ifBody.add(branch) + ifBody.add(elseBranch) + else: + ifBody = newNode(nkGotoState) + ifBody.add(newIntLit(ctx.nearestFinally)) + + elifBranch.add(ifBody) + +proc transformReturnsInTry(ctx: var Ctx, n: PNode): PNode = + result = n + # TODO: This is very inefficient. It traverses the node, looking for nkYieldStmt. + case n.kind + of nkReturnStmt: + # We're somewhere in try, transform to finally unrolling + assert(ctx.nearestFinally != 0) + + result = newNodeI(nkStmtList, n.info) + + block: # :unrollFinally = true + let asgn = newNodeI(nkAsgn, n.info) + asgn.add(ctx.newUnrollFinallyAccess()) + asgn.add(newIntTypeNode(nkIntLit, 1, getSysType(tyBool))) + result.add(asgn) + + if n[0].kind != nkEmpty: # TODO: And not void! + let asgnTmpResult = newNodeI(nkAsgn, n.info) + asgnTmpResult.add(ctx.newTmpResultAccess()) + asgnTmpResult.add(n[0]) + result.add(asgnTmpResult) + + result.add(ctx.newNullifyCurExc()) + + let goto = newNodeI(nkGotoState, n.info) + goto.add(newIntLit(ctx.nearestFinally)) + result.add(goto) + + of nkCharLit..nkUInt64Lit, nkFloatLit..nkFloat128Lit, nkStrLit..nkTripleStrLit, + nkSym, nkIdent, procDefs, nkTemplateDef: + discard + else: + for i in 0 ..< n.len: + n[i] = ctx.transformReturnsInTry(n[i]) + proc transformClosureIteratorBody(ctx: var Ctx, n: PNode, gotoOut: PNode): PNode = result = n case n.kind: @@ -310,7 +556,6 @@ proc transformClosureIteratorBody(ctx: var Ctx, n: PNode, gotoOut: PNode): PNode assert(false, "nkStmtListExpr not lowered") of nkYieldStmt: - # echo "YIELD!" result = newNodeI(nkStmtList, n.info) result.add(n) result.add(gotoOut) @@ -371,34 +616,64 @@ proc transformClosureIteratorBody(ctx: var Ctx, n: PNode, gotoOut: PNode): PNode result[1] = ctx.transformClosureIteratorBody(result[1], gotoOut) of nkTryStmt: - var tryBody = toStmtList(n[0]) + # See explanation above about how this works + ctx.hasExceptions = true - # let popTry = newNode(nkPar) - # popTry.add(newIdentNode(getIdent("popTry"), n.info)) + result = newNode(nkGotoState) + var tryBody = toStmtList(n[0]) + var exceptBody = ctx.collectExceptState(n) var finallyBody = newNode(nkStmtList) - # finallyBody.add(popTry) finallyBody.add(getFinallyNode(n)) + finallyBody = ctx.transformReturnsInTry(finallyBody) + finallyBody.add(ctx.newEndFinallyNode()) - var tryCatchOut = newNode(nkGotoState) + # The following index calculation is based on the knowledge how state + # indexes are assigned + let tryIdx = ctx.states.len + var exceptIdx, finallyIdx: int + if exceptBody.kind != nkEmpty: + exceptIdx = -(tryIdx + 1) + finallyIdx = tryIdx + 2 + else: + exceptIdx = tryIdx + 1 + finallyIdx = tryIdx + 1 - tryBody = ctx.transformClosureIteratorBody(tryBody, tryCatchOut) - var exceptBody = collectExceptState(n) + let outToFinally = newNode(nkGotoState) - var exceptIdx = -1 - if exceptBody.kind != nkEmpty: - exceptBody = ctx.transformClosureIteratorBody(exceptBody, tryCatchOut) - exceptIdx = ctx.newState(exceptBody, nil) + block: # Create initial states. + let oldExcHandlingState = ctx.curExcHandlingState + ctx.curExcHandlingState = exceptIdx + let realTryIdx = ctx.newState(tryBody, result) + assert(realTryIdx == tryIdx) + + if exceptBody.kind != nkEmpty: + ctx.curExcHandlingState = finallyIdx + let realExceptIdx = ctx.newState(exceptBody, nil) + assert(realExceptIdx == -exceptIdx) - finallyBody = ctx.transformClosureIteratorBody(finallyBody, gotoOut) - let finallyIdx = ctx.newState(finallyBody, tryCatchOut) + ctx.curExcHandlingState = oldExcHandlingState + let realFinallyIdx = ctx.newState(finallyBody, outToFinally) + assert(realFinallyIdx == finallyIdx) - # let pushTry = newNode(nkPar) #newCall(newSym("pushTry"), newIntLit(exceptIdx)) - # pushTry.add(newIdentNode(getIdent("pushTry"), n.info)) - # pushTry.add(newIntLit(exceptIdx)) - # pushTry.add(newIntLit(finallyIdx)) - # tryBody.sons.insert(pushTry, 0) + block: # Subdivide the states + let oldNearestFinally = ctx.nearestFinally + ctx.nearestFinally = finallyIdx - result = tryBody + let oldExcHandlingState = ctx.curExcHandlingState + + ctx.curExcHandlingState = exceptIdx + + discard ctx.transformReturnsInTry(tryBody) + discard ctx.transformClosureIteratorBody(tryBody, outToFinally) + + ctx.curExcHandlingState = finallyIdx + ctx.addElseToExcept(exceptBody) + discard ctx.transformReturnsInTry(exceptBody) + discard ctx.transformClosureIteratorBody(exceptBody, outToFinally) + + ctx.curExcHandlingState = oldExcHandlingState + ctx.nearestFinally = oldNearestFinally + discard ctx.transformClosureIteratorBody(finallyBody, gotoOut) of nkGotoState, nkForStmt: internalError("closure iter " & $n.kind) @@ -460,7 +735,6 @@ proc tranformStateAssignments(ctx: var Ctx, n: PNode): PNode = discard of nkReturnStmt: - result = newNodeI(nkStmtList, n.info) result.add(ctx.newStateAssgn(-1)) result.add(n) @@ -483,6 +757,29 @@ proc skipStmtList(n: PNode): PNode = if result.len == 0: return emptyNode result = result[0] +proc skipEmptyStates(ctx: Ctx, stateIdx: int): int = + # Returns first non-empty state idx for `stateIdx`. Returns `stateIdx` if + # it is not empty + var maxJumps = ctx.states.len # maxJumps used only for debugging purposes. + var stateIdx = stateIdx + while true: + let label = stateIdx + if label == ctx.exitStateIdx: break + var newLabel = label + if label == -1: + newLabel = ctx.exitStateIdx + else: + let fs = ctx.states[label][1].skipStmtList() + if fs.kind == nkGotoState: + newLabel = fs[0].intVal.int + if label == newLabel: break + stateIdx = newLabel + dec maxJumps + if maxJumps == 0: + assert(false, "Internal error") + + result = ctx.states[stateIdx][0].intVal.int + proc skipThroughEmptyStates(ctx: var Ctx, n: PNode): PNode = result = n case n.kind @@ -490,31 +787,143 @@ proc skipThroughEmptyStates(ctx: var Ctx, n: PNode): PNode = nkSym, nkIdent, procDefs, nkTemplateDef: discard of nkGotoState: - var maxJumps = ctx.states.len # maxJumps used only for debugging purposes. result = copyTree(n) - while true: - let label = result[0].intVal.int - if label == ctx.exitStateIdx: break - var newLabel = label - if label == -1: - newLabel = ctx.exitStateIdx - else: - let fs = ctx.states[label][1].skipStmtList() - if fs.kind == nkGotoState: - newLabel = fs[0].intVal.int - if label == newLabel: break - result[0].intVal = newLabel - dec maxJumps - if maxJumps == 0: - assert(false, "Internal error") - - let label = result[0].intVal.int - result[0].intVal = ctx.states[label][0].intVal + result[0].intVal = ctx.skipEmptyStates(result[0].intVal.int) else: for i in 0 ..< n.len: n[i] = ctx.skipThroughEmptyStates(n[i]) +proc newArrayType(n: int, t: PType, owner: PSym): PType = + result = newType(tyArray, owner) + + let rng = newType(tyRange, owner) + rng.n = newNode(nkRange) + rng.n.add(newIntLit(0)) + rng.n.add(newIntLit(n)) + rng.rawAddSon(t) + + result.rawAddSon(rng) + result.rawAddSon(t) + +proc createExceptionTable(ctx: var Ctx): PNode = + result = newNode(nkBracket) + result.typ = newArrayType(ctx.exceptionTable.len, getSysType(tyInt16), ctx.fn) + + for i in ctx.exceptionTable: + let elem = newIntNode(nkIntLit, i) + elem.typ = getSysType(tyInt16) + result.add(elem) + +proc newCatchBody(ctx: var Ctx): PNode {.inline.} = + # Generates the code: + # :state = exceptionTable[:state] + # if :state == 0: raise + # :unrollFinally = :state > 0 + # if :state < 0: + # :state = -:state + # :curExc = getCurrentException() + + result = newNode(nkStmtList) + + # :state = exceptionTable[:state] + block: + + # exceptionTable[:state] + let getNextState = newNode(nkBracketExpr) + getNextState.add(ctx.createExceptionTable) + getNextState.add(ctx.newStateAccess()) + getNextState.typ = getSysType(tyInt) + + # :state = exceptionTable[:state] + result.add(ctx.newStateAssgn(getNextState)) + + # if :state == 0: raise + block: + let ifStmt = newNode(nkIfStmt) + let ifBranch = newNode(nkElifBranch) + let cond = newNode(nkCall) + cond.add(getSysMagic("==", mEqI).newSymNode) + cond.add(ctx.newStateAccess()) + cond.add(newIntTypeNode(nkIntLit, 0, getSysType(tyInt))) + cond.typ = getSysType(tyBool) + ifBranch.add(cond) + + let raiseStmt = newNode(nkRaiseStmt) + raiseStmt.add(emptyNode) + + ifBranch.add(raiseStmt) + ifStmt.add(ifBranch) + result.add(ifStmt) + + # :unrollFinally = :state > 0 + block: + let asgn = newNode(nkAsgn) + asgn.add(ctx.newUnrollFinallyAccess()) + + let cond = newNode(nkCall) + cond.add(getSysMagic("<", mLtI).newSymNode) + cond.add(newIntTypeNode(nkIntLit, 0, getSysType(tyInt))) + cond.add(ctx.newStateAccess()) + cond.typ = getSysType(tyBool) + asgn.add(cond) + result.add(asgn) + + # if :state < 0: :state = -:state + block: + let ifStmt = newNode(nkIfStmt) + let ifBranch = newNode(nkElifBranch) + let cond = newNode(nkCall) + cond.add(getSysMagic("<", mLtI).newSymNode) + cond.add(ctx.newStateAccess()) + cond.add(newIntTypeNode(nkIntLit, 0, getSysType(tyInt))) + cond.typ = getSysType(tyBool) + ifBranch.add(cond) + + let negateState = newNode(nkCall) + negateState.add(getSysMagic("-", mUnaryMinusI).newSymNode) + negateState.add(ctx.newStateAccess()) + negateState.typ = getSysType(tyInt) + + ifBranch.add(ctx.newStateAssgn(negateState)) + ifStmt.add(ifBranch) + result.add(ifStmt) + + # :curExc = getCurrentException() + block: + let getCurExc = callCodegenProc("getCurrentException", emptyNode) + let asgn = newNode(nkAsgn) + asgn.add(ctx.newCurExcAccess()) + asgn.add(getCurExc) + result.add(asgn) + +proc wrapIntoTryExcept(ctx: var Ctx, n: PNode): PNode = + result = newNode(nkTryStmt) + + let tryBody = newNode(nkStmtList) + + let setupExc = newNode(nkCall) + setupExc.add(newSymNode(getCompilerProc("closureIterSetupExc"))) + + tryBody.add(setupExc) + + tryBody.add(n) + result.add(tryBody) + + let catchNode = newNode(nkExceptBranch) + result.add(catchNode) + + let catchBody = newNode(nkStmtList) + catchBody.add(ctx.newCatchBody()) + catchNode.add(catchBody) + + setupExc.add(ctx.newCurExcAccess()) + proc wrapIntoStateLoop(ctx: var Ctx, n: PNode): PNode = + # while true: + # block :stateLoop: + # gotoState :state + # body # Might get wrapped in try-except + result = newNode(nkWhileStmt) result.add(newSymNode(getSysSym("true"))) @@ -532,19 +941,19 @@ proc wrapIntoStateLoop(ctx: var Ctx, n: PNode): PNode = let blockStmt = newNodeI(nkBlockStmt, n.info) blockStmt.add(newSymNode(ctx.stateLoopLabel)) - let blockBody = newNodeI(nkStmtList, n.info) - blockStmt.add(blockBody) + var blockBody = newNodeI(nkStmtList, n.info) let gs = newNodeI(nkGotoState, n.info) - if ctx.stateVarSym.isNil: - gs.add(rawIndirectAccess(newSymNode(getEnvParam(ctx.fn)), getStateField(ctx.fn), n.info)) - else: - gs.add(newSymNode(ctx.stateVarSym)) - + gs.add(ctx.newStateAccess()) gs.add(newIntLit(ctx.states.len - 1)) + blockBody.add(gs) blockBody.add(n) - # gs.add(rawIndirectAccess(newSymNode(ctx.fn.getHiddenParam), getStateField(ctx.fn), n.info)) + + if ctx.hasExceptions: + blockBody = ctx.wrapIntoTryExcept(blockBody) + + blockStmt.add(blockBody) loopBody.add(blockStmt) @@ -558,7 +967,7 @@ proc deleteEmptyStates(ctx: var Ctx) = var iValid = 0 for i, s in ctx.states: let body = s[1].skipStmtList() - if body.kind == nkGotoState and i != ctx.states.len - 1: + if body.kind == nkGotoState and i != ctx.states.len - 1 and i != 0: # This is an empty state. Mark with -1. s[0].intVal = -1 else: @@ -567,14 +976,20 @@ proc deleteEmptyStates(ctx: var Ctx) = for i, s in ctx.states: let body = s[1].skipStmtList() - if body.kind != nkGotoState: + if body.kind != nkGotoState or i == 0: discard ctx.skipThroughEmptyStates(s) + let excHandlState = ctx.exceptionTable[i] + if excHandlState < 0: + ctx.exceptionTable[i] = -ctx.skipEmptyStates(-excHandlState) + elif excHandlState != 0: + ctx.exceptionTable[i] = ctx.skipEmptyStates(excHandlState) var i = 0 while i < ctx.states.len - 1: let fs = ctx.states[i][1].skipStmtList() - if fs.kind == nkGotoState: + if fs.kind == nkGotoState and i != 0: ctx.states.delete(i) + ctx.exceptionTable.delete(i) else: inc i @@ -591,6 +1006,7 @@ proc transformClosureIterator*(fn: PSym, n: PNode): PNode = ctx.states = @[] ctx.stateLoopLabel = newSym(skLabel, getIdent(":stateLoop"), fn, fn.info) + ctx.exceptionTable = @[] let n = n.toStmtList discard ctx.newState(n, nil) @@ -613,19 +1029,11 @@ proc transformClosureIterator*(fn: PSym, n: PNode): PNode = result.add(body) result = ctx.tranformStateAssignments(result) - - # Add excpetion handling - var hasExceptions = false - if hasExceptions: - discard # TODO: - # result = wrapIntoTryCatch(result) - - # while true: - # block :stateLoop: - # gotoState - # body result = ctx.wrapIntoStateLoop(result) - # echo "TRANSFORM TO STATES2: " - # debug(result) + # echo "TRANSFORM TO STATES: " # echo renderTree(result) + + # echo "exception table:" + # for i, e in ctx.exceptionTable: + # echo i, " -> ", e diff --git a/compiler/semexprs.nim b/compiler/semexprs.nim index 1ef284a77..79010bfde 100644 --- a/compiler/semexprs.nim +++ b/compiler/semexprs.nim @@ -1544,7 +1544,7 @@ proc semYield(c: PContext, n: PNode): PNode = checkSonsLen(n, 1) if c.p.owner == nil or c.p.owner.kind != skIterator: localError(n.info, errYieldNotAllowedHere) - elif c.p.inTryStmt > 0 and c.p.owner.typ.callConv != ccInline: + elif oldIterTransf in c.features and c.p.inTryStmt > 0 and c.p.owner.typ.callConv != ccInline: localError(n.info, errYieldNotAllowedInTryStmt) elif n.sons[0].kind != nkEmpty: n.sons[0] = semExprWithType(c, n.sons[0]) # check for type compatibility: diff --git a/lib/system/embedded.nim b/lib/system/embedded.nim index 46e84e056..4d453fcca 100644 --- a/lib/system/embedded.nim +++ b/lib/system/embedded.nim @@ -41,3 +41,6 @@ proc reraiseException() {.compilerRtl.} = proc writeStackTrace() = discard proc setControlCHook(hook: proc () {.noconv.}) = discard + +proc closureIterSetupExc(e: ref Exception) {.compilerproc, inline.} = + sysFatal(ReraiseError, "exception handling is not available") diff --git a/lib/system/excpt.nim b/lib/system/excpt.nim index fb38948f7..dabfe010e 100644 --- a/lib/system/excpt.nim +++ b/lib/system/excpt.nim @@ -131,6 +131,10 @@ proc popCurrentExceptionEx(id: uint) {.compilerRtl.} = quitOrDebug() prev.up = cur.up +proc closureIterSetupExc(e: ref Exception) {.compilerproc, inline.} = + if not e.isNil: + currException = e + # some platforms have native support for stack traces: const nativeStackTraceSupported* = (defined(macosx) or defined(linux)) and diff --git a/tests/async/tasynctry2.nim b/tests/async/tasynctry2.nim index 444a058be..f82b6cfe0 100644 --- a/tests/async/tasynctry2.nim +++ b/tests/async/tasynctry2.nim @@ -1,10 +1,12 @@ discard """ file: "tasynctry2.nim" errormsg: "\'yield\' cannot be used within \'try\' in a non-inlined iterator" - line: 15 + line: 17 """ import asyncdispatch +{.experimental: "oldIterTransf".} + proc foo(): Future[bool] {.async.} = discard proc test5(): Future[int] {.async.} = diff --git a/tests/iter/tyieldintry.nim b/tests/iter/tyieldintry.nim new file mode 100644 index 000000000..9cb199c5b --- /dev/null +++ b/tests/iter/tyieldintry.nim @@ -0,0 +1,201 @@ +discard """ +targets: "c cpp" +output: "ok" +""" +var closureIterResult = newSeq[int]() + +proc checkpoint(arg: int) = + closureIterResult.add(arg) + +type + TestException = object of Exception + AnotherException = object of Exception + +proc testClosureIterAux(it: iterator(): int, exceptionExpected: bool, expectedResults: varargs[int]) = + closureIterResult.setLen(0) + + var exceptionCaught = false + + try: + for i in it(): + closureIterResult.add(i) + except TestException: + exceptionCaught = true + + if closureIterResult != @expectedResults or exceptionCaught != exceptionExpected: + if closureIterResult != @expectedResults: + echo "Expected: ", @expectedResults + echo "Actual: ", closureIterResult + if exceptionCaught != exceptionExpected: + echo "Expected exception: ", exceptionExpected + echo "Got exception: ", exceptionCaught + doAssert(false) + +proc test(it: iterator(): int, expectedResults: varargs[int]) = + testClosureIterAux(it, false, expectedResults) + +proc testExc(it: iterator(): int, expectedResults: varargs[int]) = + testClosureIterAux(it, true, expectedResults) + +proc raiseException() = + raise newException(TestException, "Test exception!") + +block: + iterator it(): int {.closure.} = + var i = 5 + while i != 0: + yield i + if i == 3: + yield 123 + dec i + + test(it, 5, 4, 3, 123, 2, 1) + +block: + iterator it(): int {.closure.} = + yield 0 + try: + checkpoint(1) + raiseException() + except TestException: + checkpoint(2) + yield 3 + checkpoint(4) + finally: + checkpoint(5) + + checkpoint(6) + + test(it, 0, 1, 2, 3, 4, 5, 6) + +block: + iterator it(): int {.closure.} = + yield 0 + try: + yield 1 + checkpoint(2) + finally: + checkpoint(3) + yield 4 + checkpoint(5) + yield 6 + + test(it, 0, 1, 2, 3, 4, 5, 6) + +block: + iterator it(): int {.closure.} = + yield 0 + try: + yield 1 + raiseException() + yield 2 + finally: + checkpoint(3) + yield 4 + checkpoint(5) + yield 6 + + testExc(it, 0, 1, 3, 4, 5, 6) + +block: + iterator it(): int {.closure.} = + try: + try: + raiseException() + except AnotherException: + yield 123 + finally: + checkpoint(3) + finally: + checkpoint(4) + + testExc(it, 3, 4) + +block: + iterator it(): int {.closure.} = + try: + yield 1 + raiseException() + except AnotherException: + checkpoint(123) + finally: + checkpoint(2) + checkpoint(3) + + testExc(it, 1, 2) + +block: + iterator it(): int {.closure.} = + try: + yield 0 + try: + yield 1 + try: + yield 2 + raiseException() + except AnotherException: + yield 123 + finally: + yield 3 + except AnotherException: + yield 124 + finally: + yield 4 + checkpoint(1234) + except: + yield 5 + checkpoint(6) + finally: + checkpoint(7) + yield 8 + checkpoint(9) + + test(it, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9) + +block: + iterator it(): int {.closure.} = + try: + yield 0 + return 2 + finally: + checkpoint(1) + checkpoint(123) + + test(it, 0, 1) + +block: + iterator it(): int {.closure.} = + try: + try: + yield 0 + raiseException() + finally: + checkpoint(1) + except TestException: + yield 2 + return + finally: + yield 3 + + checkpoint(123) + + test(it, 0, 1, 2, 3) + +block: + iterator it(): int {.closure.} = + try: + try: + yield 0 + raiseException() + finally: + return # Return in finally should stop exception propagation + except AnotherException: + yield 2 + return + finally: + yield 3 + checkpoint(123) + + test(it, 0, 3) + +echo "ok" |