diff options
author | Araq <rumpf_a@web.de> | 2014-01-23 01:41:26 +0100 |
---|---|---|
committer | Araq <rumpf_a@web.de> | 2014-01-23 01:41:26 +0100 |
commit | 3f87326247b142df4eff99a92c6529b33bb79b81 (patch) | |
tree | 632dc70d2d73e51b97fd9830a9a7ff42014df412 /compiler | |
parent | 37229df7fc044fe108d2f4d88f127141cabeb6a6 (diff) | |
download | Nim-3f87326247b142df4eff99a92c6529b33bb79b81.tar.gz |
closure iterators almost work
Diffstat (limited to 'compiler')
-rw-r--r-- | compiler/lambdalifting.nim | 368 | ||||
-rw-r--r-- | compiler/transf.nim | 12 |
2 files changed, 189 insertions, 191 deletions
diff --git a/compiler/lambdalifting.nim b/compiler/lambdalifting.nim index 2189a1d67..352b40693 100644 --- a/compiler/lambdalifting.nim +++ b/compiler/lambdalifting.nim @@ -130,12 +130,99 @@ type TOuterContext {.final.} = object fn: PSym # may also be a module! currentEnv: PEnv + isIter: bool # first class iterator? capturedVars, processed: TIntSet localsToEnv: TIdTable # PSym->PEnv mapping localsToAccess: TIdNodeTable lambdasToEnv: TIdTable # PSym->PEnv mapping up: POuterContext + closureParam, state, resultSym: PSym # only if isIter + tup: PType # only if isIter + + +proc getStateType(iter: PSym): PType = + var n = newNodeI(nkRange, iter.info) + addSon(n, newIntNode(nkIntLit, -1)) + addSon(n, newIntNode(nkIntLit, 0)) + result = newType(tyRange, iter) + result.n = n + rawAddSon(result, getSysType(tyInt)) + +proc createStateField(iter: PSym): PSym = + result = newSym(skField, getIdent(":state"), iter, iter.info) + result.typ = getStateType(iter) + +proc newIterResult(iter: PSym): PSym = + if resultPos < iter.ast.len: + result = iter.ast.sons[resultPos].sym + else: + # XXX a bit hacky: + result = newSym(skResult, getIdent":result", iter, iter.info) + result.typ = iter.typ.sons[0] + incl(result.flags, sfUsed) + iter.ast.add newSymNode(result) + +proc addHiddenParam(routine: PSym, param: PSym) = + var params = routine.ast.sons[paramsPos] + # -1 is correct here as param.position is 0 based but we have at position 0 + # some nkEffect node: + param.position = params.len-1 + addSon(params, newSymNode(param)) + incl(routine.typ.flags, tfCapturesEnv) + #echo "produced environment: ", param.id, " for ", routine.name.s + +proc getHiddenParam(routine: PSym): PSym = + let params = routine.ast.sons[paramsPos] + let hidden = lastSon(params) + assert hidden.kind == nkSym + result = hidden.sym + +proc getEnvParam(routine: PSym): PSym = + let params = routine.ast.sons[paramsPos] + let hidden = lastSon(params) + if hidden.kind == nkSym and hidden.sym.name.s == paramName: + result = hidden.sym + +proc addField(tup: PType, s: PSym) = + var field = newSym(skField, s.name, s.owner, s.info) + let t = skipIntLit(s.typ) + field.typ = t + field.position = sonsLen(tup) + addSon(tup.n, newSymNode(field)) + rawAddSon(tup, t) + +proc initIterContext(c: POuterContext, iter: PSym) = + c.fn = iter + c.capturedVars = initIntSet() + + var cp = getEnvParam(iter) + if cp == nil: + c.tup = newType(tyTuple, iter) + c.tup.n = newNodeI(nkRecList, iter.info) + + cp = newSym(skParam, getIdent(paramName), iter, iter.info) + incl(cp.flags, sfFromGeneric) + cp.typ = newType(tyRef, iter) + rawAddSon(cp.typ, c.tup) + addHiddenParam(iter, cp) + + c.state = createStateField(iter) + addField(c.tup, c.state) + else: + c.tup = cp.typ.sons[0] + assert c.tup.kind == tyTuple + if c.tup.len > 0: + c.state = c.tup.n[0].sym + else: + c.state = createStateField(iter) + addField(c.tup, c.state) + + c.closureParam = cp + if iter.typ.sons[0] != nil: + c.resultSym = newIterResult(iter) + #iter.ast.add(newSymNode(c.resultSym)) + proc newOuterContext(fn: PSym, up: POuterContext = nil): POuterContext = new(result) result.fn = fn @@ -144,24 +231,14 @@ proc newOuterContext(fn: PSym, up: POuterContext = nil): POuterContext = initIdNodeTable(result.localsToAccess) initIdTable(result.localsToEnv) initIdTable(result.lambdasToEnv) + result.isIter = fn.kind == skIterator and fn.typ.callConv == ccClosure + if result.isIter: initIterContext(result, fn) proc newInnerContext(fn: PSym): PInnerContext = new(result) result.fn = fn initIdNodeTable(result.localsToAccess) -proc getStateType(iter: PSym): PType = - var n = newNodeI(nkRange, iter.info) - addSon(n, newIntNode(nkIntLit, -1)) - addSon(n, newIntNode(nkIntLit, 0)) - result = newType(tyRange, iter) - result.n = n - rawAddSon(result, getSysType(tyInt)) - -proc createStateField(iter: PSym): PSym = - result = newSym(skField, getIdent(":state"), iter, iter.info) - result.typ = getStateType(iter) - proc newEnv(outerProc: PSym, up: PEnv, n: PNode): PEnv = new(result) result.deps = @[] @@ -171,14 +248,6 @@ proc newEnv(outerProc: PSym, up: PEnv, n: PNode): PEnv = result.up = up result.attachedNode = n -proc addField(tup: PType, s: PSym) = - var field = newSym(skField, s.name, s.owner, s.info) - let t = skipIntLit(s.typ) - field.typ = t - field.position = sonsLen(tup) - addSon(tup.n, newSymNode(field)) - rawAddSon(tup, t) - proc addCapturedVar(e: PEnv, v: PSym) = for x in e.capturedVars: if x == v: return @@ -221,27 +290,6 @@ proc newCall(a, b: PSym): PNode = result.add newSymNode(a) result.add newSymNode(b) -proc addHiddenParam(routine: PSym, param: PSym) = - var params = routine.ast.sons[paramsPos] - # -1 is correct here as param.position is 0 based but we have at position 0 - # some nkEffect node: - param.position = params.len-1 - addSon(params, newSymNode(param)) - incl(routine.typ.flags, tfCapturesEnv) - #echo "produced environment: ", param.id, " for ", routine.name.s - -proc getHiddenParam(routine: PSym): PSym = - let params = routine.ast.sons[paramsPos] - let hidden = lastSon(params) - assert hidden.kind == nkSym - result = hidden.sym - -proc getEnvParam(routine: PSym): PSym = - let params = routine.ast.sons[paramsPos] - let hidden = lastSon(params) - if hidden.kind == nkSym and hidden.sym.name.s == paramName: - result = hidden.sym - proc isInnerProc(s, outerProc: PSym): bool {.inline.} = result = (s.kind in {skProc, skMethod, skConverter} or s.kind == skIterator and s.typ.callConv == ccClosure) and @@ -334,7 +382,9 @@ proc gatherVars(o: POuterContext, i: PInnerContext, n: PNode) = var s = n.sym if interestingVar(s) and i.fn.id != s.owner.id: captureVar(o, i, s, n.info) - elif isInnerProc(s, o.fn) and tfCapturesEnv in s.typ.flags and s != i.fn: + elif s.kind in {skProc, skMethod, skConverter} and + s.skipGenericOwner == o.fn and + tfCapturesEnv in s.typ.flags and s != i.fn: # call to some other inner proc; we need to track the dependencies for # this: let env = PEnv(idTableGet(o.lambdasToEnv, i.fn)) @@ -342,7 +392,7 @@ proc gatherVars(o: POuterContext, i: PInnerContext, n: PNode) = if o.currentEnv != env: discard addDep(o.currentEnv, env, i.fn) internalError(n.info, "too complex environment handling required") - of nkEmpty..pred(nkSym), succ(nkSym)..nkNilLit: discard + of nkEmpty..pred(nkSym), succ(nkSym)..nkNilLit, nkClosure: discard else: for k in countup(0, sonsLen(n) - 1): gatherVars(o, i, n.sons[k]) @@ -398,7 +448,8 @@ proc transformInnerProc(o: POuterContext, i: PInnerContext, n: PNode): PNode = of nkLambdaKinds, nkIteratorDef: if n.typ != nil: result = transformInnerProc(o, i, n.sons[namePos]) - of nkProcDef, nkMethodDef, nkConverterDef, nkMacroDef, nkTemplateDef: + of nkProcDef, nkMethodDef, nkConverterDef, nkMacroDef, nkTemplateDef, + nkClosure: # don't recurse here: discard else: @@ -467,7 +518,8 @@ proc searchForInnerProcs(o: POuterContext, n: PNode) = searchForInnerProcs(o, it.sons[L-1]) else: internalError(it.info, "transformOuter") - of nkProcDef, nkMethodDef, nkConverterDef, nkMacroDef, nkTemplateDef: + of nkProcDef, nkMethodDef, nkConverterDef, nkMacroDef, nkTemplateDef, + nkClosure: # don't recurse here: # XXX recurse here and setup 'up' pointers discard @@ -526,12 +578,61 @@ proc generateClosureCreation(o: POuterContext, scope: PEnv): PNode = result.add(newAsgnStmt(indirectAccess(env, field, env.info), newSymNode(getClosureVar(o, e)), env.info)) +proc interestingIterVar(s: PSym): bool {.inline.} = + result = s.kind in {skVar, skLet, skTemp, skForVar} and sfGlobal notin s.flags + +proc transformOuterProc(o: POuterContext, n: PNode): PNode + +proc transformYield(c: POuterContext, n: PNode): PNode = + inc c.state.typ.n.sons[1].intVal + let stateNo = c.state.typ.n.sons[1].intVal + + var stateAsgnStmt = newNodeI(nkAsgn, n.info) + stateAsgnStmt.add(indirectAccess(newSymNode(c.closureParam),c.state,n.info)) + stateAsgnStmt.add(newIntTypeNode(nkIntLit, stateNo, getSysType(tyInt))) + + var retStmt = newNodeI(nkReturnStmt, n.info) + if n.sons[0].kind != nkEmpty: + var a = newNodeI(nkAsgn, n.sons[0].info) + var retVal = transformOuterProc(c, n.sons[0]) + addSon(a, newSymNode(c.resultSym)) + addSon(a, if retVal.isNil: n.sons[0] else: retVal) + retStmt.add(a) + else: + retStmt.add(emptyNode) + + var stateLabelStmt = newNodeI(nkState, n.info) + stateLabelStmt.add(newIntTypeNode(nkIntLit, stateNo, getSysType(tyInt))) + + result = newNodeI(nkStmtList, n.info) + result.add(stateAsgnStmt) + result.add(retStmt) + result.add(stateLabelStmt) + +proc transformReturn(c: POuterContext, n: PNode): PNode = + result = newNodeI(nkStmtList, n.info) + var stateAsgnStmt = newNodeI(nkAsgn, n.info) + stateAsgnStmt.add(indirectAccess(newSymNode(c.closureParam),c.state,n.info)) + stateAsgnStmt.add(newIntTypeNode(nkIntLit, -1, getSysType(tyInt))) + result.add(stateAsgnStmt) + result.add(n) + +proc outerProcSons(o: POuterContext, n: PNode) = + for i in countup(0, sonsLen(n) - 1): + let x = transformOuterProc(o, n.sons[i]) + if x != nil: n.sons[i] = x + proc transformOuterProc(o: POuterContext, n: PNode): PNode = if n == nil: return nil case n.kind of nkEmpty..pred(nkSym), succ(nkSym)..nkNilLit: discard of nkSym: var local = n.sym + + if o.isIter and interestingIterVar(local) and o.fn.id == local.owner.id: + if not containsOrIncl(o.capturedVars, local.id): addField(o.tup, local) + return indirectAccess(newSymNode(o.closureParam), local, n.info) + var closure = PEnv(idTableGet(o.lambdasToEnv, local)) if closure != nil: # we need to replace the lambda with '(lambda, env)': @@ -567,17 +668,44 @@ proc transformOuterProc(o: POuterContext, n: PNode): PNode = of nkLambdaKinds, nkIteratorDef: if n.typ != nil: result = transformOuterProc(o, n.sons[namePos]) - of nkProcDef, nkMethodDef, nkConverterDef, nkMacroDef, nkTemplateDef: + of nkProcDef, nkMethodDef, nkConverterDef, nkMacroDef, nkTemplateDef, + nkClosure: # don't recurse here: discard of nkHiddenStdConv, nkHiddenSubConv, nkConv: let x = transformOuterProc(o, n.sons[1]) if x != nil: n.sons[1] = x result = transformOuterConv(n) + of nkYieldStmt: + if o.isIter: result = transformYield(o, n) + else: outerProcSons(o, n) + of nkReturnStmt: + if o.isIter: result = transformReturn(o, n) + else: outerProcSons(o, n) else: - for i in countup(0, sonsLen(n) - 1): - let x = transformOuterProc(o, n.sons[i]) - if x != nil: n.sons[i] = x + outerProcSons(o, n) + +proc liftIterator(c: POuterContext, body: PNode): PNode = + let iter = c.fn + result = newNodeI(nkStmtList, iter.info) + var gs = newNodeI(nkGotoState, iter.info) + gs.add(indirectAccess(newSymNode(c.closureParam), c.state, iter.info)) + result.add(gs) + var state0 = newNodeI(nkState, iter.info) + state0.add(newIntNode(nkIntLit, 0)) + result.add(state0) + + let newBody = transformOuterProc(c, body) + if newBody != nil: + result.add(newBody) + else: + result.add(body) + + var stateAsgnStmt = newNodeI(nkAsgn, iter.info) + stateAsgnStmt.add(indirectAccess(newSymNode(c.closureParam), + c.state,iter.info)) + stateAsgnStmt.add(newIntTypeNode(nkIntLit, -1, getSysType(tyInt))) + result.add(stateAsgnStmt) proc liftLambdas*(fn: PSym, body: PNode): PNode = # XXX gCmd == cmdCompileToJS does not suffice! The compiletime stuff needs @@ -601,8 +729,11 @@ proc liftLambdas*(fn: PSym, body: PNode): PNode = if resultPos < sonsLen(ast) and ast.sons[resultPos].kind == nkSym: idTablePut(o.localsToEnv, ast.sons[resultPos].sym, o.currentEnv) searchForInnerProcs(o, body) - discard transformOuterProc(o, body) - result = ex + if o.isIter: + result = liftIterator(o, ex) + else: + discard transformOuterProc(o, body) + result = ex proc liftLambdasForTopLevel*(module: PSym, body: PNode): PNode = if body.kind == nkEmpty or gCmd == cmdCompileToJS: @@ -617,144 +748,11 @@ proc liftLambdasForTopLevel*(module: PSym, body: PNode): PNode = # ------------------- iterator transformation -------------------------------- -discard """ - iterator chain[S, T](a, b: *S->T, args: *S): T = - for x in a(args): yield x - for x in b(args): yield x - - let c = chain(f, g) - for x in c: echo x - - # translated to: - let c = chain( (f, newClosure(f)), (g, newClosure(g)), newClosure(chain)) -""" - -type - TIterContext {.final, pure.} = object - iter, closureParam, state, resultSym: PSym - capturedVars: TIntSet - tup: PType - -proc newIterResult(iter: PSym): PSym = - if resultPos < iter.ast.len: - result = iter.ast.sons[resultPos].sym - else: - # XXX a bit hacky: - result = newSym(skResult, getIdent":result", iter, iter.info) - result.typ = iter.typ.sons[0] - incl(result.flags, sfUsed) - iter.ast.add newSymNode(result) - -proc interestingIterVar(s: PSym): bool {.inline.} = - result = s.kind in {skVar, skLet, skTemp, skForVar} and sfGlobal notin s.flags - -proc transfIterBody(c: var TIterContext, n: PNode): PNode = - # gather used vars for closure generation - if n == nil: return nil - case n.kind - of nkSym: - var s = n.sym - if interestingIterVar(s) and c.iter.id == s.owner.id: - if not containsOrIncl(c.capturedVars, s.id): addField(c.tup, s) - result = indirectAccess(newSymNode(c.closureParam), s, n.info) - of nkEmpty..pred(nkSym), succ(nkSym)..nkNilLit: discard - of nkYieldStmt: - inc c.state.typ.n.sons[1].intVal - let stateNo = c.state.typ.n.sons[1].intVal - - var stateAsgnStmt = newNodeI(nkAsgn, n.info) - stateAsgnStmt.add(indirectAccess(newSymNode(c.closureParam),c.state,n.info)) - stateAsgnStmt.add(newIntTypeNode(nkIntLit, stateNo, getSysType(tyInt))) - - var retStmt = newNodeI(nkReturnStmt, n.info) - if n.sons[0].kind != nkEmpty: - var a = newNodeI(nkAsgn, n.sons[0].info) - var retVal = transfIterBody(c, n.sons[0]) - addSon(a, newSymNode(c.resultSym)) - addSon(a, if retVal.isNil: n.sons[0] else: retVal) - retStmt.add(a) - else: - retStmt.add(emptyNode) - - var stateLabelStmt = newNodeI(nkState, n.info) - stateLabelStmt.add(newIntTypeNode(nkIntLit, stateNo, getSysType(tyInt))) - - result = newNodeI(nkStmtList, n.info) - result.add(stateAsgnStmt) - result.add(retStmt) - result.add(stateLabelStmt) - of nkReturnStmt: - result = newNodeI(nkStmtList, n.info) - var stateAsgnStmt = newNodeI(nkAsgn, n.info) - stateAsgnStmt.add(indirectAccess(newSymNode(c.closureParam),c.state,n.info)) - stateAsgnStmt.add(newIntTypeNode(nkIntLit, -1, getSysType(tyInt))) - result.add(stateAsgnStmt) - result.add(n) - else: - for i in countup(0, sonsLen(n)-1): - let x = transfIterBody(c, n.sons[i]) - if x != nil: n.sons[i] = x - -proc initIterContext(c: var TIterContext, iter: PSym) = - c.iter = iter - c.capturedVars = initIntSet() - - var cp = getEnvParam(iter) - if cp == nil: - c.tup = newType(tyTuple, iter) - c.tup.n = newNodeI(nkRecList, iter.info) - - cp = newSym(skParam, getIdent(paramName), iter, iter.info) - incl(cp.flags, sfFromGeneric) - cp.typ = newType(tyRef, iter) - rawAddSon(cp.typ, c.tup) - addHiddenParam(iter, cp) - - c.state = createStateField(iter) - addField(c.tup, c.state) - else: - c.tup = cp.typ.sons[0] - assert c.tup.kind == tyTuple - if c.tup.len > 0: - c.state = c.tup.n[0].sym - else: - c.state = createStateField(iter) - addField(c.tup, c.state) - - c.closureParam = cp - if iter.typ.sons[0] != nil: - c.resultSym = newIterResult(iter) - #iter.ast.add(newSymNode(c.resultSym)) - -proc liftIterator*(iter: PSym, body: PNode): PNode = - var c: TIterContext - initIterContext c, iter - - result = newNodeI(nkStmtList, iter.info) - var gs = newNodeI(nkGotoState, iter.info) - gs.add(indirectAccess(newSymNode(c.closureParam), c.state, iter.info)) - result.add(gs) - var state0 = newNodeI(nkState, iter.info) - state0.add(newIntNode(nkIntLit, 0)) - result.add(state0) - - let newBody = transfIterBody(c, body) - if newBody != nil: - result.add(newBody) - else: - result.add(body) - - var stateAsgnStmt = newNodeI(nkAsgn, iter.info) - stateAsgnStmt.add(indirectAccess(newSymNode(c.closureParam), - c.state,iter.info)) - stateAsgnStmt.add(newIntTypeNode(nkIntLit, -1, getSysType(tyInt))) - result.add(stateAsgnStmt) - proc liftIterSym*(n: PNode): PNode = # transforms (iter) to (let env = newClosure[iter](); (iter, env)) let iter = n.sym assert iter.kind == skIterator - if sfClosureCreated in iter.flags: return n + #if sfClosureCreated in iter.flags: return n result = newNodeIT(nkStmtListExpr, n.info, n.typ) diff --git a/compiler/transf.nim b/compiler/transf.nim index 973e8848a..cda611005 100644 --- a/compiler/transf.nim +++ b/compiler/transf.nim @@ -113,8 +113,8 @@ proc newAsgnStmt(c: PTransf, le: PNode, ri: PTransNode): PTransNode = result[1] = ri proc transformSymAux(c: PTransf, n: PNode): PNode = - if n.sym.kind == skIterator and n.sym.typ.callConv == ccClosure: - return liftIterSym(n) + #if n.sym.kind == skIterator and n.sym.typ.callConv == ccClosure: + # return liftIterSym(n) var b: PNode var tc = c.transCon if sfBorrow in n.sym.flags: @@ -636,8 +636,8 @@ proc transform(c: PTransf, n: PNode): PTransNode = s.ast.sons[bodyPos] = n.sons[bodyPos] #n.sons[bodyPos] = liftLambdas(s, n) #if n.kind == nkMethodDef: methodDef(s, false) - if n.kind == nkIteratorDef and n.typ != nil: - return liftIterSym(n.sons[namePos]).PTransNode + #if n.kind == nkIteratorDef and n.typ != nil: + # return liftIterSym(n.sons[namePos]).PTransNode result = PTransNode(n) of nkMacroDef: # XXX no proper closure support yet: @@ -741,8 +741,8 @@ proc transformBody*(module: PSym, n: PNode, prc: PSym): PNode = var c = openTransf(module, "") result = processTransf(c, n, prc) result = liftLambdas(prc, result) - if prc.kind == skIterator and prc.typ.callConv == ccClosure: - result = lambdalifting.liftIterator(prc, result) + #if prc.kind == skIterator and prc.typ.callConv == ccClosure: + # result = lambdalifting.liftIterator(prc, result) incl(result.flags, nfTransf) when useEffectSystem: trackProc(prc, result) |