summary refs log tree commit diff stats
diff options
context:
space:
mode:
authorYuriy Glukhov <yuriy.glukhov@gmail.com>2018-05-08 01:25:08 +0300
committerYuriy Glukhov <yuriy.glukhov@gmail.com>2018-05-09 22:25:28 +0300
commit14ca79fe1f1fafb8e3aff2e4c27bcb94c0595792 (patch)
tree80eba539abe709a759fc2853cbf080c0a44c6c78
parentac86b8ce615cbd55074dfd27f42ed0368d84b1fd (diff)
downloadNim-14ca79fe1f1fafb8e3aff2e4c27bcb94c0595792.tar.gz
More elaborate nkStmtListExpr lowering
-rw-r--r--compiler/closureiters.nim404
-rw-r--r--tests/iter/tyieldintry.nim459
2 files changed, 687 insertions, 176 deletions
diff --git a/compiler/closureiters.nim b/compiler/closureiters.nim
index a8e7e0274..a30b4e10e 100644
--- a/compiler/closureiters.nim
+++ b/compiler/closureiters.nim
@@ -149,7 +149,6 @@ type
     exitStateIdx: int # index of the last state
     tempVarId: int # unique name counter
     tempVars: PNode # Temp var decls, nkVarSection
-    loweredStmtListExpr: PNode # Temporary used for nkStmtListExpr lowering
     exceptionTable: seq[int] # For state `i` jump to state `exceptionTable[i]` if exception is raised
     hasExceptions: bool # Does closure have yield in try?
     curExcHandlingState: int # Negative for except, positive for finally
@@ -177,6 +176,7 @@ proc newStateAssgn(ctx: var Ctx, stateNo: int = -2): PNode =
 proc newEnvVar(ctx: var Ctx, name: string, typ: PType): PSym =
   result = newSym(skVar, getIdent(name), ctx.fn, ctx.fn.info)
   result.typ = typ
+  assert(not typ.isNil)
 
   if not ctx.stateVarSym.isNil:
     # We haven't gone through labmda lifting yet, so just create a local var,
@@ -197,7 +197,6 @@ proc newEnvVarAccess(ctx: Ctx, s: PSym): PNode =
 
 proc newTmpResultAccess(ctx: var Ctx): PNode =
   if ctx.tmpResultSym.isNil:
-    debug(ctx.fn.typ)
     ctx.tmpResultSym = ctx.newEnvVar(":tmpResult", ctx.fn.typ[0])
   ctx.newEnvVarAccess(ctx.tmpResultSym)
 
@@ -245,9 +244,8 @@ proc addGotoOut(n: PNode, gotoOut: PNode): PNode =
   if result.len != 0 and result.sons[^1].kind != nkGotoState:
     result.add(gotoOut)
 
-proc newTempVarAccess(ctx: var Ctx, typ: PType, i: TLineInfo): PNode =
-  let s = ctx.newEnvVar(":tmpSlLower" & $ctx.tempVarId, typ)
-  result = ctx.newEnvVarAccess(s)
+proc newTempVar(ctx: var Ctx, typ: PType): PSym =
+  result = ctx.newEnvVar(":tmpSlLower" & $ctx.tempVarId, typ)
   inc ctx.tempVarId
 
 proc hasYields(n: PNode): bool =
@@ -390,35 +388,382 @@ proc hasYieldsInExpressions(n: PNode): bool =
       nkSym, nkIdent, procDefs, nkTemplateDef:
     discard
   of nkStmtListExpr:
-    result = n.hasYields
-  of nkStmtList, nkWhileStmt, nkCaseStmt, nkIfStmt:
-    discard
+    if isEmptyType(n.typ):
+      for c in n:
+        if c.hasYieldsInExpressions:
+          return true
+    else:
+      result = n.hasYields
   else:
     for c in n:
       if c.hasYieldsInExpressions:
         return true
 
