summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--changelog.md5
-rw-r--r--compiler/lowerings.nim4
-rw-r--r--compiler/semexprs.nim8
-rw-r--r--compiler/semstmts.nim59
-rw-r--r--compiler/transf.nim3
-rw-r--r--lib/system/iterators.nim5
-rw-r--r--tests/iter/tmoditer.nim59
7 files changed, 116 insertions, 27 deletions
diff --git a/changelog.md b/changelog.md
index b7c2214c9..38bb412d0 100644
--- a/changelog.md
+++ b/changelog.md
@@ -61,6 +61,11 @@ type
 
 ## Language additions
 
+- Inline iterators returning `lent T` types are now supported, similarly to iterators returning `var T`:
+```nim
+iterator myitems[T](x: openarray[T]): lent T
+iterator mypairs[T](x: openarray[T]): tuple[idx: int, val: lent T]
+```
 
 ## Language changes
 
diff --git a/compiler/lowerings.nim b/compiler/lowerings.nim
index 9ff3ece33..fff6c75ca 100644
--- a/compiler/lowerings.nim
+++ b/compiler/lowerings.nim
@@ -21,8 +21,8 @@ proc newDeref*(n: PNode): PNode {.inline.} =
 
 proc newTupleAccess*(g: ModuleGraph; tup: PNode, i: int): PNode =
   if tup.kind == nkHiddenAddr:
-    result = newNodeIT(nkHiddenAddr, tup.info, tup.typ.skipTypes(abstractInst+{tyPtr, tyVar}))
-    result.addSon(newNodeIT(nkBracketExpr, tup.info, tup.typ.skipTypes(abstractInst+{tyPtr, tyVar}).sons[i]))
+    result = newNodeIT(nkHiddenAddr, tup.info, tup.typ.skipTypes(abstractInst+{tyPtr, tyVar, tyLent}))
+    result.addSon(newNodeIT(nkBracketExpr, tup.info, tup.typ.skipTypes(abstractInst+{tyPtr, tyVar, tyLent}).sons[i]))
     addSon(result[0], tup[0])
     var lit = newNodeIT(nkIntLit, tup.info, getSysType(g, tup.info, tyInt))
     lit.intVal = i
diff --git a/compiler/semexprs.nim b/compiler/semexprs.nim
index 765110d56..a3d92da8c 100644
--- a/compiler/semexprs.nim
+++ b/compiler/semexprs.nim
@@ -1781,21 +1781,21 @@ proc semYieldVarResult(c: PContext, n: PNode, restype: PType) =
   var t = skipTypes(restype, {tyGenericInst, tyAlias, tySink})
   case t.kind
   of tyVar, tyLent:
-    if t.kind == tyVar: t.flags.incl tfVarIsPtr # bugfix for #4048, #4910, #6892
+    t.flags.incl tfVarIsPtr # bugfix for #4048, #4910, #6892
     if n.sons[0].kind in {nkHiddenStdConv, nkHiddenSubConv}:
       n.sons[0] = n.sons[0].sons[1]
     n.sons[0] = takeImplicitAddr(c, n.sons[0], t.kind == tyLent)
   of tyTuple:
     for i in 0..<t.sonsLen:
-      var e = skipTypes(t.sons[i], {tyGenericInst, tyAlias, tySink})
+      let e = skipTypes(t.sons[i], {tyGenericInst, tyAlias, tySink})
       if e.kind in {tyVar, tyLent}:
-        if e.kind == tyVar: e.flags.incl tfVarIsPtr # bugfix for #4048, #4910, #6892
+        e.flags.incl tfVarIsPtr # bugfix for #4048, #4910, #6892
         if n.sons[0].kind in {nkPar, nkTupleConstr}:
           n.sons[0].sons[i] = takeImplicitAddr(c, n.sons[0].sons[i], e.kind == tyLent)
         elif n.sons[0].kind in {nkHiddenStdConv, nkHiddenSubConv} and
              n.sons[0].sons[1].kind in {nkPar, nkTupleConstr}:
           var a = n.sons[0].sons[1]
-          a.sons[i] = takeImplicitAddr(c, a.sons[i], false)
+          a.sons[i] = takeImplicitAddr(c, a.sons[i], e.kind == tyLent)
         else:
           localError(c.config, n.sons[0].info, errXExpected, "tuple constructor")
   else: discard
diff --git a/compiler/semstmts.nim b/compiler/semstmts.nim
index 5a0aac40e..dfa592549 100644
--- a/compiler/semstmts.nim
+++ b/compiler/semstmts.nim
@@ -678,25 +678,30 @@ proc semForVars(c: PContext, n: PNode; flags: TExprFlags): PNode =
   var length = sonsLen(n)
   let iterBase = n.sons[length-2].typ
   var iter = skipTypes(iterBase, {tyGenericInst, tyAlias, tySink})
