summary refs log tree commit diff stats
path: root/compiler/varpartitions.nim
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/varpartitions.nim')
-rw-r--r--compiler/varpartitions.nim30
1 files changed, 26 insertions, 4 deletions
diff --git a/compiler/varpartitions.nim b/compiler/varpartitions.nim
index 6b2f677a7..1711fea46 100644
--- a/compiler/varpartitions.nim
+++ b/compiler/varpartitions.nim
@@ -106,6 +106,7 @@ type
     unanalysableMutation: bool
     inAsgnSource, inConstructor, inNoSideEffectSection: int
     inConditional, inLoop: int
+    inConvHasDestructor: int
     owner: PSym
     g: ModuleGraph
 
@@ -427,16 +428,28 @@ proc destMightOwn(c: var Partitions; dest: var VarIndex; n: PNode) =
     # primitive literals including the empty are harmless:
     discard
 
-  of nkExprEqExpr, nkExprColonExpr, nkHiddenStdConv, nkHiddenSubConv, nkCast, nkConv:
+  of nkExprEqExpr, nkExprColonExpr, nkHiddenStdConv, nkHiddenSubConv, nkCast:
     destMightOwn(c, dest, n[1])
 
+  of nkConv:
+    if hasDestructor(n.typ):
+      inc c.inConvHasDestructor
+      destMightOwn(c, dest, n[1])
+      dec c.inConvHasDestructor
+    else:
+      destMightOwn(c, dest, n[1])
+
   of nkIfStmt, nkIfExpr:
     for i in 0..<n.len:
+      inc c.inConditional
       destMightOwn(c, dest, n[i].lastSon)
+      dec c.inConditional
 
   of nkCaseStmt:
     for i in 1..<n.len:
+      inc c.inConditional
       destMightOwn(c, dest, n[i].lastSon)
+      dec c.inConditional
 
   of nkStmtList, nkStmtListExpr:
     if n.len > 0:
@@ -481,7 +494,7 @@ proc destMightOwn(c: var Partitions; dest: var VarIndex; n: PNode) =
 
   of nkCallKinds:
     if n.typ != nil:
-      if hasDestructor(n.typ):
+      if hasDestructor(n.typ) or c.inConvHasDestructor > 0:
         # calls do construct, what we construct must be destroyed,
         # so dest cannot be a cursor:
         dest.flags.incl ownsData
@@ -489,8 +502,17 @@ proc destMightOwn(c: var Partitions; dest: var VarIndex; n: PNode) =
         # we know the result is derived from the first argument:
         var roots: seq[(PSym, int)] = @[]
         allRoots(n[1], roots, RootEscapes)
-        for r in roots:
-          connect(c, dest.sym, r[0], n[1].info)
+        if roots.len == 0 and c.inConditional > 0:
+          # when in a conditional expression,
+          # to ensure that the first argument isn't outlived
+          # by the lvalue, we need find the root, otherwise
+          # it is probably a local temporary
+          # (e.g. a return value from a call),
+          # we should prevent cursorfication
+          dest.flags.incl preventCursor
+        else:
+          for r in roots:
+            connect(c, dest.sym, r[0], n[1].info)
 
       else:
         let magic = if n[0].kind == nkSym: n[0].sym.magic else: mNone