-proc lowerStmtListExpr(ctx: var Ctx, n: PNode): PNode =
+proc exprToStmtList(n: PNode): tuple[s, res: PNode] =
+  assert(n.kind == nkStmtListExpr)
+
+  var parent = n
+  var lastSon = n[^1]
+
+  while lastSon.kind == nkStmtListExpr:
+    parent = lastSon
+    lastSon = lastSon[^1]
+
+  result.s = newNodeI(nkStmtList, n.info)
+  result.s.sons = parent.sons
+  result.s.sons.setLen(result.s.sons.len - 1) # delete last son
+  result.res = lastSon
+
+proc newEnvVarAsgn(ctx: Ctx, s: PSym, v: PNode): PNode =
+  result = newNode(nkFastAsgn)
+  result.add(ctx.newEnvVarAccess(s))
+  result.add(v)
+
+proc addExprAssgn(ctx: Ctx, output, input: PNode, sym: PSym) =
+  if input.kind == nkStmtListExpr:
+    let (st, res) = exprToStmtList(input)
+    output.add(st)
+    output.add(ctx.newEnvVarAsgn(sym, res))
+  else:
+    output.add(ctx.newEnvVarAsgn(sym, input))
+
+proc convertExprBodyToAsgn(ctx: Ctx, exprBody: PNode, res: PSym): PNode =
+  result = newNode(nkStmtList)
+  ctx.addExprAssgn(result, exprBody, res)
+
+proc newNotCall(e: PNode): PNode =
+  result = newNode(nkCall)
+  result.add(newSymNode(getSysMagic("not", mNot)))
+  result.add(e)
+  result.typ = getSysType(tyBool)
+
+proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode =
   result = n
   case n.kind
   of nkCharLit..nkUInt64Lit, nkFloatLit..nkFloat128Lit, nkStrLit..nkTripleStrLit,
       nkSym, nkIdent, procDefs, nkTemplateDef:
     discard
-  of nkStmtListExpr:
-    if n.hasYields:
-      for i in 0 .. n.len - 2:
-        ctx.loweredStmtListExpr.add(n[i])
 
