summary refs log tree commit diff stats
diff options
context:
space:
mode:
authorJason Beetham <beefers331@gmail.com>2023-04-08 03:40:43 -0600
committerGitHub <noreply@github.com>2023-04-08 11:40:43 +0200
commit686c75cef032481630e27dad3731df98e067b75e (patch)
tree43edad1dd73a05c33b492bf928820205c1bd0413
parenta37a83cbff89867be9cbc5ba2f50c37e99efe0f1 (diff)
downloadNim-686c75cef032481630e27dad3731df98e067b75e.tar.gz
`for` loop expression can now have generated `iterator`'s called (#21627)
A for expression now can have a generated iterator, allowing for more composable iterables
-rw-r--r--compiler/semstmts.nim3
-rw-r--r--tests/iter/tgeniteratorinblock.nim54
2 files changed, 56 insertions, 1 deletions
diff --git a/compiler/semstmts.nim b/compiler/semstmts.nim
index 2aa953d93..10cf65305 100644
--- a/compiler/semstmts.nim
+++ b/compiler/semstmts.nim
@@ -1061,7 +1061,8 @@ proc semFor(c: PContext, n: PNode; flags: TExprFlags): PNode =
   result = n
   n[^2] = semExprNoDeref(c, n[^2], {efWantIterator})
   var call = n[^2]
-  if call.kind == nkStmtListExpr and isTrivalStmtExpr(call):
+
+  if call.kind == nkStmtListExpr and (isTrivalStmtExpr(call) or (call.lastSon.kind in nkCallKinds and call.lastSon[0].sym.kind == skIterator)):
     call = call.lastSon
     n[^2] = call
   let isCallExpr = call.kind in nkCallKinds
diff --git a/tests/iter/tgeniteratorinblock.nim b/tests/iter/tgeniteratorinblock.nim
new file mode 100644
index 000000000..2ab903996
--- /dev/null
+++ b/tests/iter/tgeniteratorinblock.nim
@@ -0,0 +1,54 @@
+discard """
+  output: '''30
+60
+90
+150
+180
+210
+240
+60
+180
+240
+[60, 180, 240]
+[60, 180]'''
+"""
+import std/enumerate
+
+template map[T; Y](i: iterable[T], fn: proc(x: T): Y): untyped =
+  iterator internal(): Y  {.gensym.} =
+    for it in i:
+      yield fn(it)
+  internal()
+
+template filter[T](i: iterable[T], fn: proc(x: T): bool): untyped =
+  iterator internal(): T {.gensym.} =
+    for it in i:
+      if fn(it):
+        yield it
+  internal()
+
+template group[T](i: iterable[T], amount: static int): untyped =
+  iterator internal(): array[amount, T] {.gensym.} =
+    var val: array[amount, T]
+    for ind, it in enumerate i:
+      val[ind mod amount] = it
+      if ind mod amount == amount - 1:
+        yield val
+  internal()
+
+var a = [10, 20, 30, 50, 60, 70, 80]
+
+proc mapFn(x: int): int = x * 3
+proc filterFn(x: int): bool = x mod 20 == 0
+
+for x in a.items.map(mapFn):
+  echo x
+
+for y in a.items.map(mapFn).filter(filterFn):
+  echo y
+
+for y in a.items.map(mapFn).filter(filterFn).group(3):
+  echo y
+
+for y in a.items.map(mapFn).filter(filterFn).group(2):
+  echo y