+  var iterAfterVarLent = iter.skipTypes({tyLent, tyVar})
   # length == 3 means that there is one for loop variable
   # and thus no tuple unpacking:
-  if iter.kind != tyTuple or length == 3:
+  if iterAfterVarLent.kind != tyTuple or length == 3:
     if length == 3:
       if n.sons[0].kind == nkVarTuple:
-        var mutable = false
-        if iter.kind == tyVar:
-          iter = iter.skipTypes({tyVar})
-          mutable = true
-        if sonsLen(n[0])-1 != sonsLen(iter):
+        if sonsLen(n[0])-1 != sonsLen(iterAfterVarLent):
           localError(c.config, n[0].info, errWrongNumberOfVariables)
         for i in 0 ..< sonsLen(n[0])-1:
           var v = symForVar(c, n[0][i])
           if getCurrOwner(c).kind == skModule: incl(v.flags, sfGlobal)
-          if mutable:
-            v.typ = newTypeS(tyVar, c)
-            v.typ.sons.add iter[i]
-          else:
-            v.typ = iter.sons[i]
+          case iter.kind
+            of tyVar:
+              v.typ = newTypeS(tyVar, c)
+              v.typ.sons.add iterAfterVarLent[i]
+              if tfVarIsPtr in iter.flags:
+                v.typ.flags.incl tfVarIsPtr
+            of tyLent:
+              v.typ = newTypeS(tyLent, c)
+              v.typ.sons.add iterAfterVarLent[i]
+              if tfVarIsPtr in iter.flags:
+                v.typ.flags.incl tfVarIsPtr
+            else:
+              v.typ = iter.sons[i]
           n.sons[0][i] = newSymNode(v)
           if sfGenSym notin v.flags: addDecl(c, v)
           elif v.owner == nil: v.owner = getCurrOwner(c)
@@ -712,15 +717,22 @@ proc semForVars(c: PContext, n: PNode; flags: TExprFlags): PNode =
         elif v.owner == nil: v.owner = getCurrOwner(c)
     else:
       localError(c.config, n.info, errWrongNumberOfVariables)
-  elif length-2 != sonsLen(iter):
+  elif length-2 != sonsLen(iterAfterVarLent):
     localError(c.config, n.info, errWrongNumberOfVariables)
   else:
     for i in 0 .. length - 3:
       if n.sons[i].kind == nkVarTuple:
         var mutable = false
-        if iter[i].kind == tyVar:
-          iter[i] = iter[i].skipTypes({tyVar})
-          mutable = true
+        var isLent = false
+        iter[i] = case iter[i].kind
+          of tyVar:
+            mutable = true
+            iter[i].skipTypes({tyVar})
+          of tyLent:
+            isLent = true
+            iter[i].skipTypes({tyLent})
+          else: iter[i]
+
         if sonsLen(n[i])-1 != sonsLen(iter[i]):
           localError(c.config, n[i].info, errWrongNumberOfVariables)
         for j in 0 ..< sonsLen(n[i])-1:
@@ -729,6 +741,9 @@ proc semForVars(c: PContext, n: PNode; flags: TExprFlags): PNode =
           if mutable:
             v.typ = newTypeS(tyVar, c)
             v.typ.sons.add iter[i][j]
+          elif isLent:
+            v.typ = newTypeS(tyLent, c)
+            v.typ.sons.add iter[i][j]
           else:
             v.typ = iter[i][j]
           n.sons[i][j] = newSymNode(v)
@@ -737,7 +752,19 @@ proc semForVars(c: PContext, n: PNode; flags: TExprFlags): PNode =
       else:
         var v = symForVar(c, n.sons[i])
         if getCurrOwner(c).kind == skModule: incl(v.flags, sfGlobal)
-        v.typ = iter.sons[i]
+        case iter.kind
+        of tyVar:
+          v.typ = newTypeS(tyVar, c)
+          v.typ.sons.add iterAfterVarLent[i]
+          if tfVarIsPtr in iter.flags:
+            v.typ.flags.incl tfVarIsPtr
+        of tyLent:
+          v.typ = newTypeS(tyLent, c)
+          v.typ.sons.add iterAfterVarLent[i]
+          if tfVarIsPtr in iter.flags:
+            v.typ.flags.incl tfVarIsPtr
+        else:
+          v.typ = iter.sons[i]
         n.sons[i] = newSymNode(v)
         if sfGenSym notin v.flags:
           if not isDiscardUnderscore(v): addDecl(c, v)
