summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--compiler/lambdalifting.nim16
-rw-r--r--compiler/transf.nim49
-rw-r--r--tests/iter/tkeep_state_between_yield.nim36
3 files changed, 80 insertions, 21 deletions
diff --git a/compiler/lambdalifting.nim b/compiler/lambdalifting.nim
index f75be0ed1..8320292ea 100644
--- a/compiler/lambdalifting.nim
+++ b/compiler/lambdalifting.nim
@@ -221,7 +221,7 @@ proc interestingIterVar(s: PSym): bool {.inline.} =
   # closure iterators quite a bit.
   result = s.kind in {skVar, skLet, skTemp, skForVar} and sfGlobal notin s.flags
 
-template isIterator(owner: PSym): bool =
+template isIterator*(owner: PSym): bool =
   owner.kind == skIterator and owner.typ.callConv == ccClosure
 
 proc liftIterSym*(n: PNode; owner: PSym): PNode =
@@ -243,6 +243,20 @@ proc liftIterSym*(n: PNode; owner: PSym): PNode =
   result.add newCall(getSysSym"internalNew", envAsNode)
   result.add makeClosure(iter, envAsNode, n.info)
 
+proc freshVarForClosureIter*(s, owner: PSym): PNode =
+  let envParam = getHiddenParam(owner)
+  let obj = envParam.typ.lastSon
+  addField(obj, s)
+
+  var access = newSymNode(envParam)
+  assert obj.kind == tyObject
+  let field = getFieldFromObj(obj, s)
+  if field != nil:
+    result = rawIndirectAccess(access, field, s.info)
+  else:
+    localError(s.info, "internal error: cannot generate fresh variable")
+    result = access
+
 # ------------------ new stuff -------------------------------------------
 
 proc markAsClosure(owner: PSym; n: PNode) =
diff --git a/compiler/transf.nim b/compiler/transf.nim
index dd8dd1519..296ea0a0d 100644
--- a/compiler/transf.nim
+++ b/compiler/transf.nim
@@ -93,10 +93,15 @@ proc getCurrOwner(c: PTransf): PSym =
   if c.transCon != nil: result = c.transCon.owner
   else: result = c.module
 
-proc newTemp(c: PTransf, typ: PType, info: TLineInfo): PSym =
-  result = newSym(skTemp, getIdent(genPrefix), getCurrOwner(c), info)
-  result.typ = skipTypes(typ, {tyGenericInst})
-  incl(result.flags, sfFromGeneric)
+proc newTemp(c: PTransf, typ: PType, info: TLineInfo): PNode =
+  let r = newSym(skTemp, getIdent(genPrefix), getCurrOwner(c), info)
+  r.typ = skipTypes(typ, {tyGenericInst})
+  incl(r.flags, sfFromGeneric)
+  let owner = getCurrOwner(c)
+  if owner.isIterator and not c.tooEarly:
+    result = freshVarForClosureIter(r, owner)
+  else:
+    result = newSymNode(r)
 
 proc transform(c: PTransf, n: PNode): PTransNode
 
@@ -141,6 +146,16 @@ proc transformSymAux(c: PTransf, n: PNode): PNode =
 proc transformSym(c: PTransf, n: PNode): PTransNode =
   result = PTransNode(transformSymAux(c, n))
 
+proc freshVar(c: PTransf; v: PSym): PNode =
+  let owner = getCurrOwner(c)
+  if owner.isIterator and not c.tooEarly:
+    result = freshVarForClosureIter(v, owner)
+  else:
+    var newVar = copySym(v)
+    incl(newVar.flags, sfFromGeneric)
+    newVar.owner = owner
+    result = newSymNode(newVar)
+
 proc transformVarSection(c: PTransf, v: PNode): PTransNode =
   result = newTransNode(v)
   for i in countup(0, sonsLen(v)-1):
@@ -150,20 +165,16 @@ proc transformVarSection(c: PTransf, v: PNode): PTransNode =
     elif it.kind == nkIdentDefs:
       if it.sons[0].kind == nkSym:
         internalAssert(it.len == 3)
