diff options
Diffstat (limited to 'compiler')
-rw-r--r-- | compiler/closureiters.nim | 404 |
1 files changed, 372 insertions, 32 deletions
diff --git a/compiler/closureiters.nim b/compiler/closureiters.nim index a8e7e0274..a30b4e10e 100644 --- a/compiler/closureiters.nim +++ b/compiler/closureiters.nim @@ -149,7 +149,6 @@ type exitStateIdx: int # index of the last state tempVarId: int # unique name counter tempVars: PNode # Temp var decls, nkVarSection - loweredStmtListExpr: PNode # Temporary used for nkStmtListExpr lowering exceptionTable: seq[int] # For state `i` jump to state `exceptionTable[i]` if exception is raised hasExceptions: bool # Does closure have yield in try? curExcHandlingState: int # Negative for except, positive for finally @@ -177,6 +176,7 @@ proc newStateAssgn(ctx: var Ctx, stateNo: int = -2): PNode = proc newEnvVar(ctx: var Ctx, name: string, typ: PType): PSym = result = newSym(skVar, getIdent(name), ctx.fn, ctx.fn.info) result.typ = typ + assert(not typ.isNil) if not ctx.stateVarSym.isNil: # We haven't gone through labmda lifting yet, so just create a local var, @@ -197,7 +197,6 @@ proc newEnvVarAccess(ctx: Ctx, s: PSym): PNode = proc newTmpResultAccess(ctx: var Ctx): PNode = if ctx.tmpResultSym.isNil: - debug(ctx.fn.typ) ctx.tmpResultSym = ctx.newEnvVar(":tmpResult", ctx.fn.typ[0]) ctx.newEnvVarAccess(ctx.tmpResultSym) @@ -245,9 +244,8 @@ proc addGotoOut(n: PNode, gotoOut: PNode): PNode = if result.len != 0 and result.sons[^1].kind != nkGotoState: result.add(gotoOut) -proc newTempVarAccess(ctx: var Ctx, typ: PType, i: TLineInfo): PNode = - let s = ctx.newEnvVar(":tmpSlLower" & $ctx.tempVarId, typ) - result = ctx.newEnvVarAccess(s) +proc newTempVar(ctx: var Ctx, typ: PType): PSym = + result = ctx.newEnvVar(":tmpSlLower" & $ctx.tempVarId, typ) inc ctx.tempVarId proc hasYields(n: PNode): bool = @@ -390,35 +388,382 @@ proc hasYieldsInExpressions(n: PNode): bool = nkSym, nkIdent, procDefs, nkTemplateDef: discard of nkStmtListExpr: - result = n.hasYields - of nkStmtList, nkWhileStmt, nkCaseStmt, nkIfStmt: - discard + if isEmptyType(n.typ): + for c in n: + if c.hasYieldsInExpressions: + return true + else: + result = n.hasYields else: for c in n: if c.hasYieldsInExpressions: return true -proc lowerStmtListExpr(ctx: var Ctx, n: PNode): PNode = +proc exprToStmtList(n: PNode): tuple[s, res: PNode] = + assert(n.kind == nkStmtListExpr) + + var parent = n + var lastSon = n[^1] + + while lastSon.kind == nkStmtListExpr: + parent = lastSon + lastSon = lastSon[^1] + + result.s = newNodeI(nkStmtList, n.info) + result.s.sons = parent.sons + result.s.sons.setLen(result.s.sons.len - 1) # delete last son + result.res = lastSon + +proc newEnvVarAsgn(ctx: Ctx, s: PSym, v: PNode): PNode = + result = newNode(nkFastAsgn) + result.add(ctx.newEnvVarAccess(s)) + result.add(v) + +proc addExprAssgn(ctx: Ctx, output, input: PNode, sym: PSym) = + if input.kind == nkStmtListExpr: + let (st, res) = exprToStmtList(input) + output.add(st) + output.add(ctx.newEnvVarAsgn(sym, res)) + else: + output.add(ctx.newEnvVarAsgn(sym, input)) + +proc convertExprBodyToAsgn(ctx: Ctx, exprBody: PNode, res: PSym): PNode = + result = newNode(nkStmtList) + ctx.addExprAssgn(result, exprBody, res) + +proc newNotCall(e: PNode): PNode = + result = newNode(nkCall) + result.add(newSymNode(getSysMagic("not", mNot))) + result.add(e) + result.typ = getSysType(tyBool) + +proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode = result = n case n.kind of nkCharLit..nkUInt64Lit, nkFloatLit..nkFloat128Lit, nkStrLit..nkTripleStrLit, nkSym, nkIdent, procDefs, nkTemplateDef: discard - of nkStmtListExpr: - if n.hasYields: - for i in 0 .. n.len - 2: - ctx.loweredStmtListExpr.add(n[i]) - let tv = ctx.newTempVarAccess(n.typ, n[^1].info) - let asgn = newNode(nkAsgn) - asgn.add(tv) - asgn.add(n[^1]) - ctx.loweredStmtListExpr.add(asgn) - result = tv + of nkYieldStmt: + var ns = false + for i in 0 ..< n.len: + n[i] = ctx.lowerStmtListExprs(n[i], ns) + + if ns: + assert(n[0].kind == nkStmtListExpr) + result = newNodeI(nkStmtList, n.info) + let (st, ex) = exprToStmtList(n[0]) + result.add(st) + n[0] = ex + result.add(n) + + needsSplit = true + + of nkPar, nkObjConstr, nkTupleConstr, nkBracket, nkArgList: + var ns = false + for i in 0 ..< n.len: + n[i] = ctx.lowerStmtListExprs(n[i], ns) + + if ns: + needsSplit = true + + result = newNodeI(nkStmtListExpr, n.info) + if n.typ.isNil: internalError("lowerStmtListExprs: constr typ.isNil") + result.typ = n.typ + + for i in 0 ..< n.len: + if n[i].kind == nkStmtListExpr: + let (st, ex) = exprToStmtList(n[i]) + result.add(st) + n[i] = ex + result.add(n) + + of nkIfStmt, nkIfExpr: + var ns = false + for i in 0 ..< n.len: + n[i] = ctx.lowerStmtListExprs(n[i], ns) + + if ns: + needsSplit = true + var tmp: PSym + var s: PNode + let isExpr = not isEmptyType(n.typ) + if isExpr: + tmp = ctx.newTempVar(n.typ) + result = newNode(nkStmtListExpr) + result.typ = n.typ + else: + result = newNode(nkStmtList) + + var curS = result + + for branch in n: + case branch.kind + of nkElseExpr, nkElse: + if isExpr: + var newBranch = newNodeI(nkElse, branch.info) + let branchBody = newNode(nkStmtList) + ctx.addExprAssgn(branchBody, branch[0], tmp) + newBranch.add(branchBody) + curS.add(newBranch) + else: + curS.add(branch) + + of nkElifExpr, nkElifBranch: + var newBranch: PNode + if branch[0].kind == nkStmtListExpr: + let elseBody = newNode(nkStmtList) + + let (st, res) = exprToStmtList(branch[0]) + elseBody.add(st) + + newBranch = newNodeI(nkElifBranch, branch.info) + newBranch.add(res) + newBranch.add(branch[1]) + + let newIf = newNodeI(nkIfStmt, branch.info) + newIf.add(newBranch) + elseBody.add(newIf) + if curS.kind == nkIfStmt: + let newElse = newNodeI(nkElse, branch.info) + newElse.add(elseBody) + curS.add(newElse) + else: + curS.add(elseBody) + curS = newIf + else: + newBranch = branch + if curS.kind == nkIfStmt: + curS.add(newBranch) + else: + let newIf = newNodeI(nkIfStmt, branch.info) + newIf.add(newBranch) + curS.add(newIf) + curS = newIf + + if isExpr: + let branchBody = newNode(nkStmtList) + ctx.addExprAssgn(branchBody, branch[1], tmp) + newBranch[1] = branchBody + + else: + internalError("lowerStmtListExpr(nkIf): " & $branch.kind) + + if isExpr: result.add(ctx.newEnvVarAccess(tmp)) + + of nkTryStmt: + var ns = false + for i in 0 ..< n.len: + n[i] = ctx.lowerStmtListExprs(n[i], ns) + + if ns: + needsSplit = true + let isExpr = not isEmptyType(n.typ) + + if isExpr: + result = newNodeI(nkStmtListExpr, n.info) + result.typ = n.typ + let tmp = ctx.newTempVar(n.typ) + + n[0] = ctx.convertExprBodyToAsgn(n[0], tmp) + for i in 1 ..< n.len: + let branch = n[i] + case branch.kind + of nkExceptBranch: + if branch[0].kind == nkType: + branch[1] = ctx.convertExprBodyToAsgn(branch[1], tmp) + else: + branch[0] = ctx.convertExprBodyToAsgn(branch[0], tmp) + of nkFinally: + discard + else: + internalError("lowerStmtListExpr(nkTryStmt): " & $branch.kind) + result.add(n) + result.add(ctx.newEnvVarAccess(tmp)) + + of nkCaseStmt: + var ns = false + for i in 0 ..< n.len: + n[i] = ctx.lowerStmtListExprs(n[i], ns) + + if ns: + needsSplit = true + + let isExpr = not isEmptyType(n.typ) + + if isExpr: + let tmp = ctx.newTempVar(n.typ) + result = newNodeI(nkStmtListExpr, n.info) + result.typ = n.typ + + if n[0].kind == nkStmtListExpr: + let (st, ex) = exprToStmtList(n[0]) + result.add(st) + n[0] = ex + + for i in 1 ..< n.len: + let branch = n[i] + case branch.kind + of nkOfBranch: + branch[1] = ctx.convertExprBodyToAsgn(branch[1], tmp) + of nkElse: + branch[0] = ctx.convertExprBodyToAsgn(branch[0], tmp) + else: + internalError("lowerStmtListExpr(nkCaseStmt): " & $branch.kind) + result.add(n) + result.add(ctx.newEnvVarAccess(tmp)) + + of nkCallKinds: + var ns = false + for i in 0 ..< n.len: + n[i] = ctx.lowerStmtListExprs(n[i], ns) + + if ns: + needsSplit = true + let isExpr = not isEmptyType(n.typ) + + if isExpr: + result = newNodeI(nkStmtListExpr, n.info) + result.typ = n.typ + else: + result = newNode(nkStmtList, n.info) + + if n[0].kind == nkSym and n[0].sym.magic in {mAnd, mOr}: # `and`/`or` short cirquiting + var cond = n[1] + if cond.kind == nkStmtListExpr: + let (st, ex) = exprToStmtList(cond) + result.add(st) + cond = ex + + let tmp = ctx.newTempVar(cond.typ) + result.add(ctx.newEnvVarAsgn(tmp, cond)) + + let ifNode = newNode(nkIfStmt) + let ifBranch = newNode(nkElifBranch) + + var check = ctx.newEnvVarAccess(tmp) + if n[0].sym.magic == mOr: + check = newNotCall(check) + ifBranch.add(check) + + cond = n[2] + let ifBody = newNode(nkStmtList) + if cond.kind == nkStmtListExpr: + let (st, ex) = exprToStmtList(cond) + ifBody.add(st) + cond = ex + ifBody.add(ctx.newEnvVarAsgn(tmp, cond)) + ifBranch.add(ifBody) + ifNode.add(ifBranch) + result.add(ifNode) + result.add(ctx.newEnvVarAccess(tmp)) + else: + for i in 0 ..< n.len: + if n[i].kind == nkStmtListExpr: + let (st, ex) = exprToStmtList(n[i]) + result.add(st) + n[i] = ex + + if n[i].kind in nkCallKinds: # XXX: This should better be some sort of side effect tracking + let tmp = ctx.newTempVar(n[i].typ) + result.add(ctx.newEnvVarAsgn(tmp, n[i])) + n[i] = ctx.newEnvVarAccess(tmp) + + result.add(n) + + of nkVarSection, nkLetSection: + result = newNodeI(nkStmtList, n.info) + for c in n: + let varSect = newNodeI(n.kind, n.info) + varSect.add(c) + var ns = false + c[^1] = ctx.lowerStmtListExprs(c[^1], ns) + if ns: + needsSplit = true + assert(c[^1].kind == nkStmtListExpr) + let (st, ex) = exprToStmtList(c[^1]) + result.add(st) + c[^1] = ex + result.add(varSect) + + of nkDiscardStmt, nkReturnStmt, nkRaiseStmt: + var ns = false + for i in 0 ..< n.len: + n[i] = ctx.lowerStmtListExprs(n[i], ns) + if ns: + needsSplit = true + result = newNodeI(nkStmtList, n.info) + let (st, ex) = exprToStmtList(n[0]) + result.add(st) + n[0] = ex + result.add(n) + + of nkCast: + var ns = false + for i in 0 ..< n.len: + n[i] = ctx.lowerStmtListExprs(n[i], ns) + + if ns: + needsSplit = true + result = newNodeI(nkStmtListExpr, n.info) + result.typ = n.typ + let (st, ex) = exprToStmtList(n[1]) + result.add(st) + n[1] = ex + result.add(n) + + of nkAsgn, nkFastAsgn: + var ns = false + for i in 0 ..< n.len: + n[i] = ctx.lowerStmtListExprs(n[i], ns) + + if ns: + needsSplit = true + result = newNodeI(nkStmtList, n.info) + if n[0].kind == nkStmtListExpr: + let (st, ex) = exprToStmtList(n[0]) + result.add(st) + n[0] = ex + + if n[1].kind == nkStmtListExpr: + let (st, ex) = exprToStmtList(n[1]) + result.add(st) + n[1] = ex + + result.add(n) + + of nkWhileStmt: + var ns = false + + var condNeedsSplit = false + n[0] = ctx.lowerStmtListExprs(n[0], condNeedsSplit) + var bodyNeedsSplit = false + n[1] = ctx.lowerStmtListExprs(n[1], bodyNeedsSplit) + + if condNeedsSplit or bodyNeedsSplit: + needsSplit = true + + if condNeedsSplit: + let newBody = newNode(nkStmtList) + + let (st, ex) = exprToStmtList(n[0]) + newBody.add(st) + let check = newNode(nkIfStmt) + let branch = newNode(nkElifBranch) + branch.add(newNotCall(ex)) + let brk = newNode(nkBreakStmt) + brk.add(emptyNode) + branch.add(brk) + check.add(branch) + newBody.add(check) + newBody.add(n[1]) + + n[0] = newSymNode(getSysSym("true")) + n[1] = newBody else: for i in 0 ..< n.len: - n[i] = ctx.lowerStmtListExpr(n[i]) + n[i] = ctx.lowerStmtListExprs(n[i], needsSplit) proc newEndFinallyNode(ctx: var Ctx): PNode = # Generate the following code: @@ -448,7 +793,7 @@ proc newEndFinallyNode(ctx: var Ctx): PNode = branch.add(cmp) let retStmt = newNode(nkReturnStmt) - let asgn = newNode(nkAsgn) + let asgn = newNode(nkFastAsgn) addSon(asgn, newSymNode(getClosureIterResult(ctx.fn))) addSon(asgn, ctx.newTmpResultAccess()) retStmt.add(asgn) @@ -482,7 +827,7 @@ proc transformReturnsInTry(ctx: var Ctx, n: PNode): PNode = asgn.add(newIntTypeNode(nkIntLit, 1, getSysType(tyBool))) result.add(asgn) - if n[0].kind != nkEmpty: # TODO: And not void! + if n[0].kind != nkEmpty: let asgnTmpResult = newNodeI(nkAsgn, n.info) asgnTmpResult.add(ctx.newTmpResultAccess()) asgnTmpResult.add(n[0]) @@ -508,17 +853,15 @@ proc transformClosureIteratorBody(ctx: var Ctx, n: PNode, gotoOut: PNode): PNode nkSym, nkIdent, procDefs, nkTemplateDef: discard - of nkStmtList: + of nkStmtList, nkStmtListExpr: + assert(isEmptyType(n.typ), "nkStmtListExpr not lowered") + result = addGotoOut(result, gotoOut) for i in 0 ..< n.len: if n[i].hasYieldsInExpressions: # Lower nkStmtListExpr nodes inside `n[i]` first - assert(ctx.loweredStmtListExpr.isNil) - ctx.loweredStmtListExpr = newNodeI(nkStmtList, n.info) - n[i] = ctx.lowerStmtListExpr(n[i]) - ctx.loweredStmtListExpr.add(n[i]) - n[i] = ctx.loweredStmtListExpr - ctx.loweredStmtListExpr = nil + var ns = false + n[i] = ctx.lowerStmtListExprs(n[i], ns) if n[i].hasYields: # Create a new split @@ -534,9 +877,6 @@ proc transformClosureIteratorBody(ctx: var Ctx, n: PNode, gotoOut: PNode): PNode discard ctx.transformClosureIteratorBody(s, gotoOut) break - of nkStmtListExpr: - assert(false, "nkStmtListExpr not lowered") - of nkYieldStmt: result = newNodeI(nkStmtList, n.info) result.add(n) |