-      let tv = ctx.newTempVarAccess(n.typ, n[^1].info)
-      let asgn = newNode(nkAsgn)
-      asgn.add(tv)
-      asgn.add(n[^1])
-      ctx.loweredStmtListExpr.add(asgn)
-      result = tv
+  of nkYieldStmt:
+    var ns = false
+    for i in 0 ..< n.len:
+      n[i] = ctx.lowerStmtListExprs(n[i], ns)
+
+    if ns:
+      assert(n[0].kind == nkStmtListExpr)
+      result = newNodeI(nkStmtList, n.info)
+      let (st, ex) = exprToStmtList(n[0])
+      result.add(st)
+      n[0] = ex
+      result.add(n)
+
+    needsSplit = true
+
+  of nkPar, nkObjConstr, nkTupleConstr, nkBracket, nkArgList:
+    var ns = false
+    for i in 0 ..< n.len:
+      n[i] = ctx.lowerStmtListExprs(n[i], ns)
+
+    if ns:
+      needsSplit = true
+
+      result = newNodeI(nkStmtListExpr, n.info)
+      if n.typ.isNil: internalError("lowerStmtListExprs: constr typ.isNil")
+      result.typ = n.typ
+
+      for i in 0 ..< n.len:
+        if n[i].kind == nkStmtListExpr:
+          let (st, ex) = exprToStmtList(n[i])
+          result.add(st)
+          n[i] = ex
+      result.add(n)
+
+  of nkIfStmt, nkIfExpr:
+    var ns = false
+    for i in 0 ..< n.len:
+      n[i] = ctx.lowerStmtListExprs(n[i], ns)
+
+    if ns:
+      needsSplit = true
+      var tmp: PSym
+      var s: PNode
+      let isExpr = not isEmptyType(n.typ)
+      if isExpr:
+        tmp = ctx.newTempVar(n.typ)
+        result = newNode(nkStmtListExpr)
+        result.typ = n.typ
+      else:
+        result = newNode(nkStmtList)
+
+      var curS = result
+
+      for branch in n:
+        case branch.kind
+        of nkElseExpr, nkElse:
+          if isExpr:
+            var newBranch = newNodeI(nkElse, branch.info)
+            let branchBody = newNode(nkStmtList)
+            ctx.addExprAssgn(branchBody, branch[0], tmp)
+            newBranch.add(branchBody)
+            curS.add(newBranch)
+          else:
+            curS.add(branch)
+
+        of nkElifExpr, nkElifBranch:
+          var newBranch: PNode
+          if branch[0].kind == nkStmtListExpr:
+            let elseBody = newNode(nkStmtList)
+
+            let (st, res) = exprToStmtList(branch[0])
+            elseBody.add(st)
+
+            newBranch = newNodeI(nkElifBranch, branch.info)
+            newBranch.add(res)
+            newBranch.add(branch[1])
+
+            let newIf = newNodeI(nkIfStmt, branch.info)
+            newIf.add(newBranch)
+            elseBody.add(newIf)
+            if curS.kind == nkIfStmt:
+              let newElse = newNodeI(nkElse, branch.info)
+              newElse.add(elseBody)
+              curS.add(newElse)
+            else:
+              curS.add(elseBody)
+            curS = newIf
+          else:
+            newBranch = branch
+            if curS.kind == nkIfStmt:
+              curS.add(newBranch)
+            else:
+              let newIf = newNodeI(nkIfStmt, branch.info)
+              newIf.add(newBranch)
+              curS.add(newIf)
+              curS = newIf
+
+          if isExpr:
+            let branchBody = newNode(nkStmtList)
+            ctx.addExprAssgn(branchBody, branch[1], tmp)
+            newBranch[1] = branchBody
+
+        else:
+          internalError("lowerStmtListExpr(nkIf): " & $branch.kind)
+
+      if isExpr: result.add(ctx.newEnvVarAccess(tmp))
+
+  of nkTryStmt:
+    var ns = false
+    for i in 0 ..< n.len:
+      n[i] = ctx.lowerStmtListExprs(n[i], ns)
+
+    if ns:
+      needsSplit = true
+      let isExpr = not isEmptyType(n.typ)
+
+      if isExpr:
+        result = newNodeI(nkStmtListExpr, n.info)
+        result.typ = n.typ
+        let tmp = ctx.newTempVar(n.typ)
+
+        n[0] = ctx.convertExprBodyToAsgn(n[0], tmp)
+        for i in 1 ..< n.len:
+          let branch = n[i]
+          case branch.kind
+          of nkExceptBranch:
+            if branch[0].kind == nkType:
+              branch[1] = ctx.convertExprBodyToAsgn(branch[1], tmp)
+            else:
+              branch[0] = ctx.convertExprBodyToAsgn(branch[0], tmp)
+          of nkFinally:
+            discard
+          else:
+            internalError("lowerStmtListExpr(nkTryStmt): " & $branch.kind)
+        result.add(n)
+        result.add(ctx.newEnvVarAccess(tmp))
+
+  of nkCaseStmt:
+    var ns = false
+    for i in 0 ..< n.len:
+      n[i] = ctx.lowerStmtListExprs(n[i], ns)
+
+    if ns:
+      needsSplit = true
+
+      let isExpr = not isEmptyType(n.typ)
+
+      if isExpr:
+        let tmp = ctx.newTempVar(n.typ)
+        result = newNodeI(nkStmtListExpr, n.info)
+        result.typ = n.typ
+
+        if n[0].kind == nkStmtListExpr:
+          let (st, ex) = exprToStmtList(n[0])
+          result.add(st)
+          n[0] = ex
+
+        for i in 1 ..< n.len:
+          let branch = n[i]
+          case branch.kind
+          of nkOfBranch:
+            branch[1] = ctx.convertExprBodyToAsgn(branch[1], tmp)
+          of nkElse:
+            branch[0] = ctx.convertExprBodyToAsgn(branch[0], tmp)
+          else:
+            internalError("lowerStmtListExpr(nkCaseStmt): " & $branch.kind)
+        result.add(n)
+        result.add(ctx.newEnvVarAccess(tmp))
+
+  of nkCallKinds:
+    var ns = false
+    for i in 0 ..< n.len:
+      n[i] = ctx.lowerStmtListExprs(n[i], ns)
+
+    if ns:
+      needsSplit = true
+      let isExpr = not isEmptyType(n.typ)
+
+      if isExpr:
+        result = newNodeI(nkStmtListExpr, n.info)
+        result.typ = n.typ
+      else:
+        result = newNode(nkStmtList, n.info)
+
+      if n[0].kind == nkSym and n[0].sym.magic in {mAnd, mOr}: # `and`/`or` short cirquiting
+        var cond = n[1]
+        if cond.kind == nkStmtListExpr:
+          let (st, ex) = exprToStmtList(cond)
+          result.add(st)
+          cond = ex
+
+        let tmp = ctx.newTempVar(cond.typ)
+        result.add(ctx.newEnvVarAsgn(tmp, cond))
+
+        let ifNode = newNode(nkIfStmt)
+        let ifBranch = newNode(nkElifBranch)
+
+        var check = ctx.newEnvVarAccess(tmp)
+        if n[0].sym.magic == mOr:
+          check = newNotCall(check)
+        ifBranch.add(check)
+
+        cond = n[2]
+        let ifBody = newNode(nkStmtList)
+        if cond.kind == nkStmtListExpr:
+          let (st, ex) = exprToStmtList(cond)
+          ifBody.add(st)
+          cond = ex
+        ifBody.add(ctx.newEnvVarAsgn(tmp, cond))
+        ifBranch.add(ifBody)
+        ifNode.add(ifBranch)
+        result.add(ifNode)
+        result.add(ctx.newEnvVarAccess(tmp))
+      else:
+        for i in 0 ..< n.len:
+          if n[i].kind == nkStmtListExpr:
+            let (st, ex) = exprToStmtList(n[i])
+            result.add(st)
+            n[i] = ex
+
+          if n[i].kind in nkCallKinds: # XXX: This should better be some sort of side effect tracking
+            let tmp = ctx.newTempVar(n[i].typ)
+            result.add(ctx.newEnvVarAsgn(tmp, n[i]))
+            n[i] = ctx.newEnvVarAccess(tmp)
+
+        result.add(n)
+
+  of nkVarSection, nkLetSection:
+    result = newNodeI(nkStmtList, n.info)
+    for c in n:
+      let varSect = newNodeI(n.kind, n.info)
+      varSect.add(c)
+      var ns = false
+      c[^1] = ctx.lowerStmtListExprs(c[^1], ns)
+      if ns:
+        needsSplit = true
+        assert(c[^1].kind == nkStmtListExpr)
+        let (st, ex) = exprToStmtList(c[^1])
+        result.add(st)
+        c[^1] = ex
+      result.add(varSect)
+
+  of nkDiscardStmt, nkReturnStmt, nkRaiseStmt:
+    var ns = false
+    for i in 0 ..< n.len:
+      n[i] = ctx.lowerStmtListExprs(n[i], ns)
 
