summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--lib/pure/sugar.nim16
-rw-r--r--tests/stdlib/tsugar.nim60
2 files changed, 73 insertions, 3 deletions
diff --git a/lib/pure/sugar.nim b/lib/pure/sugar.nim
index ff2a3bb58..d8e00d1f8 100644
--- a/lib/pure/sugar.nim
+++ b/lib/pure/sugar.nim
@@ -231,9 +231,19 @@ macro capture*(locals: varargs[typed], body: untyped): untyped {.since: (1, 1).}
   let locals = if locals.len == 1 and locals[0].kind == nnkBracket: locals[0]
                else: locals
   for arg in locals:
-    if arg.strVal == "result":
-      error("The variable name cannot be `result`!", arg)
-    params.add(newIdentDefs(ident(arg.strVal), freshIdentNodes getTypeInst arg))
+    proc getIdent(n: NimNode): NimNode =
+      case n.kind
+      of nnkIdent, nnkSym:
+        let nStr = n.strVal
+        if nStr == "result":
+          error("The variable name cannot be `result`!", n)
+        result = ident(nStr)
+      of nnkHiddenDeref: result = n[0].getIdent()
+      else:
+        error("The argument to be captured `" & n.repr & "` is not a pure identifier. " &
+          "It is an unsupported `" & $n.kind & "` node.", n)
+    let argName = getIdent(arg)
+    params.add(newIdentDefs(argName, freshIdentNodes getTypeInst arg))
   result = newNimNode(nnkCall)
   result.add(newProc(newEmptyNode(), params, body, nnkLambda))
   for arg in locals: result.add(arg)
diff --git a/tests/stdlib/tsugar.nim b/tests/stdlib/tsugar.nim
index 9c1213901..6ef3ae519 100644
--- a/tests/stdlib/tsugar.nim
+++ b/tests/stdlib/tsugar.nim
@@ -6,6 +6,10 @@ x + y = 30
 import std/[sugar, algorithm, random, sets, tables, strutils]
 import std/[syncio, assertions]
 
+type # for capture test, ref #20679
+  FooCapture = ref object
+    x: int
+
 template main() =
   block: # `=>`
     block:
@@ -96,6 +100,62 @@ template main() =
         let foo = i + 1
         doAssert p() == foo
 
+    block: # issue #20679
+      # this should compile. Previously was broken as `var int` is an `nnkHiddenDeref`
+      # which was not handled correctly
+
+      block:
+        var x = 5
+        var s1 = newSeq[proc (): int](2)
+        proc function(data: var int) =
+          for i in 0 ..< 2:
+            data = (i+1) * data
+            capture data:
+              s1[i] = proc(): int = data
+        function(x)
+        doAssert s1[0]() == 5
+        doAssert s1[1]() == 10
+
+
+      block:
+        var y = @[5, 10]
+        var s2 = newSeq[proc (): seq[int]](2)
+        proc functionS(data: var seq[int]) =
+          for i in 0 ..< 2:
+            data.add (i+1) * 5
+            capture data:
+              s2[i] = proc(): seq[int] = data
+        functionS(y)
+        doAssert s2[0]() == @[5, 10, 5]
+        doAssert s2[1]() == @[5, 10, 5, 10]
+
+
+      template typeT(typ, val: untyped): untyped =
+        var x = val
+        var s = newSeq[proc (): typ](2)
+
+        proc functionT[T](data: var T) =
+          for i in 0 ..< 2:
+            if i == 1:
+              data = default(T)
+            capture data:
+              s[i] = proc (): T = data
+
+        functionT(x)
+        doAssert s[0]() == val
+        doAssert s[1]() == x # == default
+        doAssert s[1]() == default(typ)
+
+      block:
+        var x = 1.1
+        typeT(float, x)
+      block:
+        var x = "hello"
+        typeT(string, x)
+      block:
+        var f = FooCapture(x: 5)
+        typeT(FooCapture, f)
+
   block: # dup
     block dup_with_field:
       type