summary refs log tree commit diff stats
path: root/lib/pure/sugar.nim
diff options
context:
space:
mode:
authorAntonis Geralis <43617260+planetis-m@users.noreply.github.com>2020-12-03 21:32:18 +0200
committerGitHub <noreply@github.com>2020-12-03 20:32:18 +0100
commit808ab7eae22536167445818c9a4650d36e87d39a (patch)
tree0442fa17095a605eed8cde51d023f25ab5028536 /lib/pure/sugar.nim
parent2220aaeaef74cb6018f4689af8f280db22cb30dd (diff)
downloadNim-808ab7eae22536167445818c9a4650d36e87d39a.tar.gz
add collect with infered init, refs #16078 fixes #14332 (#16089)
* changelog
* add testcase, fixes #14332
Diffstat (limited to 'lib/pure/sugar.nim')
-rw-r--r--lib/pure/sugar.nim100
1 files changed, 66 insertions, 34 deletions
diff --git a/lib/pure/sugar.nim b/lib/pure/sugar.nim
index 047104972..686729515 100644
--- a/lib/pure/sugar.nim
+++ b/lib/pure/sugar.nim
@@ -58,7 +58,7 @@ macro `=>`*(p, b: untyped): untyped =
   runnableExamples:
     proc passTwoAndTwo(f: (int, int) -> int): int =
       f(2, 2)
-  
+
     doAssert passTwoAndTwo((x, y) => x + y) == 4
 
     type
@@ -270,54 +270,79 @@ since (1, 1):
     underscoredCalls(result, calls, tmp)
     result.add tmp
 
-
-proc transLastStmt(n, res, bracketExpr: NimNode): (NimNode, NimNode, NimNode) {.since: (1, 1).} =
+proc trans(n, res, bracketExpr: NimNode): (NimNode, NimNode, NimNode) {.since: (1, 1).} =
   # Looks for the last statement of the last statement, etc...
   case n.kind
-  of nnkIfExpr, nnkIfStmt, nnkTryStmt, nnkCaseStmt:
+  of nnkIfExpr, nnkIfStmt, nnkTryStmt, nnkCaseStmt, nnkWhenStmt:
     result[0] = copyNimTree(n)
     result[1] = copyNimTree(n)
     result[2] = copyNimTree(n)
-    for i in ord(n.kind == nnkCaseStmt)..<n.len:
-      (result[0][i], result[1][^1], result[2][^1]) = transLastStmt(n[i], res, bracketExpr)
+    for i in ord(n.kind == nnkCaseStmt) ..< n.len:
+      (result[0][i], result[1][^1], result[2][^1]) = trans(n[i], res, bracketExpr)
   of nnkStmtList, nnkStmtListExpr, nnkBlockStmt, nnkBlockExpr, nnkWhileStmt,
       nnkForStmt, nnkElifBranch, nnkElse, nnkElifExpr, nnkOfBranch, nnkExceptBranch:
     result[0] = copyNimTree(n)
     result[1] = copyNimTree(n)
     result[2] = copyNimTree(n)
     if n.len >= 1:
-      (result[0][^1], result[1][^1], result[2][^1]) = transLastStmt(n[^1], res, bracketExpr)
+      (result[0][^1], result[1][^1], result[2][^1]) = trans(n[^1],
+          res, bracketExpr)
   of nnkTableConstr:
     result[1] = n[0][0]
     result[2] = n[0][1]
+    if bracketExpr.len == 0:
+      bracketExpr.add(ident"initTable") # don't import tables
     if bracketExpr.len == 1:
-      bracketExpr.add([newCall(bindSym"typeof", newEmptyNode()), newCall(
-          bindSym"typeof", newEmptyNode())])
+      bracketExpr.add([newCall(bindSym"typeof",
+          newEmptyNode()), newCall(bindSym"typeof", newEmptyNode())])
     template adder(res, k, v) = res[k] = v
     result[0] = getAst(adder(res, n[0][0], n[0][1]))
   of nnkCurly:
     result[2] = n[0]
+    if bracketExpr.len == 0:
+      bracketExpr.add(ident"initHashSet")
     if bracketExpr.len == 1:
       bracketExpr.add(newCall(bindSym"typeof", newEmptyNode()))
     template adder(res, v) = res.incl(v)
     result[0] = getAst(adder(res, n[0]))
   else:
     result[2] = n