+    if ns:
+      needsSplit = true
+      result = newNodeI(nkStmtList, n.info)
+      let (st, ex) = exprToStmtList(n[0])
+      result.add(st)
+      n[0] = ex
+      result.add(n)
+
+  of nkCast:
+    var ns = false
+    for i in 0 ..< n.len:
+      n[i] = ctx.lowerStmtListExprs(n[i], ns)
+
+    if ns:
+      needsSplit = true
+      result = newNodeI(nkStmtListExpr, n.info)
+      result.typ = n.typ
+      let (st, ex) = exprToStmtList(n[1])
+      result.add(st)
+      n[1] = ex
+      result.add(n)
+
+  of nkAsgn, nkFastAsgn:
+    var ns = false
+    for i in 0 ..< n.len:
+      n[i] = ctx.lowerStmtListExprs(n[i], ns)
+
+    if ns:
+      needsSplit = true
+      result = newNodeI(nkStmtList, n.info)
+      if n[0].kind == nkStmtListExpr:
+        let (st, ex) = exprToStmtList(n[0])
+        result.add(st)
+        n[0] = ex
+
+      if n[1].kind == nkStmtListExpr:
+        let (st, ex) = exprToStmtList(n[1])
+        result.add(st)
+        n[1] = ex
+
+      result.add(n)
+
+  of nkWhileStmt:
+    var ns = false
+
+    var condNeedsSplit = false
+    n[0] = ctx.lowerStmtListExprs(n[0], condNeedsSplit)
+    var bodyNeedsSplit = false
+    n[1] = ctx.lowerStmtListExprs(n[1], bodyNeedsSplit)
+
+    if condNeedsSplit or bodyNeedsSplit:
+      needsSplit = true
+
+      if condNeedsSplit:
+        let newBody = newNode(nkStmtList)
+
+        let (st, ex) = exprToStmtList(n[0])
+        newBody.add(st)
+        let check = newNode(nkIfStmt)
+        let branch = newNode(nkElifBranch)
+        branch.add(newNotCall(ex))
+        let brk = newNode(nkBreakStmt)
+        brk.add(emptyNode)
+        branch.add(brk)
+        check.add(branch)
+        newBody.add(check)
+        newBody.add(n[1])
+
+        n[0] = newSymNode(getSysSym("true"))
+        n[1] = newBody
   else:
     for i in 0 ..< n.len:
