summary refs log tree commit diff stats
diff options
context:
space:
mode:
authorBung <crc32@qq.com>2023-08-09 15:43:39 +0800
committerGitHub <noreply@github.com>2023-08-09 09:43:39 +0200
commit989da75b84e15a6d585e8b03d2bcd0c42a90c2fb (patch)
tree348122cf414a84389dbbfa676927c1e1fdec642d
parentc622e58db968da29e743472d6619137598cc546a (diff)
downloadNim-989da75b84e15a6d585e8b03d2bcd0c42a90c2fb.tar.gz
fix #20891 Illegal capture error of env its self (#22414)
* fix #20891 Illegal capture error of env its self

* fix innerClosure too earlier, make condition shorter
-rw-r--r--compiler/ast.nim6
-rw-r--r--compiler/lambdalifting.nim10
-rw-r--r--tests/iter/t20891.nim28
3 files changed, 41 insertions, 3 deletions
diff --git a/compiler/ast.nim b/compiler/ast.nim
index 55ee0e20d..fc3f52378 100644
--- a/compiler/ast.nim
+++ b/compiler/ast.nim
@@ -2083,6 +2083,12 @@ proc isClosureIterator*(typ: PType): bool {.inline.} =
 proc isClosure*(typ: PType): bool {.inline.} =
   typ.kind == tyProc and typ.callConv == ccClosure
 
+proc isNimcall*(s: PSym): bool {.inline.} =
+  s.typ.callConv == ccNimCall
+
+proc isExplicitCallConv*(s: PSym): bool {.inline.} =
+  tfExplicitCallConv in s.typ.flags
+
 proc isSinkParam*(s: PSym): bool {.inline.} =
   s.kind == skParam and (s.typ.kind == tySink or tfHasOwned in s.typ.flags)
 
diff --git a/compiler/lambdalifting.nim b/compiler/lambdalifting.nim
index ac4c160f9..37c913fe2 100644
--- a/compiler/lambdalifting.nim
+++ b/compiler/lambdalifting.nim
@@ -297,17 +297,19 @@ proc freshVarForClosureIter*(g: ModuleGraph; s: PSym; idgen: IdGenerator; owner:
 
 proc markAsClosure(g: ModuleGraph; owner: PSym; n: PNode) =
   let s = n.sym
+  let isEnv = s.name.id == getIdent(g.cache, ":env").id
   if illegalCapture(s):
     localError(g.config, n.info,
       ("'$1' is of type <$2> which cannot be captured as it would violate memory" &
        " safety, declared here: $3; using '-d:nimNoLentIterators' helps in some cases." &
        " Consider using a <ref $2> which can be captured.") %
       [s.name.s, typeToString(s.typ), g.config$s.info])
-  elif not (owner.typ.callConv == ccClosure or owner.typ.callConv == ccNimCall and tfExplicitCallConv notin owner.typ.flags):
+  elif not (owner.typ.isClosure or owner.isNimcall and not owner.isExplicitCallConv or isEnv):
     localError(g.config, n.info, "illegal capture '$1' because '$2' has the calling convention: <$3>" %
       [s.name.s, owner.name.s, $owner.typ.callConv])
   incl(owner.typ.flags, tfCapturesEnv)
-  owner.typ.callConv = ccClosure
+  if not isEnv:
+    owner.typ.callConv = ccClosure
 
 type
   DetectionPass = object
@@ -448,6 +450,8 @@ proc detectCapturedVars(n: PNode; owner: PSym; c: var DetectionPass) =
         let body = transformBody(c.graph, c.idgen, s, useCache)
         detectCapturedVars(body, s, c)
     let ow = s.skipGenericOwner
+    let innerClosure = innerProc and s.typ.callConv == ccClosure and not s.isIterator
+    let interested = interestingVar(s)
     if ow == owner:
       if owner.isIterator:
         c.somethingToDo = true
@@ -462,7 +466,7 @@ proc detectCapturedVars(n: PNode; owner: PSym; c: var DetectionPass) =
             else:
               discard addField(obj, s, c.graph.cache, c.idgen)
     # direct or indirect dependency:
-    elif (innerProc and not s.isIterator and s.typ.callConv == ccClosure) or interestingVar(s):
+    elif innerClosure or interested:
       discard """
         proc outer() =
           var x: int
diff --git a/tests/iter/t20891.nim b/tests/iter/t20891.nim
new file mode 100644
index 000000000..34deec41b
--- /dev/null
+++ b/tests/iter/t20891.nim
@@ -0,0 +1,28 @@
+import macros, tables
+
+var mapping {.compileTime.}: Table[string, NimNode]
+
+macro register(a: static[string], b: typed): untyped =
+  mapping[a] = b
+
+macro getPtr(a: static[string]): untyped =
+  result = mapping[a]
+
+proc foo() =
+  iterator it() {.closure.} =
+    discard
+  proc getIterPtr(): pointer {.nimcall.} =
+    rawProc(it)
+  register("foo", getIterPtr())
+  discard getIterPtr() # Comment either this to make it work
+foo() # or this
+
+proc bar() =
+  iterator it() {.closure.} =
+    discard getPtr("foo") # Or this
+    discard
+  proc getIterPtr(): pointer {.nimcall.} =
+    rawProc(it)
+  register("bar", getIterPtr())
+  discard getIterPtr()
+bar()