summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--compiler/lambdalifting.nim24
-rw-r--r--tests/iter/t21737.nim22
-rw-r--r--tests/iter/t22548.nim21
3 files changed, 62 insertions, 5 deletions
diff --git a/compiler/lambdalifting.nim b/compiler/lambdalifting.nim
index ee1e5b797..ee19eec08 100644
--- a/compiler/lambdalifting.nim
+++ b/compiler/lambdalifting.nim
@@ -316,6 +316,7 @@ type
     processed, capturedVars: IntSet
     ownerToType: Table[int, PType]
     somethingToDo: bool
+    inTypeOf: bool
     graph: ModuleGraph
     idgen: IdGenerator
 
@@ -417,6 +418,9 @@ Consider:
 
 """
 
+proc isTypeOf(n: PNode): bool =
+  n.kind == nkSym and n.sym.magic in {mTypeOf, mType}
+
 proc addClosureParam(c: var DetectionPass; fn: PSym; info: TLineInfo) =
   var cp = getEnvParam(fn)
   let owner = if fn.kind == skIterator: fn else: fn.skipGenericOwner
@@ -455,7 +459,8 @@ proc detectCapturedVars(n: PNode; owner: PSym; c: var DetectionPass) =
         c.somethingToDo = true
         addClosureParam(c, owner, n.info)
         if interestingIterVar(s):
-          if not c.capturedVars.containsOrIncl(s.id):
+          if not c.capturedVars.contains(s.id):
+            if not c.inTypeOf: c.capturedVars.incl(s.id)
             let obj = getHiddenParam(c.graph, owner).typ.skipTypes({tyOwned, tyRef, tyPtr})
             #let obj = c.getEnvTypeForOwner(s.owner).skipTypes({tyOwned, tyRef, tyPtr})
 
@@ -481,10 +486,12 @@ proc detectCapturedVars(n: PNode; owner: PSym; c: var DetectionPass) =
       addClosureParam(c, owner, n.info)
       #echo "capturing ", n.info
       # variable 's' is actually captured:
-      if interestingVar(s) and not c.capturedVars.containsOrIncl(s.id):
-        let obj = c.getEnvTypeForOwner(ow, n.info).skipTypes({tyOwned, tyRef, tyPtr})
-        #getHiddenParam(owner).typ.skipTypes({tyOwned, tyRef, tyPtr})
-        discard addField(obj, s, c.graph.cache, c.idgen)
+      if interestingVar(s):
+        if not c.capturedVars.contains(s.id):
+          if not c.inTypeOf: c.capturedVars.incl(s.id)
+          let obj = c.getEnvTypeForOwner(ow, n.info).skipTypes({tyOwned, tyRef, tyPtr})
+          #getHiddenParam(owner).typ.skipTypes({tyOwned, tyRef, tyPtr})
+          discard addField(obj, s, c.graph.cache, c.idgen)
       # create required upFields:
       var w = owner.skipGenericOwner
       if isInnerProc(w) or owner.isIterator:
@@ -516,9 +523,14 @@ proc detectCapturedVars(n: PNode; owner: PSym; c: var DetectionPass) =
       detectCapturedVars(n[namePos], owner, c)
   of nkReturnStmt:
     detectCapturedVars(n[0], owner, c)
+  of nkIdentDefs:
+    detectCapturedVars(n[^1], owner, c)
   else:
+    if n.isCallExpr and n[0].isTypeOf:
+      c.inTypeOf = true
     for i in 0..<n.len:
       detectCapturedVars(n[i], owner, c)
+    c.inTypeOf = false
 
 type
   LiftingPass = object
@@ -798,6 +810,8 @@ proc liftCapturedVars(n: PNode; owner: PSym; d: var DetectionPass;
   of nkTypeOfExpr:
     result = n
   else:
+    if n.isCallExpr and n[0].isTypeOf:
+      return
     if owner.isIterator:
       if nfLL in n.flags:
         # special case 'when nimVm' due to bug #3636:
diff --git a/tests/iter/t21737.nim b/tests/iter/t21737.nim
new file mode 100644
index 000000000..da06faea7
--- /dev/null
+++ b/tests/iter/t21737.nim
@@ -0,0 +1,22 @@
+discard """
+  action: compile
+"""
+
+template mytoSeq*(iter: untyped): untyped =
+  var result: seq[typeof(iter)]# = @[]
+  for x in iter:
+    result.add(x)
+  result
+
+iterator test(dir:int): int =
+  yield 1234
+
+iterator walkGlobKinds (): int =
+  let dir2 = 123
+  let it = mytoSeq(test(dir2))  
+
+proc main()=
+    let it = iterator(): int=
+      for path in walkGlobKinds():
+          yield path
+main()
diff --git a/tests/iter/t22548.nim b/tests/iter/t22548.nim
new file mode 100644
index 000000000..b9abb75d0
--- /dev/null
+++ b/tests/iter/t22548.nim
@@ -0,0 +1,21 @@
+discard """
+  action: compile
+"""
+
+type Xxx[T] = object
+
+iterator x(v: string): char =
+  var v2: Xxx[int]
+
+  var y: v2.T
+
+  echo y
+
+proc bbb(vv: string): proc () =
+  proc xxx() =
+    for c in x(vv):
+      echo c
+
+  return xxx
+
+bbb("test")()