-      n[i] = ctx.lowerStmtListExpr(n[i])
+      n[i] = ctx.lowerStmtListExprs(n[i], needsSplit)
 
 proc newEndFinallyNode(ctx: var Ctx): PNode =
   # Generate the following code:
@@ -448,7 +793,7 @@ proc newEndFinallyNode(ctx: var Ctx): PNode =
   branch.add(cmp)
 
   let retStmt = newNode(nkReturnStmt)
-  let asgn = newNode(nkAsgn)
+  let asgn = newNode(nkFastAsgn)
   addSon(asgn, newSymNode(getClosureIterResult(ctx.fn)))
   addSon(asgn, ctx.newTmpResultAccess())
   retStmt.add(asgn)
@@ -482,7 +827,7 @@ proc transformReturnsInTry(ctx: var Ctx, n: PNode): PNode =
       asgn.add(newIntTypeNode(nkIntLit, 1, getSysType(tyBool)))
       result.add(asgn)
 
-    if n[0].kind != nkEmpty: # TODO: And not void!
+    if n[0].kind != nkEmpty:
       let asgnTmpResult = newNodeI(nkAsgn, n.info)
       asgnTmpResult.add(ctx.newTmpResultAccess())
       asgnTmpResult.add(n[0])
@@ -508,17 +853,15 @@ proc transformClosureIteratorBody(ctx: var Ctx, n: PNode, gotoOut: PNode): PNode
         nkSym, nkIdent, procDefs, nkTemplateDef:
       discard
 
-    of nkStmtList:
+    of nkStmtList, nkStmtListExpr:
+      assert(isEmptyType(n.typ), "nkStmtListExpr not lowered")
+
       result = addGotoOut(result, gotoOut)
       for i in 0 ..< n.len:
         if n[i].hasYieldsInExpressions:
           # Lower nkStmtListExpr nodes inside `n[i]` first
-          assert(ctx.loweredStmtListExpr.isNil)
-          ctx.loweredStmtListExpr = newNodeI(nkStmtList, n.info)
-          n[i] = ctx.lowerStmtListExpr(n[i])
-          ctx.loweredStmtListExpr.add(n[i])
-          n[i] = ctx.loweredStmtListExpr
-          ctx.loweredStmtListExpr = nil
+          var ns = false
+          n[i] = ctx.lowerStmtListExprs(n[i], ns)
 
         if n[i].hasYields:
           # Create a new split
@@ -534,9 +877,6 @@ proc transformClosureIteratorBody(ctx: var Ctx, n: PNode, gotoOut: PNode): PNode
           discard ctx.transformClosureIteratorBody(s, gotoOut)
           break
 
-    of nkStmtListExpr:
-      assert(false, "nkStmtListExpr not lowered")
-
     of nkYieldStmt:
       result = newNodeI(nkStmtList, n.info)
       result.add(n)
