diff options
author | Antonis Geralis <43617260+planetis-m@users.noreply.github.com> | 2020-12-03 21:32:18 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-12-03 20:32:18 +0100 |
commit | 808ab7eae22536167445818c9a4650d36e87d39a (patch) | |
tree | 0442fa17095a605eed8cde51d023f25ab5028536 /lib/pure/sugar.nim | |
parent | 2220aaeaef74cb6018f4689af8f280db22cb30dd (diff) | |
download | Nim-808ab7eae22536167445818c9a4650d36e87d39a.tar.gz |
add collect with infered init, refs #16078 fixes #14332 (#16089)
* changelog * add testcase, fixes #14332
Diffstat (limited to 'lib/pure/sugar.nim')
-rw-r--r-- | lib/pure/sugar.nim | 100 |
1 files changed, 66 insertions, 34 deletions
diff --git a/lib/pure/sugar.nim b/lib/pure/sugar.nim index 047104972..686729515 100644 --- a/lib/pure/sugar.nim +++ b/lib/pure/sugar.nim @@ -58,7 +58,7 @@ macro `=>`*(p, b: untyped): untyped = runnableExamples: proc passTwoAndTwo(f: (int, int) -> int): int = f(2, 2) - + doAssert passTwoAndTwo((x, y) => x + y) == 4 type @@ -270,54 +270,79 @@ since (1, 1): underscoredCalls(result, calls, tmp) result.add tmp - -proc transLastStmt(n, res, bracketExpr: NimNode): (NimNode, NimNode, NimNode) {.since: (1, 1).} = +proc trans(n, res, bracketExpr: NimNode): (NimNode, NimNode, NimNode) {.since: (1, 1).} = # Looks for the last statement of the last statement, etc... case n.kind - of nnkIfExpr, nnkIfStmt, nnkTryStmt, nnkCaseStmt: + of nnkIfExpr, nnkIfStmt, nnkTryStmt, nnkCaseStmt, nnkWhenStmt: result[0] = copyNimTree(n) result[1] = copyNimTree(n) result[2] = copyNimTree(n) - for i in ord(n.kind == nnkCaseStmt)..<n.len: - (result[0][i], result[1][^1], result[2][^1]) = transLastStmt(n[i], res, bracketExpr) + for i in ord(n.kind == nnkCaseStmt) ..< n.len: + (result[0][i], result[1][^1], result[2][^1]) = trans(n[i], res, bracketExpr) of nnkStmtList, nnkStmtListExpr, nnkBlockStmt, nnkBlockExpr, nnkWhileStmt, nnkForStmt, nnkElifBranch, nnkElse, nnkElifExpr, nnkOfBranch, nnkExceptBranch: result[0] = copyNimTree(n) result[1] = copyNimTree(n) result[2] = copyNimTree(n) if n.len >= 1: - (result[0][^1], result[1][^1], result[2][^1]) = transLastStmt(n[^1], res, bracketExpr) + (result[0][^1], result[1][^1], result[2][^1]) = trans(n[^1], + res, bracketExpr) of nnkTableConstr: result[1] = n[0][0] result[2] = n[0][1] + if bracketExpr.len == 0: + bracketExpr.add(ident"initTable") # don't import tables if bracketExpr.len == 1: - bracketExpr.add([newCall(bindSym"typeof", newEmptyNode()), newCall( - bindSym"typeof", newEmptyNode())]) + bracketExpr.add([newCall(bindSym"typeof", + newEmptyNode()), newCall(bindSym"typeof", newEmptyNode())]) template adder(res, k, v) = res[k] = v result[0] = getAst(adder(res, n[0][0], n[0][1])) of nnkCurly: result[2] = n[0] + if bracketExpr.len == 0: + bracketExpr.add(ident"initHashSet") if bracketExpr.len == 1: bracketExpr.add(newCall(bindSym"typeof", newEmptyNode())) template adder(res, v) = res.incl(v) result[0] = getAst(adder(res, n[0])) else: result[2] = n + if bracketExpr.len == 0: + bracketExpr.add(bindSym"newSeq") if bracketExpr.len == 1: bracketExpr.add(newCall(bindSym"typeof", newEmptyNode())) template adder(res, v) = res.add(v) result[0] = getAst(adder(res, n)) +proc collectImpl(init, body: NimNode): NimNode {.since: (1, 1).} = + let res = genSym(nskVar, "collectResult") + var bracketExpr: NimNode + if init != nil: + expectKind init, {nnkCall, nnkIdent, nnkSym} + bracketExpr = newTree(nnkBracketExpr, + if init.kind == nnkCall: freshIdentNodes(init[0]) else: freshIdentNodes(init)) + else: + bracketExpr = newTree(nnkBracketExpr) + let (resBody, keyType, valueType) = trans(body, res, bracketExpr) + if bracketExpr.len == 3: + bracketExpr[1][1] = keyType + bracketExpr[2][1] = valueType + else: + bracketExpr[1][1] = valueType + let call = newTree(nnkCall, bracketExpr) + if init != nil and init.kind == nnkCall: + for i in 1 ..< init.len: + call.add init[i] + result = newTree(nnkStmtListExpr, newVarStmt(res, call), resBody, res) + macro collect*(init, body: untyped): untyped {.since: (1, 1).} = - ## Comprehension for seq/set/table collections. ``init`` is - ## the init call, and so custom collections are supported. + ## Comprehension for seqs/sets/tables. ## - ## The last statement of ``body`` has special syntax that specifies - ## the collection's add operation. Use ``{e}`` for set's ``incl``, - ## ``{k: v}`` for table's ``[]=`` and ``e`` for seq's ``add``. - ## - ## The ``init`` proc can be called with any number of arguments, - ## i.e. ``initTable(initialSize)``. + ## The last expression of `body` has special syntax that specifies + ## the collection's add operation. Use `{e}` for set's `incl`, + ## `{k: v}` for table's `[]=` and `e` for seq's `add`. + # analyse the body, find the deepest expression 'it' and replace it via + # 'result.add it' runnableExamples: import sets, tables let data = @["bird", "word"] @@ -343,20 +368,27 @@ macro collect*(init, body: untyped): untyped {.since: (1, 1).} = for i, d in data.pairs: {i: d} assert z == {0: "bird", 1: "word"}.toTable - # analyse the body, find the deepest expression 'it' and replace it via - # 'result.add it' - let res = genSym(nskVar, "collectResult") - expectKind init, {nnkCall, nnkIdent, nnkSym} - let bracketExpr = newTree(nnkBracketExpr, - if init.kind == nnkCall: init[0] else: init) - let (resBody, keyType, valueType) = transLastStmt(body, res, bracketExpr) - if bracketExpr.len == 3: - bracketExpr[1][1] = keyType - bracketExpr[2][1] = valueType - else: - bracketExpr[1][1] = valueType - let call = newTree(nnkCall, bracketExpr) - if init.kind == nnkCall: - for i in 1 ..< init.len: - call.add init[i] - result = newTree(nnkStmtListExpr, newVarStmt(res, call), resBody, res) + result = collectImpl(init, body) + +macro collect*(body: untyped): untyped {.since: (1, 5).} = + ## Same as `collect` but without an `init` parameter. + runnableExamples: + import sets, tables + # Seq: + let data = @["bird", "word"] + let k = collect: + for i, d in data.pairs: + if i mod 2 == 0: d + + assert k == @["bird"] + ## HashSet: + let n = collect: + for d in data.items: {d} + + assert n == data.toHashSet + ## Table: + let m = collect: + for i, d in data.pairs: {i: d} + + assert m == {0: "bird", 1: "word"}.toTable + result = collectImpl(nil, body) \ No newline at end of file |