summary refs log tree commit diff stats
path: root/compiler
diff options
context:
space:
mode:
Diffstat (limited to 'compiler')
-rw-r--r--compiler/closureiters.nim404
1 files changed, 372 insertions, 32 deletions
diff --git a/compiler/closureiters.nim b/compiler/closureiters.nim
index a8e7e0274..a30b4e10e 100644
--- a/compiler/closureiters.nim
+++ b/compiler/closureiters.nim
@@ -149,7 +149,6 @@ type
     exitStateIdx: int # index of the last state
     tempVarId: int # unique name counter
     tempVars: PNode # Temp var decls, nkVarSection
-    loweredStmtListExpr: PNode # Temporary used for nkStmtListExpr lowering
     exceptionTable: seq[int] # For state `i` jump to state `exceptionTable[i]` if exception is raised
     hasExceptions: bool # Does closure have yield in try?
     curExcHandlingState: int # Negative for except, positive for finally
@@ -177,6 +176,7 @@ proc newStateAssgn(ctx: var Ctx, stateNo: int = -2): PNode =
 proc newEnvVar(ctx: var Ctx, name: string, typ: PType): PSym =
   result = newSym(skVar, getIdent(name), ctx.fn, ctx.fn.info)
   result.typ = typ
+  assert(not typ.isNil)
 
   if not ctx.stateVarSym.isNil:
     # We haven't gone through labmda lifting yet, so just create a local var,
@@ -197,7 +197,6 @@ proc newEnvVarAccess(ctx: Ctx, s: PSym): PNode =
 
 proc newTmpResultAccess(ctx: var Ctx): PNode =
   if ctx.tmpResultSym.isNil:
-    debug(ctx.fn.typ)
     ctx.tmpResultSym = ctx.newEnvVar(":tmpResult", ctx.fn.typ[0])
   ctx.newEnvVarAccess(ctx.tmpResultSym)
 
@@ -245,9 +244,8 @@ proc addGotoOut(n: PNode, gotoOut: PNode): PNode =
   if result.len != 0 and result.sons[^1].kind != nkGotoState:
     result.add(gotoOut)
 
-proc newTempVarAccess(ctx: var Ctx, typ: PType, i: TLineInfo): PNode =
-  let s = ctx.newEnvVar(":tmpSlLower" & $ctx.tempVarId, typ)
-  result = ctx.newEnvVarAccess(s)
+proc newTempVar(ctx: var Ctx, typ: PType): PSym =
+  result = ctx.newEnvVar(":tmpSlLower" & $ctx.tempVarId, typ)
   inc ctx.tempVarId
 
 proc hasYields(n: PNode): bool =
@@ -390,35 +388,382 @@ proc hasYieldsInExpressions(n: PNode): bool =
       nkSym, nkIdent, procDefs, nkTemplateDef:
     discard
   of nkStmtListExpr:
-    result = n.hasYields
-  of nkStmtList, nkWhileStmt, nkCaseStmt, nkIfStmt:
-    discard
+    if isEmptyType(n.typ):
+      for c in n:
+        if c.hasYieldsInExpressions:
+          return true
+    else:
+      result = n.hasYields
   else:
     for c in n:
       if c.hasYieldsInExpressions:
         return true
 