diff --git a/compiler/transf.nim b/compiler/transf.nim
index 555583685..836a6154c 100644
--- a/compiler/transf.nim
+++ b/compiler/transf.nim
@@ -371,8 +371,7 @@ proc transformYield(c: PTransf, n: PNode): PTransNode =
   # c.transCon.forStmt.len == 3 means that there is one for loop variable
   # and thus no tuple unpacking:
   if e.typ.isNil: return result # can happen in nimsuggest for unknown reasons
-  if skipTypes(e.typ, {tyGenericInst, tyAlias, tySink}).kind == tyTuple and
-      c.transCon.forStmt.len != 3:
+  if c.transCon.forStmt.len != 3:
     e = skipConv(e)
     if e.kind in {nkPar, nkTupleConstr}:
       for i in 0 ..< sonsLen(e):
diff --git a/lib/system/iterators.nim b/lib/system/iterators.nim
index 117ec123d..549aa5886 100644
--- a/lib/system/iterators.nim
+++ b/lib/system/iterators.nim
@@ -85,7 +85,7 @@ iterator pairs*[T](a: openArray[T]): tuple[key: int, val: T] {.inline.} =
     yield (i, a[i])
     inc(i)
 
-iterator mpairs*[T](a: var openArray[T]): tuple[key:int, val:var T]{.inline.} =
+iterator mpairs*[T](a: var openArray[T]): tuple[key: int, val: var T]{.inline.} =
   ## Iterates over each item of `a`. Yields ``(index, a[index])`` pairs.
   ## ``a[index]`` can be modified.
   var i = 0
@@ -102,7 +102,7 @@ iterator pairs*[IX, T](a: array[IX, T]): tuple[key: IX, val: T] {.inline.} =
       if i >= high(IX): break
       inc(i)
 
-iterator mpairs*[IX, T](a:var array[IX, T]):tuple[key:IX,val:var T] {.inline.} =
+iterator mpairs*[IX, T](a: var array[IX, T]): tuple[key: IX, val: var T] {.inline.} =
   ## Iterates over each item of `a`. Yields ``(index, a[index])`` pairs.
   ## ``a[index]`` can be modified.
   var i = low(IX)
@@ -179,7 +179,6 @@ iterator mpairs*(a: var cstring): tuple[key: int, val: var char] {.inline.} =
       yield (i, a[i])
       inc(i)
 
-
 iterator items*[T](a: seq[T]): T {.inline.} =
   ## Iterates over each item of `a`.
   var i = 0
diff --git a/tests/iter/tmoditer.nim b/tests/iter/tmoditer.nim
index 1e6be37e4..34c6321ce 100644
--- a/tests/iter/tmoditer.nim
+++ b/tests/iter/tmoditer.nim
@@ -27,3 +27,62 @@ for a in items(arr):
 
 echo ""
 
+#--------------------------------------------------------------------
+# Lent iterators
+#--------------------------------------------------------------------
+type
+  NonCopyable = object
+    x: int
+
+
+proc `=destroy`(o: var NonCopyable) =
+  discard
+
+proc `=copy`(dst: var NonCopyable, src: NonCopyable) {.error.}
+
+proc `=sink`(dst: var NonCopyable, src: NonCopyable) =
+  dst.x = src.x
+
+iterator lentItems[T](a: openarray[T]): lent T =
+  for i in 0..a.high:
+    yield a[i]
+
+iterator lentPairs[T](a: array[0..1, T]): tuple[key: int, val: lent T] =
+  for i in 0..a.high:
+    yield (i, a[i])
+
+
+let arr1 = [1, 2, 3]
+let arr2 = @["a", "b", "c"]
+let arr3 = [NonCopyable(x: 1), NonCopyable(x: 2)]
+let arr4 = @[(1, "a"), (2, "b"), (3, "c")]
+
+var accum: string
+for x in lentItems(arr1):
+  accum &= $x
+doAssert(accum == "123")
+
+accum = ""
+for x in lentItems(arr2):
+  accum &= $x
+doAssert(accum == "abc")
+
+accum = ""
+for val in lentItems(arr3):
+  accum &= $val.x
+doAssert(accum == "12")
+
+accum = ""
+for i, val in lentPairs(arr3):
+  accum &= $i & "-" & $val.x & " "
+doAssert(accum == "0-1 1-2 ")
+
+accum = ""
+for i, val in lentItems(arr4):
+  accum &= $i & "-" & $val & " "
+doAssert(accum == "1-a 2-b 3-c ")
+
+accum = ""
+for (i, val) in lentItems(arr4):
+  accum &= $i & "-" & $val & " "
+doAssert(accum == "1-a 2-b 3-c ")