+    if bracketExpr.len == 0:
+      bracketExpr.add(bindSym"newSeq")
     if bracketExpr.len == 1:
       bracketExpr.add(newCall(bindSym"typeof", newEmptyNode()))
     template adder(res, v) = res.add(v)
     result[0] = getAst(adder(res, n))
 
+proc collectImpl(init, body: NimNode): NimNode {.since: (1, 1).} =
+  let res = genSym(nskVar, "collectResult")
+  var bracketExpr: NimNode
+  if init != nil:
+    expectKind init, {nnkCall, nnkIdent, nnkSym}
+    bracketExpr = newTree(nnkBracketExpr,
+      if init.kind == nnkCall: freshIdentNodes(init[0]) else: freshIdentNodes(init))
+  else:
+    bracketExpr = newTree(nnkBracketExpr)
+  let (resBody, keyType, valueType) = trans(body, res, bracketExpr)
+  if bracketExpr.len == 3:
+    bracketExpr[1][1] = keyType
+    bracketExpr[2][1] = valueType
+  else:
+    bracketExpr[1][1] = valueType
+  let call = newTree(nnkCall, bracketExpr)
+  if init != nil and init.kind == nnkCall:
+    for i in 1 ..< init.len:
+      call.add init[i]
+  result = newTree(nnkStmtListExpr, newVarStmt(res, call), resBody, res)
+
 macro collect*(init, body: untyped): untyped {.since: (1, 1).} =
-  ## Comprehension for seq/set/table collections. ``init`` is
-  ## the init call, and so custom collections are supported.
+  ## Comprehension for seqs/sets/tables.
   ##
-  ## The last statement of ``body`` has special syntax that specifies
-  ## the collection's add operation. Use ``{e}`` for set's ``incl``,
-  ## ``{k: v}`` for table's ``[]=`` and ``e`` for seq's ``add``.
-  ##
-  ## The ``init`` proc can be called with any number of arguments,
-  ## i.e. ``initTable(initialSize)``.
+  ## The last expression of `body` has special syntax that specifies
+  ## the collection's add operation. Use `{e}` for set's `incl`,
+  ## `{k: v}` for table's `[]=` and `e` for seq's `add`.
+  # analyse the body, find the deepest expression 'it' and replace it via
+  # 'result.add it'
   runnableExamples:
     import sets, tables
     let data = @["bird", "word"]
@@ -343,20 +368,27 @@ macro collect*(init, body: untyped): untyped {.since: (1, 1).} =
       for i, d in data.pairs: {i: d}
 
     assert z == {0: "bird", 1: "word"}.toTable
-  # analyse the body, find the deepest expression 'it' and replace it via
-  # 'result.add it'
-  let res = genSym(nskVar, "collectResult")
-  expectKind init, {nnkCall, nnkIdent, nnkSym}
-  let bracketExpr = newTree(nnkBracketExpr,
-    if init.kind == nnkCall: init[0] else: init)
-  let (resBody, keyType, valueType) = transLastStmt(body, res, bracketExpr)
-  if bracketExpr.len == 3:
-    bracketExpr[1][1] = keyType
-    bracketExpr[2][1] = valueType
-  else:
-    bracketExpr[1][1] = valueType
-  let call = newTree(nnkCall, bracketExpr)
-  if init.kind == nnkCall:
-    for i in 1 ..< init.len:
-      call.add init[i]
-  result = newTree(nnkStmtListExpr, newVarStmt(res, call), resBody, res)
+  result = collectImpl(init, body)
+
+macro collect*(body: untyped): untyped {.since: (1, 5).} =
+  ## Same as `collect` but without an `init` parameter.
+  runnableExamples:
+    import sets, tables
+    # Seq:
+    let data = @["bird", "word"]
+    let k = collect:
+      for i, d in data.pairs:
+        if i mod 2 == 0: d
+
+    assert k == @["bird"]
+    ## HashSet:
+    let n = collect:
+      for d in data.items: {d}
+
+    assert n == data.toHashSet
+    ## Table:
+    let m = collect:
+      for i, d in data.pairs: {i: d}
+
+    assert m == {0: "bird", 1: "word"}.toTable
+  result = collectImpl(nil, body)
\ No newline at end of file