-proc lowerStmtListExpr(ctx: var Ctx, n: PNode): PNode =
+proc exprToStmtList(n: PNode): tuple[s, res: PNode] =
+  assert(n.kind == nkStmtListExpr)
+
+  var parent = n
+  var lastSon = n[^1]
+
+  while lastSon.kind == nkStmtListExpr:
+    parent = lastSon
+    lastSon = lastSon[^1]
+
+  result.s = newNodeI(nkStmtList, n.info)
+  result.s.sons = parent.sons
+  result.s.sons.setLen(result.s.sons.len - 1) # delete last son
+  result.res = lastSon
+
+proc newEnvVarAsgn(ctx: Ctx, s: PSym, v: PNode): PNode =
+  result = newNode(nkFastAsgn)
+  result.add(ctx.newEnvVarAccess(s))
+  result.add(v)
+
+proc addExprAssgn(ctx: Ctx, output, input: PNode, sym: PSym) =
+  if input.kind == nkStmtListExpr:
+    let (st, res) = exprToStmtList(input)
+    output.add(st)
+    output.add(ctx.newEnvVarAsgn(sym, res))
+  else:
+    output.add(ctx.newEnvVarAsgn(sym, input))
+
+proc convertExprBodyToAsgn(ctx: Ctx, exprBody: PNode, res: PSym): PNode =
+  result = newNode(nkStmtList)
+  ctx.addExprAssgn(result, exprBody, res)
+
+proc newNotCall(e: PNode): PNode =
+  result = newNode(nkCall)
+  result.add(newSymNode(getSysMagic("not", mNot)))
+  result.add(e)
+  result.typ = getSysType(tyBool)
+
+proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode =
   result = n
   case n.kind
   of nkCharLit..nkUInt64Lit, nkFloatLit..nkFloat128Lit, nkStrLit..nkTripleStrLit,
       nkSym, nkIdent, procDefs, nkTemplateDef:
     discard
-  of nkStmtListExpr:
-    if n.hasYields:
-      for i in 0 .. n.len - 2:
-        ctx.loweredStmtListExpr.add(n[i])
 
