diff options
-rw-r--r-- | lib/pure/sugar.nim | 16 | ||||
-rw-r--r-- | tests/stdlib/tsugar.nim | 60 |
2 files changed, 73 insertions, 3 deletions
diff --git a/lib/pure/sugar.nim b/lib/pure/sugar.nim index ff2a3bb58..d8e00d1f8 100644 --- a/lib/pure/sugar.nim +++ b/lib/pure/sugar.nim @@ -231,9 +231,19 @@ macro capture*(locals: varargs[typed], body: untyped): untyped {.since: (1, 1).} let locals = if locals.len == 1 and locals[0].kind == nnkBracket: locals[0] else: locals for arg in locals: - if arg.strVal == "result": - error("The variable name cannot be `result`!", arg) - params.add(newIdentDefs(ident(arg.strVal), freshIdentNodes getTypeInst arg)) + proc getIdent(n: NimNode): NimNode = + case n.kind + of nnkIdent, nnkSym: + let nStr = n.strVal + if nStr == "result": + error("The variable name cannot be `result`!", n) + result = ident(nStr) + of nnkHiddenDeref: result = n[0].getIdent() + else: + error("The argument to be captured `" & n.repr & "` is not a pure identifier. " & + "It is an unsupported `" & $n.kind & "` node.", n) + let argName = getIdent(arg) + params.add(newIdentDefs(argName, freshIdentNodes getTypeInst arg)) result = newNimNode(nnkCall) result.add(newProc(newEmptyNode(), params, body, nnkLambda)) for arg in locals: result.add(arg) diff --git a/tests/stdlib/tsugar.nim b/tests/stdlib/tsugar.nim index 9c1213901..6ef3ae519 100644 --- a/tests/stdlib/tsugar.nim +++ b/tests/stdlib/tsugar.nim @@ -6,6 +6,10 @@ x + y = 30 import std/[sugar, algorithm, random, sets, tables, strutils] import std/[syncio, assertions] +type # for capture test, ref #20679 + FooCapture = ref object + x: int + template main() = block: # `=>` block: @@ -96,6 +100,62 @@ template main() = let foo = i + 1 doAssert p() == foo + block: # issue #20679 + # this should compile. Previously was broken as `var int` is an `nnkHiddenDeref` + # which was not handled correctly + + block: + var x = 5 + var s1 = newSeq[proc (): int](2) + proc function(data: var int) = + for i in 0 ..< 2: + data = (i+1) * data + capture data: + s1[i] = proc(): int = data + function(x) + doAssert s1[0]() == 5 + doAssert s1[1]() == 10 + + + block: + var y = @[5, 10] + var s2 = newSeq[proc (): seq[int]](2) + proc functionS(data: var seq[int]) = + for i in 0 ..< 2: + data.add (i+1) * 5 + capture data: + s2[i] = proc(): seq[int] = data + functionS(y) + doAssert s2[0]() == @[5, 10, 5] + doAssert s2[1]() == @[5, 10, 5, 10] + + + template typeT(typ, val: untyped): untyped = + var x = val + var s = newSeq[proc (): typ](2) + + proc functionT[T](data: var T) = + for i in 0 ..< 2: + if i == 1: + data = default(T) + capture data: + s[i] = proc (): T = data + + functionT(x) + doAssert s[0]() == val + doAssert s[1]() == x # == default + doAssert s[1]() == default(typ) + + block: + var x = 1.1 + typeT(float, x) + block: + var x = "hello" + typeT(string, x) + block: + var f = FooCapture(x: 5) + typeT(FooCapture, f) + block: # dup block dup_with_field: type |