-        var newVar = copySym(it.sons[0].sym)
-        incl(newVar.flags, sfFromGeneric)
-        # fixes a strange bug for rodgen:
-        #include(it.sons[0].sym.flags, sfFromGeneric);
-        newVar.owner = getCurrOwner(c)
-        idNodeTablePut(c.transCon.mapping, it.sons[0].sym, newSymNode(newVar))
+        let x = freshVar(c, it.sons[0].sym)
+        idNodeTablePut(c.transCon.mapping, it.sons[0].sym, x)
         var defs = newTransNode(nkIdentDefs, it.info, 3)
         if importantComments():
           # keep documentation information:
           PNode(defs).comment = it.comment
-        defs[0] = newSymNode(newVar).PTransNode
+        defs[0] = x.PTransNode
         defs[1] = it.sons[1].PTransNode
         defs[2] = transform(c, it.sons[2])
-        newVar.ast = defs[2].PNode
+        if x.kind == nkSym: x.sym.ast = defs[2].PNode
         result[i] = defs
       else:
         # has been transformed into 'param.x' for closure iterators, so just
@@ -175,11 +186,9 @@ proc transformVarSection(c: PTransf, v: PNode): PTransNode =
       var L = sonsLen(it)
       var defs = newTransNode(it.kind, it.info, L)
       for j in countup(0, L-3):
-        var newVar = copySym(it.sons[j].sym)
-        incl(newVar.flags, sfFromGeneric)
-        newVar.owner = getCurrOwner(c)
-        idNodeTablePut(c.transCon.mapping, it.sons[j].sym, newSymNode(newVar))
-        defs[j] = newSymNode(newVar).PTransNode
+        let x = freshVar(c, it.sons[j].sym)
+        idNodeTablePut(c.transCon.mapping, it.sons[j].sym, x)
+        defs[j] = x.PTransNode
       assert(it.sons[L-2].kind == nkEmpty)
       defs[L-2] = ast.emptyNode.PTransNode
       defs[L-1] = transform(c, it.sons[L-1])
@@ -549,9 +558,9 @@ proc transformFor(c: PTransf, n: PNode): PTransNode =
     of paFastAsgn:
       # generate a temporary and produce an assignment statement:
       var temp = newTemp(c, formal.typ, formal.info)
-      addVar(v, newSymNode(temp))
-      add(stmtList, newAsgnStmt(c, newSymNode(temp), arg.PTransNode))
-      idNodeTablePut(newC.mapping, formal, newSymNode(temp))
+      addVar(v, temp)
+      add(stmtList, newAsgnStmt(c, temp, arg.PTransNode))
+      idNodeTablePut(newC.mapping, formal, temp)
     of paVarAsgn:
       assert(skipTypes(formal.typ, abstractInst).kind == tyVar)
       idNodeTablePut(newC.mapping, formal, arg)
diff --git a/tests/iter/tkeep_state_between_yield.nim b/tests/iter/tkeep_state_between_yield.nim
new file mode 100644
index 000000000..f4f0ee363
--- /dev/null
+++ b/tests/iter/tkeep_state_between_yield.nim
@@ -0,0 +1,36 @@
+discard """
+  output: '''@[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 18, 20, 21, 24, 27, 30, 36, 40, 42]
+1002'''
+"""
+
+import strutils
+
+proc slice[T](iter: iterator(): T {.closure.}, sl: auto): seq[T] =
+  var res: seq[int64] = @[]
+  var i = 0
+  for n in iter():
+    if i > sl.b:
+      break
+    if i >= sl.a:
+      res.add(n)
+    inc i
+  res
+
+iterator harshad(): int64 {.closure.} =
+  for n in 1 .. < int64.high:
+    var sum = 0
+    for ch in string($n):
+      sum += parseInt("" & ch)
+    if n mod sum == 0:
+      yield n
+
+echo harshad.slice 0 .. <20
+
+for n in harshad():
+  if n > 1000:
+    echo n
+    break
+
+
+# bug #3499 last snippet fixed
+# bug 705  last snippet fixed