summary refs log tree commit diff stats
path: root/lib/pure/sugar.nim
diff options
context:
space:
mode:
Diffstat (limited to 'lib/pure/sugar.nim')
-rw-r--r--lib/pure/sugar.nim25
1 files changed, 18 insertions, 7 deletions
diff --git a/lib/pure/sugar.nim b/lib/pure/sugar.nim
index ff2a3bb58..90ba20c13 100644
--- a/lib/pure/sugar.nim
+++ b/lib/pure/sugar.nim
@@ -11,7 +11,7 @@
 ## macro system.
 
 import std/private/since
-import macros
+import std/macros
 
 proc checkPragma(ex, prag: var NimNode) =
   since (1, 3):
@@ -203,7 +203,7 @@ proc freshIdentNodes(ast: NimNode): NimNode =
   # see also https://github.com/nim-lang/Nim/pull/8531#issuecomment-410436458
   proc inspect(node: NimNode): NimNode =
     case node.kind:
-    of nnkIdent, nnkSym:
+    of nnkIdent, nnkSym, nnkOpenSymChoice, nnkClosedSymChoice, nnkOpenSym:
       result = ident($node)
     of nnkEmpty, nnkLiterals:
       result = node
@@ -231,9 +231,19 @@ macro capture*(locals: varargs[typed], body: untyped): untyped {.since: (1, 1).}
   let locals = if locals.len == 1 and locals[0].kind == nnkBracket: locals[0]
                else: locals
   for arg in locals:
-    if arg.strVal == "result":
-      error("The variable name cannot be `result`!", arg)
-    params.add(newIdentDefs(ident(arg.strVal), freshIdentNodes getTypeInst arg))
+    proc getIdent(n: NimNode): NimNode =
+      case n.kind
+      of nnkIdent, nnkSym:
+        let nStr = n.strVal
+        if nStr == "result":
+          error("The variable name cannot be `result`!", n)
+        result = ident(nStr)
+      of nnkHiddenDeref: result = n[0].getIdent()
+      else:
+        error("The argument to be captured `" & n.repr & "` is not a pure identifier. " &
+          "It is an unsupported `" & $n.kind & "` node.", n)
+    let argName = getIdent(arg)
+    params.add(newIdentDefs(argName, freshIdentNodes getTypeInst arg))
   result = newNimNode(nnkCall)
   result.add(newProc(newEmptyNode(), params, body, nnkLambda))
   for arg in locals: result.add(arg)
@@ -335,9 +345,10 @@ proc collectImpl(init, body: NimNode): NimNode {.since: (1, 1).} =
   let res = genSym(nskVar, "collectResult")
   var bracketExpr: NimNode
   if init != nil:
-    expectKind init, {nnkCall, nnkIdent, nnkSym}
+    expectKind init, {nnkCall, nnkIdent, nnkSym, nnkClosedSymChoice, nnkOpenSymChoice, nnkOpenSym}
     bracketExpr = newTree(nnkBracketExpr,
-      if init.kind == nnkCall: freshIdentNodes(init[0]) else: freshIdentNodes(init))
+      if init.kind in {nnkCall, nnkClosedSymChoice, nnkOpenSymChoice, nnkOpenSym}:
+        freshIdentNodes(init[0]) else: freshIdentNodes(init))
   else:
     bracketExpr = newTree(nnkBracketExpr)
   let (resBody, keyType, valueType) = trans(body, res, bracketExpr)