diff options
Diffstat (limited to 'compiler/lambdalifting.nim')
-rw-r--r-- | compiler/lambdalifting.nim | 134 |
1 files changed, 86 insertions, 48 deletions
diff --git a/compiler/lambdalifting.nim b/compiler/lambdalifting.nim index ed92fefb4..2189a1d67 100644 --- a/compiler/lambdalifting.nim +++ b/compiler/lambdalifting.nim @@ -116,9 +116,9 @@ type TDep = tuple[e: PEnv, field: PSym] TEnv {.final.} = object of TObject attachedNode: PNode - closure: PSym # if != nil it is a used environment + closure: PSym # if != nil it is a used environment capturedVars: seq[PSym] # captured variables in this environment - deps: seq[TDep] # dependencies + deps: seq[TDep] # dependencies up: PEnv tup: PType @@ -149,7 +149,19 @@ proc newInnerContext(fn: PSym): PInnerContext = new(result) result.fn = fn initIdNodeTable(result.localsToAccess) - + +proc getStateType(iter: PSym): PType = + var n = newNodeI(nkRange, iter.info) + addSon(n, newIntNode(nkIntLit, -1)) + addSon(n, newIntNode(nkIntLit, 0)) + result = newType(tyRange, iter) + result.n = n + rawAddSon(result, getSysType(tyInt)) + +proc createStateField(iter: PSym): PSym = + result = newSym(skField, getIdent(":state"), iter, iter.info) + result.typ = getStateType(iter) + proc newEnv(outerProc: PSym, up: PEnv, n: PNode): PEnv = new(result) result.deps = @[] @@ -170,6 +182,9 @@ proc addField(tup: PType, s: PSym) = proc addCapturedVar(e: PEnv, v: PSym) = for x in e.capturedVars: if x == v: return + # XXX meh, just add the state field for every closure for now, it's too + # hard to figure out if it comes from a closure iterator: + if e.tup.len == 0: addField(e.tup, createStateField(v.owner)) e.capturedVars.add(v) addField(e.tup, v) @@ -189,6 +204,7 @@ proc indirectAccess(a: PNode, b: PSym, info: TLineInfo): PNode = # returns a[].b as a node var deref = newNodeI(nkHiddenDeref, info) deref.typ = a.typ.sons[0] + assert deref.typ.kind == tyTuple let field = getSymFromList(deref.typ.n, b.name) assert field != nil, b.name.s addSon(deref, a) @@ -220,18 +236,30 @@ proc getHiddenParam(routine: PSym): PSym = assert hidden.kind == nkSym result = hidden.sym +proc getEnvParam(routine: PSym): PSym = + let params = routine.ast.sons[paramsPos] + let hidden = lastSon(params) + if hidden.kind == nkSym and hidden.sym.name.s == paramName: + result = hidden.sym + proc isInnerProc(s, outerProc: PSym): bool {.inline.} = - result = s.kind in {skProc, skMethod, skConverter} and + result = (s.kind in {skProc, skMethod, skConverter} or + s.kind == skIterator and s.typ.callConv == ccClosure) and s.skipGenericOwner == outerProc #s.typ.callConv == ccClosure proc addClosureParam(i: PInnerContext, e: PEnv) = - var cp = newSym(skParam, getIdent(paramName), i.fn, i.fn.info) - incl(cp.flags, sfFromGeneric) - cp.typ = newType(tyRef, i.fn) - rawAddSon(cp.typ, e.tup) + var cp = getEnvParam(i.fn) + if cp == nil: + cp = newSym(skParam, getIdent(paramName), i.fn, i.fn.info) + incl(cp.flags, sfFromGeneric) + cp.typ = newType(tyRef, i.fn) + rawAddSon(cp.typ, e.tup) + addHiddenParam(i.fn, cp) + else: + e.tup = cp.typ.sons[0] + assert e.tup.kind == tyTuple i.closureParam = cp - addHiddenParam(i.fn, i.closureParam) #echo "closure param added for ", i.fn.name.s, " ", i.fn.id proc dummyClosureParam(o: POuterContext, i: PInnerContext) = @@ -344,6 +372,7 @@ proc transformOuterConv(n: PNode): PNode = proc makeClosure(prc, env: PSym, info: TLineInfo): PNode = result = newNodeIT(nkClosure, info, prc.typ) result.add(newSymNode(prc)) + if prc.kind == skIterator: incl(prc.flags, sfClosureCreated) if env == nil: result.add(newNodeIT(nkNilLit, info, getSysType(tyNil))) else: @@ -366,10 +395,10 @@ proc transformInnerProc(o: POuterContext, i: PInnerContext, n: PNode): PNode = else: # captured symbol? result = idNodeTableGet(i.localsToAccess, n.sym) - of nkLambdaKinds: - result = transformInnerProc(o, i, n.sons[namePos]) - of nkProcDef, nkMethodDef, nkConverterDef, nkMacroDef, nkTemplateDef, - nkIteratorDef: + of nkLambdaKinds, nkIteratorDef: + if n.typ != nil: + result = transformInnerProc(o, i, n.sons[namePos]) + of nkProcDef, nkMethodDef, nkConverterDef, nkMacroDef, nkTemplateDef: # don't recurse here: discard else: @@ -400,8 +429,9 @@ proc searchForInnerProcs(o: POuterContext, n: PNode) = if inner.closureParam != nil: let ti = transformInnerProc(o, inner, body) if ti != nil: n.sym.ast.sons[bodyPos] = ti - of nkLambdaKinds: - searchForInnerProcs(o, n.sons[namePos]) + of nkLambdaKinds, nkIteratorDef: + if n.typ != nil: + searchForInnerProcs(o, n.sons[namePos]) of nkWhileStmt, nkForStmt, nkParForStmt, nkBlockStmt: # some nodes open a new scope, so they are candidates for the insertion # of closure creation; however for simplicity we merge closures between @@ -437,8 +467,7 @@ proc searchForInnerProcs(o: POuterContext, n: PNode) = searchForInnerProcs(o, it.sons[L-1]) else: internalError(it.info, "transformOuter") - of nkProcDef, nkMethodDef, nkConverterDef, nkMacroDef, nkTemplateDef, - nkIteratorDef: + of nkProcDef, nkMethodDef, nkConverterDef, nkMacroDef, nkTemplateDef: # don't recurse here: # XXX recurse here and setup 'up' pointers discard @@ -535,10 +564,10 @@ proc transformOuterProc(o: POuterContext, n: PNode): PNode = assert result != nil, "cannot find: " & local.name.s # else it is captured by copy and this means that 'outer' should continue # to access the local as a local. - of nkLambdaKinds: - result = transformOuterProc(o, n.sons[namePos]) - of nkProcDef, nkMethodDef, nkConverterDef, nkMacroDef, nkTemplateDef, - nkIteratorDef: + of nkLambdaKinds, nkIteratorDef: + if n.typ != nil: + result = transformOuterProc(o, n.sons[namePos]) + of nkProcDef, nkMethodDef, nkConverterDef, nkMacroDef, nkTemplateDef: # don't recurse here: discard of nkHiddenStdConv, nkHiddenSubConv, nkConv: @@ -607,11 +636,14 @@ type tup: PType proc newIterResult(iter: PSym): PSym = - result = iter.ast.sons[resultPos].sym - when false: + if resultPos < iter.ast.len: + result = iter.ast.sons[resultPos].sym + else: + # XXX a bit hacky: result = newSym(skResult, getIdent":result", iter, iter.info) result.typ = iter.typ.sons[0] incl(result.flags, sfUsed) + iter.ast.add newSymNode(result) proc interestingIterVar(s: PSym): bool {.inline.} = result = s.kind in {skVar, skLet, skTemp, skForVar} and sfGlobal notin s.flags @@ -663,36 +695,40 @@ proc transfIterBody(c: var TIterContext, n: PNode): PNode = let x = transfIterBody(c, n.sons[i]) if x != nil: n.sons[i] = x -proc getStateType(iter: PSym): PType = - var n = newNodeI(nkRange, iter.info) - addSon(n, newIntNode(nkIntLit, -1)) - addSon(n, newIntNode(nkIntLit, 0)) - result = newType(tyRange, iter) - result.n = n - rawAddSon(result, getSysType(tyInt)) - -proc liftIterator*(iter: PSym, body: PNode): PNode = - var c: TIterContext +proc initIterContext(c: var TIterContext, iter: PSym) = c.iter = iter c.capturedVars = initIntSet() - c.tup = newType(tyTuple, iter) - c.tup.n = newNodeI(nkRecList, iter.info) + var cp = getEnvParam(iter) + if cp == nil: + c.tup = newType(tyTuple, iter) + c.tup.n = newNodeI(nkRecList, iter.info) - var cp = newSym(skParam, getIdent(paramName), iter, iter.info) - incl(cp.flags, sfFromGeneric) - cp.typ = newType(tyRef, iter) - rawAddSon(cp.typ, c.tup) - c.closureParam = cp - addHiddenParam(iter, cp) + cp = newSym(skParam, getIdent(paramName), iter, iter.info) + incl(cp.flags, sfFromGeneric) + cp.typ = newType(tyRef, iter) + rawAddSon(cp.typ, c.tup) + addHiddenParam(iter, cp) - c.state = newSym(skField, getIdent(":state"), iter, iter.info) - c.state.typ = getStateType(iter) - addField(c.tup, c.state) + c.state = createStateField(iter) + addField(c.tup, c.state) + else: + c.tup = cp.typ.sons[0] + assert c.tup.kind == tyTuple + if c.tup.len > 0: + c.state = c.tup.n[0].sym + else: + c.state = createStateField(iter) + addField(c.tup, c.state) + c.closureParam = cp if iter.typ.sons[0] != nil: c.resultSym = newIterResult(iter) - iter.ast.add(newSymNode(c.resultSym)) + #iter.ast.add(newSymNode(c.resultSym)) + +proc liftIterator*(iter: PSym, body: PNode): PNode = + var c: TIterContext + initIterContext c, iter result = newNodeI(nkStmtList, iter.info) var gs = newNodeI(nkGotoState, iter.info) @@ -716,12 +752,14 @@ proc liftIterator*(iter: PSym, body: PNode): PNode = proc liftIterSym*(n: PNode): PNode = # transforms (iter) to (let env = newClosure[iter](); (iter, env)) - result = newNodeIT(nkStmtListExpr, n.info, n.typ) let iter = n.sym assert iter.kind == skIterator + if sfClosureCreated in iter.flags: return n + + result = newNodeIT(nkStmtListExpr, n.info, n.typ) + var env = copySym(getHiddenParam(iter)) env.kind = skLet - var v = newNodeI(nkVarSection, n.info) addVar(v, newSymNode(env)) result.add(v) @@ -766,7 +804,7 @@ proc liftForLoop*(body: PNode): PNode = # static binding? var env: PSym if call[0].kind == nkSym and call[0].sym.kind == skIterator: - # createClose() + # createClosure() let iter = call[0].sym assert iter.kind == skIterator env = copySym(getHiddenParam(iter)) |