diff options
Diffstat (limited to 'compiler/closureiters.nim')
-rw-r--r-- | compiler/closureiters.nim | 183 |
1 files changed, 120 insertions, 63 deletions
diff --git a/compiler/closureiters.nim b/compiler/closureiters.nim index c1f525488..9ee394111 100644 --- a/compiler/closureiters.nim +++ b/compiler/closureiters.nim @@ -36,13 +36,6 @@ # else: # return -# The transformation should play well with lambdalifting, however depending -# on situation, it can be called either before or after lambdalifting -# transformation. As such we behave slightly differently, when accessing -# iterator state, or using temp variables. If lambdalifting did not happen, -# we just create local variables, so that they will be lifted further on. -# Otherwise, we utilize existing env, created by lambdalifting. - # Lambdalifting treats :state variable specially, it should always end up # as the first field in env. Currently C codegen depends on this behavior. @@ -151,7 +144,6 @@ type Ctx = object g: ModuleGraph fn: PSym - stateVarSym: PSym # :state variable. nil if env already introduced by lambdalifting tmpResultSym: PSym # Used when we return, but finally has to interfere unrollFinallySym: PSym # Indicates that we're unrolling finally states (either exception happened or premature return) curExcSym: PSym # Current exception @@ -168,18 +160,18 @@ type nearestFinally: int # Index of the nearest finally block. For try/except it # is their finally. For finally it is parent finally. Otherwise -1 idgen: IdGenerator + varStates: Table[ItemId, int] # Used to detect if local variable belongs to multiple states const nkSkip = {nkEmpty..nkNilLit, nkTemplateDef, nkTypeSection, nkStaticStmt, nkCommentStmt, nkMixinStmt, nkBindStmt} + procDefs emptyStateLabel = -1 + localNotSeen = -1 + localRequiresLifting = -2 proc newStateAccess(ctx: var Ctx): PNode = - if ctx.stateVarSym.isNil: - result = rawIndirectAccess(newSymNode(getEnvParam(ctx.fn)), - getStateField(ctx.g, ctx.fn), ctx.fn.info) - else: - result = newSymNode(ctx.stateVarSym) + result = rawIndirectAccess(newSymNode(getEnvParam(ctx.fn)), + getStateField(ctx.g, ctx.fn), ctx.fn.info) proc newStateAssgn(ctx: var Ctx, toValue: PNode): PNode = # Creates state assignment: @@ -195,24 +187,17 @@ proc newEnvVar(ctx: var Ctx, name: string, typ: PType): PSym = result = newSym(skVar, getIdent(ctx.g.cache, name), ctx.idgen, ctx.fn, ctx.fn.info) result.typ = typ result.flags.incl sfNoInit - assert(not typ.isNil) - - if not ctx.stateVarSym.isNil: - # We haven't gone through labmda lifting yet, so just create a local var, - # it will be lifted later - if ctx.tempVars.isNil: - ctx.tempVars = newNodeI(nkVarSection, ctx.fn.info) - addVar(ctx.tempVars, newSymNode(result)) - else: - let envParam = getEnvParam(ctx.fn) - # let obj = envParam.typ.lastSon - result = addUniqueField(envParam.typ.elementType, result, ctx.g.cache, ctx.idgen) + assert(not typ.isNil, "Env var needs a type") + + let envParam = getEnvParam(ctx.fn) + # let obj = envParam.typ.lastSon + result = addUniqueField(envParam.typ.elementType, result, ctx.g.cache, ctx.idgen) proc newEnvVarAccess(ctx: Ctx, s: PSym): PNode = - if ctx.stateVarSym.isNil: - result = rawIndirectAccess(newSymNode(getEnvParam(ctx.fn)), s, ctx.fn.info) - else: - result = newSymNode(s) + result = rawIndirectAccess(newSymNode(getEnvParam(ctx.fn)), s, ctx.fn.info) + +proc newTempVarAccess(ctx: Ctx, s: PSym): PNode = + result = newSymNode(s, ctx.fn.info) proc newTmpResultAccess(ctx: var Ctx): PNode = if ctx.tmpResultSym.isNil: @@ -255,9 +240,18 @@ proc addGotoOut(n: PNode, gotoOut: PNode): PNode = if result.len == 0 or result[^1].kind != nkGotoState: result.add(gotoOut) -proc newTempVar(ctx: var Ctx, typ: PType): PSym = - result = ctx.newEnvVar(":tmpSlLower" & $ctx.tempVarId, typ) +proc newTempVarDef(ctx: Ctx, s: PSym, initialValue: PNode): PNode = + var v = initialValue + if v == nil: + v = ctx.g.emptyNode + newTree(nkVarSection, newTree(nkIdentDefs, newSymNode(s), ctx.g.emptyNode, v)) + +proc newTempVar(ctx: var Ctx, typ: PType, parent: PNode, initialValue: PNode = nil): PSym = + result = newSym(skVar, getIdent(ctx.g.cache, ":tmpSlLower" & $ctx.tempVarId), ctx.idgen, ctx.fn, ctx.fn.info) inc ctx.tempVarId + result.typ = typ + assert(not typ.isNil, "Temp var needs a type") + parent.add(ctx.newTempVarDef(result, initialValue)) proc hasYields(n: PNode): bool = # TODO: This is very inefficient. It traverses the node, looking for nkYieldStmt. @@ -429,21 +423,20 @@ proc exprToStmtList(n: PNode): tuple[s, res: PNode] = result.res = n - -proc newEnvVarAsgn(ctx: Ctx, s: PSym, v: PNode): PNode = +proc newTempVarAsgn(ctx: Ctx, s: PSym, v: PNode): PNode = if isEmptyType(v.typ): result = v else: - result = newTree(nkFastAsgn, ctx.newEnvVarAccess(s), v) + result = newTree(nkFastAsgn, ctx.newTempVarAccess(s), v) result.info = v.info proc addExprAssgn(ctx: Ctx, output, input: PNode, sym: PSym) = if input.kind == nkStmtListExpr: let (st, res) = exprToStmtList(input) output.add(st) - output.add(ctx.newEnvVarAsgn(sym, res)) + output.add(ctx.newTempVarAsgn(sym, res)) else: - output.add(ctx.newEnvVarAsgn(sym, input)) + output.add(ctx.newTempVarAsgn(sym, input)) proc convertExprBodyToAsgn(ctx: Ctx, exprBody: PNode, res: PSym): PNode = result = newNodeI(nkStmtList, exprBody.info) @@ -457,6 +450,12 @@ proc boolLit(g: ModuleGraph; info: TLineInfo; value: bool): PNode = result = newIntLit(g, info, ord value) result.typ = getSysType(g, info, tyBool) +proc captureVar(c: var Ctx, s: PSym) = + if c.varStates.getOrDefault(s.itemId) != localRequiresLifting: + c.varStates[s.itemId] = localRequiresLifting # Mark this variable for lifting + let e = getEnvParam(c.fn) + discard addField(e.typ.elementType, s, c.g.cache, c.idgen) + proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode = result = n case n.kind @@ -513,9 +512,9 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode = var tmp: PSym = nil let isExpr = not isEmptyType(n.typ) if isExpr: - tmp = ctx.newTempVar(n.typ) result = newNodeI(nkStmtListExpr, n.info) result.typ = n.typ + tmp = ctx.newTempVar(n.typ, result) else: result = newNodeI(nkStmtList, n.info) @@ -566,7 +565,7 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode = else: internalError(ctx.g.config, "lowerStmtListExpr(nkIf): " & $branch.kind) - if isExpr: result.add(ctx.newEnvVarAccess(tmp)) + if isExpr: result.add(ctx.newTempVarAccess(tmp)) of nkTryStmt, nkHiddenTryStmt: var ns = false @@ -580,7 +579,7 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode = if isExpr: result = newNodeI(nkStmtListExpr, n.info) result.typ = n.typ - let tmp = ctx.newTempVar(n.typ) + let tmp = ctx.newTempVar(n.typ, result) n[0] = ctx.convertExprBodyToAsgn(n[0], tmp) for i in 1..<n.len: @@ -596,7 +595,7 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode = else: internalError(ctx.g.config, "lowerStmtListExpr(nkTryStmt): " & $branch.kind) result.add(n) - result.add(ctx.newEnvVarAccess(tmp)) + result.add(ctx.newTempVarAccess(tmp)) of nkCaseStmt: var ns = false @@ -609,9 +608,9 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode = let isExpr = not isEmptyType(n.typ) if isExpr: - let tmp = ctx.newTempVar(n.typ) result = newNodeI(nkStmtListExpr, n.info) result.typ = n.typ + let tmp = ctx.newTempVar(n.typ, result) if n[0].kind == nkStmtListExpr: let (st, ex) = exprToStmtList(n[0]) @@ -628,7 +627,7 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode = else: internalError(ctx.g.config, "lowerStmtListExpr(nkCaseStmt): " & $branch.kind) result.add(n) - result.add(ctx.newEnvVarAccess(tmp)) + result.add(ctx.newTempVarAccess(tmp)) elif n[0].kind == nkStmtListExpr: result = newNodeI(nkStmtList, n.info) let (st, ex) = exprToStmtList(n[0]) @@ -658,10 +657,10 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode = result.add(st) cond = ex - let tmp = ctx.newTempVar(cond.typ) - result.add(ctx.newEnvVarAsgn(tmp, cond)) + let tmp = ctx.newTempVar(cond.typ, result, cond) + # result.add(ctx.newTempVarAsgn(tmp, cond)) - var check = ctx.newEnvVarAccess(tmp) + var check = ctx.newTempVarAccess(tmp) if n[0].sym.magic == mOr: check = ctx.g.newNotCall(check) @@ -671,12 +670,12 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode = let (st, ex) = exprToStmtList(cond) ifBody.add(st) cond = ex - ifBody.add(ctx.newEnvVarAsgn(tmp, cond)) + ifBody.add(ctx.newTempVarAsgn(tmp, cond)) let ifBranch = newTree(nkElifBranch, check, ifBody) let ifNode = newTree(nkIfStmt, ifBranch) result.add(ifNode) - result.add(ctx.newEnvVarAccess(tmp)) + result.add(ctx.newTempVarAccess(tmp)) else: for i in 0..<n.len: if n[i].kind == nkStmtListExpr: @@ -685,9 +684,9 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode = n[i] = ex if n[i].kind in nkCallKinds: # XXX: This should better be some sort of side effect tracking - let tmp = ctx.newTempVar(n[i].typ) - result.add(ctx.newEnvVarAsgn(tmp, n[i])) - n[i] = ctx.newEnvVarAccess(tmp) + let tmp = ctx.newTempVar(n[i].typ, result, n[i]) + # result.add(ctx.newTempVarAsgn(tmp, n[i])) + n[i] = ctx.newTempVarAccess(tmp) result.add(n) @@ -703,6 +702,12 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode = let (st, ex) = exprToStmtList(c[^1]) result.add(st) c[^1] = ex + for i in 0 .. c.len - 3: + if c[i].kind == nkSym: + let s = c[i].sym + if sfForceLift in s.flags: + ctx.captureVar(s) + result.add(varSect) of nkDiscardStmt, nkReturnStmt, nkRaiseStmt: @@ -1279,13 +1284,6 @@ proc wrapIntoStateLoop(ctx: var Ctx, n: PNode): PNode = result.info = n.info let localVars = newNodeI(nkStmtList, n.info) - if not ctx.stateVarSym.isNil: - let varSect = newNodeI(nkVarSection, n.info) - addVar(varSect, newSymNode(ctx.stateVarSym)) - localVars.add(varSect) - - if not ctx.tempVars.isNil: - localVars.add(ctx.tempVars) let blockStmt = newNodeI(nkBlockStmt, n.info) blockStmt.add(newSymNode(ctx.stateLoopLabel)) @@ -1433,15 +1431,67 @@ proc preprocess(c: var PreprocessContext; n: PNode): PNode = for i in 0 ..< n.len: result[i] = preprocess(c, n[i]) +proc detectCapturedVars(c: var Ctx, n: PNode, stateIdx: int) = + case n.kind + of nkSym: + let s = n.sym + if s.kind in {skResult, skVar, skLet, skForVar, skTemp} and sfGlobal notin s.flags and s.owner == c.fn: + let vs = c.varStates.getOrDefault(s.itemId, localNotSeen) + if vs == localNotSeen: # First seing this variable + c.varStates[s.itemId] = stateIdx + elif vs == localRequiresLifting: + discard # Sym already marked + elif vs != stateIdx: + c.captureVar(s) + of nkReturnStmt: + if n[0].kind in {nkAsgn, nkFastAsgn, nkSinkAsgn}: + # we have a `result = result` expression produced by the closure + # transform, let's not touch the LHS in order to make the lifting pass + # correct when `result` is lifted + detectCapturedVars(c, n[0][1], stateIdx) + else: + detectCapturedVars(c, n[0], stateIdx) + else: + for i in 0 ..< n.safeLen: + detectCapturedVars(c, n[i], stateIdx) + +proc detectCapturedVars(c: var Ctx) = + for i, s in c.states: + detectCapturedVars(c, s.body, i) + +proc liftLocals(c: var Ctx, n: PNode): PNode = + result = n + case n.kind + of nkSym: + let s = n.sym + if c.varStates.getOrDefault(s.itemId) == localRequiresLifting: + # lift + let e = getEnvParam(c.fn) + let field = getFieldFromObj(e.typ.elementType, s) + assert(field != nil) + result = rawIndirectAccess(newSymNode(e), field, n.info) + # elif c.varStates.getOrDefault(s.itemId, localNotSeen) != localNotSeen: + # echo "Not lifting ", s.name.s + + of nkReturnStmt: + if n[0].kind in {nkAsgn, nkFastAsgn, nkSinkAsgn}: + # we have a `result = result` expression produced by the closure + # transform, let's not touch the LHS in order to make the lifting pass + # correct when `result` is lifted + n[0][1] = liftLocals(c, n[0][1]) + else: + n[0] = liftLocals(c, n[0]) + else: + for i in 0 ..< n.safeLen: + n[i] = liftLocals(c, n[i]) + proc transformClosureIterator*(g: ModuleGraph; idgen: IdGenerator; fn: PSym, n: PNode): PNode = var ctx = Ctx(g: g, fn: fn, idgen: idgen) - if getEnvParam(fn).isNil: - # Lambda lifting was not done yet. Use temporary :state sym, which will - # be handled specially by lambda lifting. Local temp vars (if needed) - # should follow the same logic. - ctx.stateVarSym = newSym(skVar, getIdent(ctx.g.cache, ":state"), idgen, fn, fn.info) - ctx.stateVarSym.typ = g.createClosureIterStateType(fn, idgen) + # The transformation should always happen after at least partial lambdalifting + # is performed, so that the closure iter environment is always created upfront. + doAssert(getEnvParam(fn) != nil, "Env param not created before iter transformation") + ctx.stateLoopLabel = newSym(skLabel, getIdent(ctx.g.cache, ":stateLoop"), idgen, fn, fn.info) var pc = PreprocessContext(finallys: @[], config: g.config, idgen: idgen) var n = preprocess(pc, n.toStmtList) @@ -1466,6 +1516,10 @@ proc transformClosureIterator*(g: ModuleGraph; idgen: IdGenerator; fn: PSym, n: let caseDispatcher = newTreeI(nkCaseStmt, n.info, ctx.newStateAccess()) + # Lamdalifting will not touch our locals, it is our responsibility to lift those that + # need it. + detectCapturedVars(ctx) + for s in ctx.states: let body = ctx.transformStateAssignments(s.body) caseDispatcher.add newTreeI(nkOfBranch, body.info, g.newIntLit(body.info, s.label), body) @@ -1473,6 +1527,7 @@ proc transformClosureIterator*(g: ModuleGraph; idgen: IdGenerator; fn: PSym, n: caseDispatcher.add newTreeI(nkElse, n.info, newTreeI(nkReturnStmt, n.info, g.emptyNode)) result = wrapIntoStateLoop(ctx, caseDispatcher) + result = liftLocals(ctx, result) when false: echo "TRANSFORM TO STATES: " @@ -1481,3 +1536,5 @@ proc transformClosureIterator*(g: ModuleGraph; idgen: IdGenerator; fn: PSym, n: echo "exception table:" for i, e in ctx.exceptionTable: echo i, " -> ", e + + echo "ENV: ", renderTree(getEnvParam(fn).typ.elementType.n) |