summary refs log tree commit diff stats
path: root/compiler/sempass2.nim
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/sempass2.nim')
-rw-r--r--compiler/sempass2.nim87
1 files changed, 68 insertions, 19 deletions
diff --git a/compiler/sempass2.nim b/compiler/sempass2.nim
index 0954f42fd..823699a8c 100644
--- a/compiler/sempass2.nim
+++ b/compiler/sempass2.nim
@@ -352,20 +352,25 @@ proc useVar(a: PEffects, n: PNode) =
       a.init.add s.id
   useVarNoInitCheck(a, n, s)
 
+type
+  BreakState = enum
+    bsNone
+    bsBreakOrReturn
+    bsNoReturn
 
 type
   TIntersection = seq[tuple[id, count: int]] # a simple count table
 
-proc addToIntersection(inter: var TIntersection, s: int, initOnly: bool) =
+proc addToIntersection(inter: var TIntersection, s: int, state: BreakState) =
   for j in 0..<inter.len:
     if s == inter[j].id:
-      if not initOnly:
+      if state == bsNone:
         inc inter[j].count
       return
-  if initOnly:
-    inter.add((id: s, count: 0))
-  else:
+  if state == bsNone:
     inter.add((id: s, count: 1))
+  else:
+    inter.add((id: s, count: 0))
 
 proc throws(tracked, n, orig: PNode) =
   if n.typ == nil or n.typ.kind != tyError:
@@ -469,7 +474,7 @@ proc trackTryStmt(tracked: PEffects, n: PNode) =
   track(tracked, n[0])
   dec tracked.inTryStmt
   for i in oldState..<tracked.init.len:
-    addToIntersection(inter, tracked.init[i], false)
+    addToIntersection(inter, tracked.init[i], bsNone)
 
   var branches = 1
   var hasFinally = false
@@ -504,7 +509,7 @@ proc trackTryStmt(tracked: PEffects, n: PNode) =
           tracked.init.add b[j][2].sym.id
       track(tracked, b[^1])
       for i in oldState..<tracked.init.len:
-        addToIntersection(inter, tracked.init[i], false)
+        addToIntersection(inter, tracked.init[i], bsNone)
     else:
       setLen(tracked.init, oldState)
       track(tracked, b[^1])
