summary refs log tree commit diff stats
diff options
context:
space:
mode:
authorb3liever <43617260+b3liever@users.noreply.github.com>2019-11-22 15:40:50 +0200
committerAndreas Rumpf <rumpf_a@web.de>2019-11-22 14:40:50 +0100
commit5bb6c67a45045f41c39b11c085f5ff072e1615d7 (patch)
tree8633167dab3dcacce876ba6d95a5826d979a8015
parente5478b32a891077972c53b0715481fb7c6b4390c (diff)
downloadNim-5bb6c67a45045f41c39b11c085f5ff072e1615d7.tar.gz
add collect macro (#12708)
* add collect macro

* Add to changelog
-rw-r--r--changelog.md2
-rw-r--r--lib/pure/sugar.nim97
2 files changed, 98 insertions, 1 deletions
diff --git a/changelog.md b/changelog.md
index 6320ec158..a1b6b6166 100644
--- a/changelog.md
+++ b/changelog.md
@@ -39,7 +39,7 @@
 - Added `sugar.outplace` for turning in-place algorithms like `sort` and `shuffle` into
   operations that work on a copy of the data and return the mutated copy. As the existing
   `sorted` does.
-
+- Added `sugar.collect` that does comprehension for seq/set/table collections.
 
 ## Library changes
 
diff --git a/lib/pure/sugar.nim b/lib/pure/sugar.nim
index 264f3749c..c811f115c 100644
--- a/lib/pure/sugar.nim
+++ b/lib/pure/sugar.nim
@@ -265,6 +265,89 @@ when (NimMajor, NimMinor) >= (1, 1):
       copyNimNode(call).add callsons,
       tmp)
 
+  proc transLastStmt(n, res, bracketExpr: NimNode): (NimNode, NimNode, NimNode) =
+    # Looks for the last statement of the last statement, etc...
+    case n.kind
+    of nnkStmtList, nnkStmtListExpr, nnkBlockStmt, nnkBlockExpr, nnkWhileStmt,
+        nnkForStmt, nnkIfExpr, nnkIfStmt, nnkTryStmt, nnkCaseStmt,
+        nnkElifBranch, nnkElse, nnkElifExpr:
+      result[0] = copyNimTree(n)
+      result[1] = copyNimTree(n)
+      result[2] = copyNimTree(n)
+      if n.len >= 1:
+        (result[0][^1], result[1][^1], result[2][^1]) = transLastStmt(n[^1], res,
+            bracketExpr)
+    of nnkTableConstr:
+      result[1] = n[0][0]
+      result[2] = n[0][1]
+      bracketExpr.add([newCall(bindSym"typeof", newEmptyNode()), newCall(
+          bindSym"typeof", newEmptyNode())])
+      template adder(res, k, v) = res[k] = v
+      result[0] = getAst(adder(res, n[0][0], n[0][1]))
+    of nnkCurly:
+      result[2] = n[0]
+      bracketExpr.add(newCall(bindSym"typeof", newEmptyNode()))
+      template adder(res, v) = res.incl(v)
+      result[0] = getAst(adder(res, n[0]))
+    else:
+      result[2] = n
+      bracketExpr.add(newCall(bindSym"typeof", newEmptyNode()))
+      template adder(res, v) = res.add(v)
+      result[0] = getAst(adder(res, n))
+
+  macro collect*(init, body: untyped): untyped =
+    ## Comprehension for seq/set/table collections. ``init`` is
+    ## the init call, and so custom collections are supported.
+    ##
+    ## The last statement of ``body`` has special syntax that specifies
+    ## the collection's add operation. Use ``{e}`` for set's ``incl``,
+    ## ``{k: v}`` for table's ``[]=`` and ``e`` for seq's ``add``.
+    ##
+    ## The ``init`` proc can be called with any number of arguments,
+    ## i.e. ``initTable(initialSize)``.
+    runnableExamples:
+      import sets, tables
+      let data = @["bird", "word"]
+      ## seq:
+      let k = collect(newSeq):
+        for i, d in data.pairs:
+          if i mod 2 == 0: d
+
+      assert k == @["bird"]
+      ## seq with initialSize:
+      let x = collect(newSeqOfCap(4)):
+        for i, d in data.pairs:
+          if i mod 2 == 0: d
+
+      assert x == @["bird"]
+      ## HashSet:
+      let y = initHashSet.collect:
+        for d in data.items: {d}
+
+      assert y == data.toHashSet
+      ## Table:
+      let z = collect(initTable(2)):
+        for i, d in data.pairs: {i: d}
+
+      assert z == {1: "word", 0: "bird"}.toTable
+    # analyse the body, find the deepest expression 'it' and replace it via
+    # 'result.add it'
+    let res = genSym(nskVar, "collectResult")
+    expectKind init, {nnkCall, nnkIdent, nnkSym}
+    let bracketExpr = newTree(nnkBracketExpr,
+      if init.kind == nnkCall: init[0] else: init)
+    let (resBody, keyType, valueType) = transLastStmt(body, res, bracketExpr)
+    if bracketExpr.len == 3:
+      bracketExpr[1][1] = keyType
+      bracketExpr[2][1] = valueType
+    else:
+      bracketExpr[1][1] = valueType
+    let call = newTree(nnkCall, bracketExpr)
+    if init.kind == nnkCall:
+      for i in 1 ..< init.len:
+        call.add init[i]
+    result = newTree(nnkStmtListExpr, newVarStmt(res, call), resBody, res)
+
   when isMainModule:
     import algorithm
 
@@ -282,3 +365,17 @@ when (NimMajor, NimMinor) >= (1, 1):
     let c = b.outplace shuffle()
     doAssert c[0] == 1
     doAssert c[1] == 0
+
+    #test collect
+    import sets, tables
+
+    let data = @["bird", "word"] # if this gets stuck in your head, its not my fault
+    assert collect(newSeq, for (i, d) in data.pairs: (if i mod 2 == 0: d)) == @["bird"]
+    assert collect(initTable(2), for (i, d) in data.pairs: {i: d}) == {1: "word",
+          0: "bird"}.toTable
+    assert initHashSet.collect(for d in data.items: {d}) == data.toHashSet
+
+    let x = collect(newSeqOfCap(4)):
+        for (i, d) in data.pairs:
+          if i mod 2 == 0: d
+    assert x == @["bird"]