diff --git a/tests/iter/tyieldintry.nim b/tests/iter/tyieldintry.nim
index 9cb199c5b..31ec65a83 100644
--- a/tests/iter/tyieldintry.nim
+++ b/tests/iter/tyieldintry.nim
@@ -5,197 +5,368 @@ output: "ok"
 var closureIterResult = newSeq[int]()
 
 proc checkpoint(arg: int) =
-    closureIterResult.add(arg)
+  closureIterResult.add(arg)
 
 type
-    TestException = object of Exception
-    AnotherException = object of Exception
+  TestException = object of Exception
+  AnotherException = object of Exception
 
 proc testClosureIterAux(it: iterator(): int, exceptionExpected: bool, expectedResults: varargs[int]) =
-    closureIterResult.setLen(0)
+  closureIterResult.setLen(0)
 
-    var exceptionCaught = false
+  var exceptionCaught = false
 
-    try:
-        for i in it():
-            closureIterResult.add(i)
-    except TestException:
-        exceptionCaught = true
+  try:
+    for i in it():
+      closureIterResult.add(i)
+  except TestException:
+    exceptionCaught = true
 
-    if closureIterResult != @expectedResults or exceptionCaught != exceptionExpected:
-        if closureIterResult != @expectedResults:
-            echo "Expected: ", @expectedResults
-            echo "Actual: ", closureIterResult
-        if exceptionCaught != exceptionExpected:
-            echo "Expected exception: ", exceptionExpected
-            echo "Got exception: ", exceptionCaught
-        doAssert(false)
+  if closureIterResult != @expectedResults or exceptionCaught != exceptionExpected:
+    if closureIterResult != @expectedResults:
+      echo "Expected: ", @expectedResults
+      echo "Actual: ", closureIterResult
+    if exceptionCaught != exceptionExpected:
+      echo "Expected exception: ", exceptionExpected
+      echo "Got exception: ", exceptionCaught
+    doAssert(false)
 
 proc test(it: iterator(): int, expectedResults: varargs[int]) =
-    testClosureIterAux(it, false, expectedResults)
+  testClosureIterAux(it, false, expectedResults)
 
 proc testExc(it: iterator(): int, expectedResults: varargs[int]) =
-    testClosureIterAux(it, true, expectedResults)
+  testClosureIterAux(it, true, expectedResults)
 
 proc raiseException() =
-    raise newException(TestException, "Test exception!")
+  raise newException(TestException, "Test exception!")
 
 block:
-    iterator it(): int {.closure.} =
-        var i = 5
-        while i != 0:
-            yield i
-            if i == 3:
-                yield 123
-            dec i
+  iterator it(): int {.closure.} =
+    var i = 5
+    while i != 0:
+      yield i
+      if i == 3:
+        yield 123
+      dec i
 
-    test(it, 5, 4, 3, 123, 2, 1)
+  test(it, 5, 4, 3, 123, 2, 1)
 
 block:
-    iterator it(): int {.closure.} =
-        yield 0
-        try:
-            checkpoint(1)
-            raiseException()
-        except TestException:
-            checkpoint(2)
-            yield 3
-            checkpoint(4)
-        finally:
-            checkpoint(5)
+  iterator it(): int {.closure.} =
+    yield 0
+    try:
+      checkpoint(1)
+      raiseException()
+    except TestException:
+      checkpoint(2)
+      yield 3
+      checkpoint(4)
+    finally:
+      checkpoint(5)
 
-        checkpoint(6)
+    checkpoint(6)
 
-    test(it, 0, 1, 2, 3, 4, 5, 6)
+  test(it, 0, 1, 2, 3, 4, 5, 6)
 
 block:
-    iterator it(): int {.closure.} =
-        yield 0
-        try:
-            yield 1
-            checkpoint(2)
-        finally:
-            checkpoint(3)
-            yield 4
-            checkpoint(5)
-            yield 6
+  iterator it(): int {.closure.} =
+    yield 0
+    try:
+      yield 1
+      checkpoint(2)
+    finally:
+      checkpoint(3)
+      yield 4
+      checkpoint(5)
+      yield 6
 
-    test(it, 0, 1, 2, 3, 4, 5, 6)
+  test(it, 0, 1, 2, 3, 4, 5, 6)
 
 block:
-    iterator it(): int {.closure.} =
-        yield 0
-        try:
-            yield 1
-            raiseException()
-            yield 2
-        finally:
-            checkpoint(3)
-            yield 4
-            checkpoint(5)
-            yield 6
+  iterator it(): int {.closure.} =
+    yield 0
+    try:
+      yield 1
+      raiseException()
+      yield 2
+    finally:
+      checkpoint(3)
+      yield 4
+      checkpoint(5)
+      yield 6
 
-    testExc(it, 0, 1, 3, 4, 5, 6)
+  testExc(it, 0, 1, 3, 4, 5, 6)
 
 block:
-    iterator it(): int {.closure.} =
-        try:
-            try:
-                raiseException()
-            except AnotherException:
-                yield 123
-            finally:
-                checkpoint(3)
-        finally:
-            checkpoint(4)
+  iterator it(): int {.closure.} =
+    try:
+      try:
+        raiseException()
+      except AnotherException:
+        yield 123
+      finally:
+        checkpoint(3)
+    finally:
+      checkpoint(4)
+
+  testExc(it, 3, 4)
+
+block:
+  iterator it(): int {.closure.} =
+    try:
+      yield 1
+      raiseException()
+    except AnotherException:
+      checkpoint(123)
+    finally:
+      checkpoint(2)
+    checkpoint(3)
 
-    testExc(it, 3, 4)
+  testExc(it, 1, 2)
 
 block:
-    iterator it(): int {.closure.} =
+  iterator it(): int {.closure.} =
+    try:
+      yield 0
+      try:
+        yield 1
         try:
-            yield 1
-            raiseException()
+          yield 2
+          raiseException()
         except AnotherException:
-            checkpoint(123)
+          yield 123
         finally:
-            checkpoint(2)
-        checkpoint(3)
+          yield 3
+      except AnotherException:
+        yield 124
+      finally:
+        yield 4
+      checkpoint(1234)
+    except:
+      yield 5
+      checkpoint(6)
+    finally:
+      checkpoint(7)
+      yield 8
+    checkpoint(9)
+
+  test(it, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9)
 
-    testExc(it, 1, 2)
+block:
+  iterator it(): int {.closure.} =
+    try:
+      yield 0
+      return 2
+    finally:
+      checkpoint(1)
+    checkpoint(123)
+
+  test(it, 0, 1)
 
 block:
-    iterator it(): int {.closure.} =
-        try:
-            yield 0
-            try:
-                yield 1
-                try:
-                    yield 2
-                    raiseException()
-                except AnotherException:
-                    yield 123
-                finally:
-                    yield 3
-            except AnotherException:
-                yield 124
-            finally:
-                yield 4
-            checkpoint(1234)
-        except:
-            yield 5
-            checkpoint(6)
-        finally:
-            checkpoint(7)
-            yield 8
-        checkpoint(9)
+  iterator it(): int {.closure.} =
+    try:
+      try:
+        yield 0
+        raiseException()
+      finally:
+        checkpoint(1)
+    except TestException:
+      yield 2
+      return
+    finally:
+      yield 3
 
-    test(it, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9)
+    checkpoint(123)
+
+  test(it, 0, 1, 2, 3)
 
 block:
-    iterator it(): int {.closure.} =
-        try:
-            yield 0
-            return 2
-        finally:
-            checkpoint(1)
-        checkpoint(123)
+  iterator it(): int {.closure.} =
+    try:
+      try:
+        yield 0
+        raiseException()
+      finally:
+        return # Return in finally should stop exception propagation
+    except AnotherException:
+      yield 2
+      return
+    finally:
+      yield 3
+    checkpoint(123)
+
+  test(it, 0, 3)
+
+block: # Yield in yield
+  iterator it(): int {.closure.} =
+    template foo(): int =
+      yield 1
+      2
+
+    for i in 0 .. 2:
+      checkpoint(0)
+      yield foo()
+
+  test(it, 0, 1, 2, 0, 1, 2, 0, 1, 2)
 