-      let tv = ctx.newTempVarAccess(n.typ, n[^1].info)
-      let asgn = newNode(nkAsgn)
-      asgn.add(tv)
-      asgn.add(n[^1])
-      ctx.loweredStmtListExpr.add(asgn)
-      result = tv
+  of nkYieldStmt:
+    var ns = false
+    for i in 0 ..< n.len:
+      n[i] = ctx.lowerStmtListExprs(n[i], ns)
+
+    if ns:
+      assert(n[0].kind == nkStmtListExpr)
+      result = newNodeI(nkStmtList, n.info)
+      let (st, ex) = exprToStmtList(n[0])
+      result.add(st)
+      n[0] = ex
+      result.add(n)
+
+    needsSplit = true
+
+  of nkPar, nkObjConstr, nkTupleConstr, nkBracket, nkArgList:
+    var ns = false
+    for i in 0 ..< n.len:
+      n[i] = ctx.lowerStmtListExprs(n[i], ns)
+
+    if ns:
+      needsSplit = true
+
+      result = newNodeI(nkStmtListExpr, n.info)
+      if n.typ.isNil: internalError("lowerStmtListExprs: constr typ.isNil")
+      result.typ = n.typ
+
+      for i in 0 ..< n.len:
+        if n[i].kind == nkStmtListExpr:
+          let (st, ex) = exprToStmtList(n[i])
+          result.add(st)
+          n[i] = ex
+      result.add(n)
+
+  of nkIfStmt, nkIfExpr:
+    var ns = false
+    for i in 0 ..< n.len:
+      n[i] = ctx.lowerStmtListExprs(n[i], ns)
+
+    if ns:
+      needsSplit = true
+      var tmp: PSym
+      var s: PNode
+      let isExpr = not isEmptyType(n.typ)
+      if isExpr:
+        tmp = ctx.newTempVar(n.typ)
+        result = newNode(nkStmtListExpr)
+        result.typ = n.typ
+      else:
+        result = newNode(nkStmtList)
+
+      var curS = result
+
+      for branch in n:
+        case branch.kind
+        of nkElseExpr, nkElse:
+          if isExpr:
+            var newBranch = newNodeI(nkElse, branch.info)
+            let branchBody = newNode(nkStmtList)
+            ctx.addExprAssgn(branchBody, branch[0], tmp)
+            newBranch.add(branchBody)
+            curS.add(newBranch)
+          else:
+            curS.add(branch)
+
+        of nkElifExpr, nkElifBranch:
+          var newBranch: PNode
+          if branch[0].kind == nkStmtListExpr:
+            let elseBody = newNode(nkStmtList)
+
+            let (st, res) = exprToStmtList(branch[0])
+            elseBody.add(st)
+
+            newBranch = newNodeI(nkElifBranch, branch.info)
+            newBranch.add(res)
+            newBranch.add(branch[1])
+
+            let newIf = newNodeI(nkIfStmt, branch.info)
+            newIf.add(newBranch)
+            elseBody.add(newIf)
+            if curS.kind == nkIfStmt:
+              let newElse = newNodeI(nkElse, branch.info)
+              newElse.add(elseBody)
+              curS.add(newElse)
+            else:
+              curS.add(elseBody)
+            curS = newIf
+          else:
+            newBranch = branch
+            if curS.kind == nkIfStmt:
+              curS.add(newBranch)
+            else:
+              let newIf = newNodeI(nkIfStmt, branch.info)
+              newIf.add(newBranch)
+              curS.add(newIf)
+              curS = newIf
+
+          if isExpr:
+            let branchBody = newNode(nkStmtList)
+            ctx.addExprAssgn(branchBody, branch[1], tmp)
+            newBranch[1] = branchBody
+
+        else:
+          internalError("lowerStmtListExpr(nkIf): " & $branch.kind)
+
+      if isExpr: result.add(ctx.newEnvVarAccess(tmp))
+
+  of nkTryStmt:
+    var ns = false
+    for i in 0 ..< n.len:
+      n[i] = ctx.lowerStmtListExprs(n[i], ns)
+
+    if ns:
+      needsSplit = true
+      let isExpr = not isEmptyType(n.typ)
+
+      if isExpr:
+        result = newNodeI(nkStmtListExpr, n.info)
+        result.typ = n.typ
+        let tmp = ctx.newTempVar(n.typ)
+
+        n[0] = ctx.convertExprBodyToAsgn(n[0], tmp)
+        for i in 1 ..< n.len:
+          let branch = n[i]
+          case branch.kind
+          of nkExceptBranch:
+            if branch[0].kind == nkType:
+              branch[1] = ctx.convertExprBodyToAsgn(branch[1], tmp)
+            else:
+              branch[0] = ctx.convertExprBodyToAsgn(branch[0], tmp)
+          of nkFinally:
+            discard
+          else:
+            internalError("lowerStmtListExpr(nkTryStmt): " & $branch.kind)
+        result.add(n)
+        result.add(ctx.newEnvVarAccess(tmp))
+
+  of nkCaseStmt:
+    var ns = false
+    for i in 0 ..< n.len:
+      n[i] = ctx.lowerStmtListExprs(n[i], ns)
+
+    if ns:
+      needsSplit = true
+
+      let isExpr = not isEmptyType(n.typ)
+
+      if isExpr:
+        let tmp = ctx.newTempVar(n.typ)
+        result = newNodeI(nkStmtListExpr, n.info)
+        result.typ = n.typ
+
+        if n[0].kind == nkStmtListExpr:
+          let (st, ex) = exprToStmtList(n[0])
+          result.add(st)
+          n[0] = ex
+
+        for i in 1 ..< n.len:
+          let branch = n[i]
+          case branch.kind
+          of nkOfBranch:
+            branch[1] = ctx.convertExprBodyToAsgn(branch[1], tmp)
+          of nkElse:
+            branch[0] = ctx.convertExprBodyToAsgn(branch[0], tmp)
+          else:
+            internalError("lowerStmtListExpr(nkCaseStmt): " & $branch.kind)
+        result.add(n)
+        result.add(ctx.newEnvVarAccess(tmp))
+
+  of nkCallKinds:
+    var ns = false
+    for i in 0 ..< n.len:
+      n[i] = ctx.lowerStmtListExprs(n[i], ns)
+
+    if ns:
+      needsSplit = true
+      let isExpr = not isEmptyType(n.typ)
+
+      if isExpr:
+        result = newNodeI(nkStmtListExpr, n.info)
+        result.typ = n.typ
+      else:
+        result = newNode(nkStmtList, n.info)
+
+      if n[0].kind == nkSym and n[0].sym.magic in {mAnd, mOr}: # `and`/`or` short cirquiting
+        var cond = n[1]
+        if cond.kind == nkStmtListExpr:
+          let (st, ex) = exprToStmtList(cond)
+          result.add(st)
+          cond = ex
+
+        let tmp = ctx.newTempVar(cond.typ)
+        result.add(ctx.newEnvVarAsgn(tmp, cond))
+
+        let ifNode = newNode(nkIfStmt)
+        let ifBranch = newNode(nkElifBranch)
+
+        var check = ctx.newEnvVarAccess(tmp)
+        if n[0].sym.magic == mOr:
+          check = newNotCall(check)
+        ifBranch.add(check)
+
+        cond = n[2]
+        let ifBody = newNode(nkStmtList)
+        if cond.kind == nkStmtListExpr:
+          let (st, ex) = exprToStmtList(cond)
+          ifBody.add(st)
+          cond = ex
+        ifBody.add(ctx.newEnvVarAsgn(tmp, cond))
+        ifBranch.add(ifBody)
+        ifNode.add(ifBranch)
+        result.add(ifNode)
+        result.add(ctx.newEnvVarAccess(tmp))
+      else:
+        for i in 0 ..< n.len:
+          if n[i].kind == nkStmtListExpr:
+            let (st, ex) = exprToStmtList(n[i])
+            result.add(st)
+            n[i] = ex
+
+          if n[i].kind in nkCallKinds: # XXX: This should better be some sort of side effect tracking
+            let tmp = ctx.newTempVar(n[i].typ)
+            result.add(ctx.newEnvVarAsgn(tmp, n[i]))
+            n[i] = ctx.newEnvVarAccess(tmp)
+
+        result.add(n)
+
+  of nkVarSection, nkLetSection:
+    result = newNodeI(nkStmtList, n.info)
+    for c in n:
+      let varSect = newNodeI(n.kind, n.info)
+      varSect.add(c)
+      var ns = false
+      c[^1] = ctx.lowerStmtListExprs(c[^1], ns)
+      if ns:
+        needsSplit = true
+        assert(c[^1].kind == nkStmtListExpr)
+        let (st, ex) = exprToStmtList(c[^1])
+        result.add(st)
+        c[^1] = ex
+      result.add(varSect)
+
+  of nkDiscardStmt, nkReturnStmt, nkRaiseStmt:
+    var ns = false
+    for i in 0 ..< n.len:
+      n[i] = ctx.lowerStmtListExprs(n[i], ns)
 