@@ -673,15 +678,50 @@ proc trackOperandForIndirectCall(tracked: PEffects, n: PNode, formals: PType; ar
       localError(tracked.config, n.info, $n & " is not GC safe")
   notNilCheck(tracked, n, paramType)
 
-proc breaksBlock(n: PNode): bool =
+
+proc breaksBlock(n: PNode): BreakState =
   # semantic check doesn't allow statements after raise, break, return or
   # call to noreturn proc, so it is safe to check just the last statements
   var it = n
   while it.kind in {nkStmtList, nkStmtListExpr} and it.len > 0:
     it = it.lastSon
 
-  result = it.kind in {nkBreakStmt, nkReturnStmt, nkRaiseStmt} or
-    it.kind in nkCallKinds and it[0].kind == nkSym and sfNoReturn in it[0].sym.flags
+  case it.kind
+  of nkBreakStmt, nkReturnStmt:
+    result = bsBreakOrReturn
+  of nkRaiseStmt:
+    result = bsNoReturn
+  of nkCallKinds:
+    if it[0].kind == nkSym and sfNoReturn in it[0].sym.flags:
+      result = bsNoReturn
+    else:
+      result = bsNone
+  else:
+    result = bsNone
+
+proc addIdToIntersection(tracked: PEffects, inter: var TIntersection, resCounter: var int,
+            hasBreaksBlock: BreakState, oldState: int, resSym: PSym, hasResult: bool) =
+  if hasResult:
+    var alreadySatisfy = false
+
+    if hasBreaksBlock == bsNoReturn:
+      alreadySatisfy = true
+      inc resCounter
+
+    for i in oldState..<tracked.init.len:
+      if tracked.init[i] == resSym.id:
+        if not alreadySatisfy:
+          inc resCounter
+          alreadySatisfy = true
+      else:
+        addToIntersection(inter, tracked.init[i], hasBreaksBlock)
+  else:
+    for i in oldState..<tracked.init.len:
+      addToIntersection(inter, tracked.init[i], hasBreaksBlock)
+
+template hasResultSym(s: PSym): bool =
+  s != nil and s.kind in {skProc, skFunc, skConverter, skMethod} and
+    not isEmptyType(s.typ[0])
 
 proc trackCase(tracked: PEffects, n: PNode) =
   track(tracked, n[0])
@@ -694,6 +734,10 @@ proc trackCase(tracked: PEffects, n: PNode) =
         (tracked.config.hasWarn(warnProveField) or strictCaseObjects in tracked.c.features)
   var inter: TIntersection = @[]
   var toCover = 0
+  let hasResult = hasResultSym(tracked.owner)
+  let resSym = if hasResult: tracked.owner.ast[resultPos].sym else: nil
+  var resCounter = 0
+
   for i in 1..<n.len:
     let branch = n[i]
     setLen(tracked.init, oldState)
@@ -703,13 +747,14 @@ proc trackCase(tracked: PEffects, n: PNode) =
     for i in 0..<branch.len:
       track(tracked, branch[i])
     let hasBreaksBlock = breaksBlock(branch.lastSon)
-    if not hasBreaksBlock:
+    if hasBreaksBlock == bsNone:
       inc toCover
-    for i in oldState..<tracked.init.len:
-      addToIntersection(inter, tracked.init[i], hasBreaksBlock)
+    addIdToIntersection(tracked, inter, resCounter, hasBreaksBlock, oldState, resSym, hasResult)
 
   setLen(tracked.init, oldState)
   if not stringCase or lastSon(n).kind == nkElse:
+    if hasResult and resCounter == n.len-1:
+        tracked.init.add resSym.id
     for id, count in items(inter):
       if count >= toCover: tracked.init.add id
     # else we can't merge
@@ -723,14 +768,17 @@ proc trackIf(tracked: PEffects, n: PNode) =
   addFact(tracked.guards, n[0][0])
   let oldState = tracked.init.len
 
+  let hasResult = hasResultSym(tracked.owner)
+  let resSym = if hasResult: tracked.owner.ast[resultPos].sym else: nil
+  var resCounter = 0
+
   var inter: TIntersection = @[]
   var toCover = 0
   track(tracked, n[0][1])
   let hasBreaksBlock = breaksBlock(n[0][1])
-  if not hasBreaksBlock:
+  if hasBreaksBlock == bsNone:
     inc toCover
-  for i in oldState..<tracked.init.len:
-    addToIntersection(inter, tracked.init[i], hasBreaksBlock)
+  addIdToIntersection(tracked, inter, resCounter, hasBreaksBlock, oldState, resSym, hasResult)
 
   for i in 1..<n.len:
     let branch = n[i]
@@ -743,13 +791,14 @@ proc trackIf(tracked: PEffects, n: PNode) =
     for i in 0..<branch.len:
       track(tracked, branch[i])
     let hasBreaksBlock = breaksBlock(branch.lastSon)
-    if not hasBreaksBlock:
+    if hasBreaksBlock == bsNone:
       inc toCover
-    for i in oldState..<tracked.init.len:
-      addToIntersection(inter, tracked.init[i], hasBreaksBlock)
+    addIdToIntersection(tracked, inter, resCounter, hasBreaksBlock, oldState, resSym, hasResult)
 
   setLen(tracked.init, oldState)
   if lastSon(n).len == 1:
+    if hasResult and resCounter == n.len:
+        tracked.init.add resSym.id
     for id, count in items(inter):
       if count >= toCover: tracked.init.add id
     # else we can't merge as it is not exhaustive