-    test(it, 0, 1)
+block:
+  iterator it(): int {.closure.} =
+    let i = if true:
+        yield 0
+        1
+      else:
+        2
+    yield i
+
+  test(it, 0, 1)
 
 block:
-    iterator it(): int {.closure.} =
-        try:
-            try:
-                yield 0
-                raiseException()
-            finally:
-                checkpoint(1)
-        except TestException:
-            yield 2
-            return
-        finally:
-            yield 3
+  iterator it(): int {.closure.} =
+    var foo = 123
+    let i = try:
+        yield 0
+        raiseException()
+        1
+      except TestException as e:
+        assert(e.msg == "Test exception!")
+        case foo
+        of 1:
+          yield 123
+          2
+        of 123:
+          yield 5
+          6
+        else:
+          7
+    yield i
+
+  test(it, 0, 5, 6)
 
-        checkpoint(123)
+block:
+  iterator it(): int {.closure.} =
+    proc voidFoo(i1, i2, i3: int) =
+      checkpoint(i1)
+      checkpoint(i2)
+      checkpoint(i3)
+
+    proc foo(i1, i2, i3: int): int =
+      voidFoo(i1, i2, i3)
+      i3
+
+    proc bar(i1: int): int =
+      checkpoint(i1)
+
+    template tryexcept: int =
+      try:
+        yield 1
+        raiseException()
+        123
+      except TestException:
+        yield 2
+        checkpoint(3)
+        4
+
+    let e1 = true
+
+    template ifelse1: int =
+      if e1:
+        yield 10
+        11
+      else:
+        12
+
+    template ifelse2: int =
+      if ifelse1() == 12:
+        yield 20
+        21
+      else:
+        yield 22
+        23
+
+    let i = foo(bar(0), tryexcept, ifelse2)
+    discard foo(bar(0), tryexcept, ifelse2)
+    voidFoo(bar(0), tryexcept, ifelse2)
+    yield i
+
+  test(it,
+
+    # let i = foo(bar(0), tryexcept, ifelse2)
+    0, # bar(0)
+    1, 2, 3, # tryexcept
+    10, # ifelse1
+    22, # ifelse22
+    0, 4, 23, # foo
+
+    # discard foo(bar(0), tryexcept, ifelse2)
+    0, # bar(0)
+    1, 2, 3, # tryexcept
+    10, # ifelse1
+    22, # ifelse22
+    0, 4, 23, # foo
+
+    # voidFoo(bar(0), tryexcept, ifelse2)
+    0, # bar(0)
+    1, 2, 3, # tryexcept
+    10, # ifelse1
+    22, # ifelse22
+    0, 4, 23, # foo
+
+    23 # i
+  )
 
-    test(it, 0, 1, 2, 3)
+block:
+  iterator it(): int {.closure.} =
+    checkpoint(0)
+    for i in 0 .. 1:
+      try:
+        yield 1
+        raiseException()
+      except TestException as e:
+        doAssert(e.msg == "Test exception!")
+        yield 2
+      except AnotherException:
+        yield 123
+      except:
+        yield 1234
+      finally:
+        yield 3
+        checkpoint(4)
+        yield 5
+
+  test(it, 0, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5)
 
 block:
-    iterator it(): int {.closure.} =
-        try:
-            try:
-                yield 0
-                raiseException()
-            finally:
-                return # Return in finally should stop exception propagation
-        except AnotherException:
-            yield 2
-            return
-        finally:
-            yield 3
-        checkpoint(123)
+  iterator it(): int {.closure.} =
+    var i = 5
+    template foo(): bool =
+      yield i
+      true
+
+    while foo():
+      dec i
+      if i == 0:
+        break
+
+  test(it, 5, 4, 3, 2, 1)
+
+block: # Short cirquits
+  iterator it(): int {.closure.} =
+    template trueYield: bool =
+      yield 1
+      true
+
+    template falseYield: bool =
+      yield 0
+      false
+
+    if trueYield or falseYield:
+      discard falseYield and trueYield
+
+    if falseYield and trueYield:
+      checkpoint(123)
+
+  test(it, 1, 0, 0)
 
-    test(it, 0, 3)
 
 echo "ok"