+    if ns:
+      needsSplit = true
+      result = newNodeI(nkStmtList, n.info)
+      let (st, ex) = exprToStmtList(n[0])
+      result.add(st)
+      n[0] = ex
+      result.add(n)
+
+  of nkCast:
+    var ns = false
+    for i in 0 ..< n.len:
+      n[i] = ctx.lowerStmtListExprs(n[i], ns)
+
+    if ns:
+      needsSplit = true
+      result = newNodeI(nkStmtListExpr, n.info)
+      result.typ = n.typ
+      let (st, ex) = exprToStmtList(n[1])
+      result.add(st)
+      n[1] = ex
+      result.add(n)
+
+  of nkAsgn, nkFastAsgn:
+    var ns = false
+    for i in 0 ..< n.len:
+      n[i] = ctx.lowerStmtListExprs(n[i], ns)
+
+    if ns:
+      needsSplit = true
+      result = newNodeI(nkStmtList, n.info)
+      if n[0].kind == nkStmtListExpr:
+        let (st, ex) = exprToStmtList(n[0])
+        result.add(st)
+        n[0] = ex
+
+      if n[1].kind == nkStmtListExpr:
+        let (st, ex) = exprToStmtList(n[1])
+        result.add(st)
+        n[1] = ex
+
+      result.add(n)
+
+  of nkWhileStmt:
+    var ns = false
+
+    var condNeedsSplit = false
+    n[0] = ctx.lowerStmtListExprs(n[0], condNeedsSplit)
+    var bodyNeedsSplit = false
+    n[1] = ctx.lowerStmtListExprs(n[1], bodyNeedsSplit)
+
+    if condNeedsSplit or bodyNeedsSplit:
+      needsSplit = true
+
+      if condNeedsSplit:
+        let newBody = newNode(nkStmtList)
+
+        let (st, ex) = exprToStmtList(n[0])
+        newBody.add(st)
+        let check = newNode(nkIfStmt)
+        let branch = newNode(nkElifBranch)
+        branch.add(newNotCall(ex))
+        let brk = newNode(nkBreakStmt)
+        brk.add(emptyNode)
+        branch.add(brk)
+        check.add(branch)
+        newBody.add(check)
+        newBody.add(n[1])
+
+        n[0] = newSymNode(getSysSym("true"))
+        n[1] = newBody
   else:
     for i in 0 ..< n.len:
-      n[i] = ctx.lowerStmtListExpr(n[i])
+      n[i] = ctx.lowerStmtListExprs(n[i], needsSplit)
 
 proc newEndFinallyNode(ctx: var Ctx): PNode =
   # Generate the following code:
@@ -448,7 +793,7 @@ proc newEndFinallyNode(ctx: var Ctx): PNode =
   branch.add(cmp)
 
   let retStmt = newNode(nkReturnStmt)
-  let asgn = newNode(nkAsgn)
+  let asgn = newNode(nkFastAsgn)
   addSon(asgn, newSymNode(getClosureIterResult(ctx.fn)))
   addSon(asgn, ctx.newTmpResultAccess())
   retStmt.add(asgn)
@@ -482,7 +827,7 @@ proc transformReturnsInTry(ctx: var Ctx, n: PNode): PNode =
       asgn.add(newIntTypeNode(nkIntLit, 1, getSysType(tyBool)))
       result.add(asgn)
 
-    if n[0].kind != nkEmpty: # TODO: And not void!
+    if n[0].kind != nkEmpty:
       let asgnTmpResult = newNodeI(nkAsgn, n.info)
       asgnTmpResult.add(ctx.newTmpResultAccess())
       asgnTmpResult.add(n[0])
@@ -508,17 +853,15 @@ proc transformClosureIteratorBody(ctx: var Ctx, n: PNode, gotoOut: PNode): PNode
         nkSym, nkIdent, procDefs, nkTemplateDef:
       discard
 
-    of nkStmtList:
+    of nkStmtList, nkStmtListExpr:
+      assert(isEmptyType(n.typ), "nkStmtListExpr not lowered")
+
       result = addGotoOut(result, gotoOut)
       for i in 0 ..< n.len:
         if n[i].hasYieldsInExpressions:
           # Lower nkStmtListExpr nodes inside `n[i]` first
-          assert(ctx.loweredStmtListExpr.isNil)
-          ctx.loweredStmtListExpr = newNodeI(nkStmtList, n.info)
-          n[i] = ctx.lowerStmtListExpr(n[i])
-          ctx.loweredStmtListExpr.add(n[i])
-          n[i] = ctx.loweredStmtListExpr
-          ctx.loweredStmtListExpr = nil
+          var ns = false
+          n[i] = ctx.lowerStmtListExprs(n[i], ns)
 
         if n[i].hasYields:
           # Create a new split
@@ -534,9 +877,6 @@ proc transformClosureIteratorBody(ctx: var Ctx, n: PNode, gotoOut: PNode): PNode
           discard ctx.transformClosureIteratorBody(s, gotoOut)
           break
 
-    of nkStmtListExpr:
-      assert(false, "nkStmtListExpr not lowered")
-
     of nkYieldStmt:
       result = newNodeI(nkStmtList, n.info)
       result.add(n)