summary refs log tree commit diff stats
diff options
context:
space:
mode:
authorringabout <43030857+ringabout@users.noreply.github.com>2023-09-04 20:36:45 +0800
committerGitHub <noreply@github.com>2023-09-04 14:36:45 +0200
commitd13aab50cf465a7f2edf9c37a4fa30e128892e72 (patch)
tree042c3f6fce69a3dfebfaa088a5ca7c924bdb4ec6
parentc5495f40d5d881e6bd155c9e6c9c6e5e49b749a7 (diff)
downloadNim-d13aab50cf465a7f2edf9c37a4fa30e128892e72.tar.gz
fixes branches interacting with break, raise etc. in strictdefs (#22627)
```nim
{.experimental: "strictdefs".}

type Test = object
  id: int

proc test(): Test =
  if true:
    return Test()
  else:
    return
echo test()
```

I will tackle https://github.com/nim-lang/Nim/issues/16735 and #21615 in
the following PR.


The old code just premises that in branches ended with returns, raise
statements etc. , all variables including the result variable are
initialized for that branch. It's true for noreturn statements. But it
is false for the result variable in a branch tailing with a return
statement, in which the result variable is not initialized. The solution
is not perfect for usages below branch statements with the result
variable uninitialized, but it should suffice for now, which gives a
proper warning.

It also fixes

```nim

{.experimental: "strictdefs".}

type Test = object
  id: int

proc foo {.noreturn.} = discard

proc test9(x: bool): Test =
  if x:
    foo()
  else:
    foo()
```
which gives a warning, but shouldn't
-rw-r--r--compiler/lookups.nim2
-rw-r--r--compiler/sempass2.nim87
-rw-r--r--tests/init/tcompiles.nim64
-rw-r--r--tests/init/treturns.nim93
4 files changed, 226 insertions, 20 deletions
diff --git a/compiler/lookups.nim b/compiler/lookups.nim
index 90f9a9b2b..2bdf3a1e0 100644
--- a/compiler/lookups.nim
+++ b/compiler/lookups.nim
@@ -596,7 +596,7 @@ proc lookUp*(c: PContext, n: PNode): PSym =
     if result == nil: result = errorUndeclaredIdentifierHint(c, n, ident)
   else:
     internalError(c.config, n.info, "lookUp")
-    return
+    return nil
   if amb:
     #contains(c.ambiguousSymbols, result.id):
     result = errorUseQualifier(c, n.info, result, amb)
diff --git a/compiler/sempass2.nim b/compiler/sempass2.nim
index 0954f42fd..823699a8c 100644
--- a/compiler/sempass2.nim
+++ b/compiler/sempass2.nim
@@ -352,20 +352,25 @@ proc useVar(a: PEffects, n: PNode) =
       a.init.add s.id
   useVarNoInitCheck(a, n, s)
 
+type
+  BreakState = enum
+    bsNone
+    bsBreakOrReturn
+    bsNoReturn
 
 type
   TIntersection = seq[tuple[id, count: int]] # a simple count table
 
-proc addToIntersection(inter: var TIntersection, s: int, initOnly: bool) =
+proc addToIntersection(inter: var TIntersection, s: int, state: BreakState) =
   for j in 0..<inter.len:
     if s == inter[j].id:
-      if not initOnly:
+      if state == bsNone:
         inc inter[j].count
       return
-  if initOnly:
-    inter.add((id: s, count: 0))
-  else:
+  if state == bsNone:
     inter.add((id: s, count: 1))
+  else:
+    inter.add((id: s, count: 0))
 
 proc throws(tracked, n, orig: PNode) =
   if n.typ == nil or n.typ.kind != tyError:
@@ -469,7 +474,7 @@ proc trackTryStmt(tracked: PEffects, n: PNode) =
   track(tracked, n[0])
   dec tracked.inTryStmt
   for i in oldState..<tracked.init.len:
-    addToIntersection(inter, tracked.init[i], false)
+    addToIntersection(inter, tracked.init[i], bsNone)
 
   var branches = 1
   var hasFinally = false
@@ -504,7 +509,7 @@ proc trackTryStmt(tracked: PEffects, n: PNode) =
           tracked.init.add b[j][2].sym.id
       track(tracked, b[^1])
       for i in oldState..<tracked.init.len:
-        addToIntersection(inter, tracked.init[i], false)
+        addToIntersection(inter, tracked.init[i], bsNone)
     else:
       setLen(tracked.init, oldState)
       track(tracked, b[^1])
@@ -673,15 +678,50 @@ proc trackOperandForIndirectCall(tracked: PEffects, n: PNode, formals: PType; ar
       localError(tracked.config, n.info, $n & " is not GC safe")
   notNilCheck(tracked, n, paramType)
 
-proc breaksBlock(n: PNode): bool =
+
+proc breaksBlock(n: PNode): BreakState =
   # semantic check doesn't allow statements after raise, break, return or
   # call to noreturn proc, so it is safe to check just the last statements
   var it = n
   while it.kind in {nkStmtList, nkStmtListExpr} and it.len > 0:
     it = it.lastSon
 
-  result = it.kind in {nkBreakStmt, nkReturnStmt, nkRaiseStmt} or
-    it.kind in nkCallKinds and it[0].kind == nkSym and sfNoReturn in it[0].sym.flags
+  case it.kind
+  of nkBreakStmt, nkReturnStmt:
+    result = bsBreakOrReturn
+  of nkRaiseStmt:
+    result = bsNoReturn
+  of nkCallKinds:
+    if it[0].kind == nkSym and sfNoReturn in it[0].sym.flags:
+      result = bsNoReturn
+    else:
+      result = bsNone
+  else:
+    result = bsNone
+
+proc addIdToIntersection(tracked: PEffects, inter: var TIntersection, resCounter: var int,
+            hasBreaksBlock: BreakState, oldState: int, resSym: PSym, hasResult: bool) =
+  if hasResult:
+    var alreadySatisfy = false
+
+    if hasBreaksBlock == bsNoReturn:
+      alreadySatisfy = true
+      inc resCounter
+
+    for i in oldState..<tracked.init.len:
+      if tracked.init[i] == resSym.id:
+        if not alreadySatisfy:
+          inc resCounter
+          alreadySatisfy = true
+      else:
+        addToIntersection(inter, tracked.init[i], hasBreaksBlock)
+  else:
+    for i in oldState..<tracked.init.len:
+      addToIntersection(inter, tracked.init[i], hasBreaksBlock)
+
+template hasResultSym(s: PSym): bool =
+  s != nil and s.kind in {skProc, skFunc, skConverter, skMethod} and
+    not isEmptyType(s.typ[0])
 
 proc trackCase(tracked: PEffects, n: PNode) =
   track(tracked, n[0])
@@ -694,6 +734,10 @@ proc trackCase(tracked: PEffects, n: PNode) =
         (tracked.config.hasWarn(warnProveField) or strictCaseObjects in tracked.c.features)
   var inter: TIntersection = @[]
   var toCover = 0
+  let hasResult = hasResultSym(tracked.owner)
+  let resSym = if hasResult: tracked.owner.ast[resultPos].sym else: nil
+  var resCounter = 0
+
   for i in 1..<n.len:
     let branch = n[i]
     setLen(tracked.init, oldState)
@@ -703,13 +747,14 @@ proc trackCase(tracked: PEffects, n: PNode) =
     for i in 0..<branch.len:
       track(tracked, branch[i])
     let hasBreaksBlock = breaksBlock(branch.lastSon)
-    if not hasBreaksBlock:
+    if hasBreaksBlock == bsNone:
       inc toCover
-    for i in oldState..<tracked.init.len:
-      addToIntersection(inter, tracked.init[i], hasBreaksBlock)
+    addIdToIntersection(tracked, inter, resCounter, hasBreaksBlock, oldState, resSym, hasResult)
 
   setLen(tracked.init, oldState)
   if not stringCase or lastSon(n).kind == nkElse:
+    if hasResult and resCounter == n.len-1:
+        tracked.init.add resSym.id
     for id, count in items(inter):
       if count >= toCover: tracked.init.add id
     # else we can't merge
@@ -723,14 +768,17 @@ proc trackIf(tracked: PEffects, n: PNode) =
   addFact(tracked.guards, n[0][0])
   let oldState = tracked.init.len
 
+  let hasResult = hasResultSym(tracked.owner)
+  let resSym = if hasResult: tracked.owner.ast[resultPos].sym else: nil
+  var resCounter = 0
+
   var inter: TIntersection = @[]
   var toCover = 0
   track(tracked, n[0][1])
   let hasBreaksBlock = breaksBlock(n[0][1])
-  if not hasBreaksBlock:
+  if hasBreaksBlock == bsNone:
     inc toCover
-  for i in oldState..<tracked.init.len:
-    addToIntersection(inter, tracked.init[i], hasBreaksBlock)
+  addIdToIntersection(tracked, inter, resCounter, hasBreaksBlock, oldState, resSym, hasResult)
 
   for i in 1..<n.len:
     let branch = n[i]
@@ -743,13 +791,14 @@ proc trackIf(tracked: PEffects, n: PNode) =
     for i in 0..<branch.len:
       track(tracked, branch[i])
     let hasBreaksBlock = breaksBlock(branch.lastSon)
-    if not hasBreaksBlock:
+    if hasBreaksBlock == bsNone:
       inc toCover
-    for i in oldState..<tracked.init.len:
-      addToIntersection(inter, tracked.init[i], hasBreaksBlock)
+    addIdToIntersection(tracked, inter, resCounter, hasBreaksBlock, oldState, resSym, hasResult)
 
   setLen(tracked.init, oldState)
   if lastSon(n).len == 1:
+    if hasResult and resCounter == n.len:
+        tracked.init.add resSym.id
     for id, count in items(inter):
       if count >= toCover: tracked.init.add id
     # else we can't merge as it is not exhaustive
diff --git a/tests/init/tcompiles.nim b/tests/init/tcompiles.nim
new file mode 100644
index 000000000..2072702ad
--- /dev/null
+++ b/tests/init/tcompiles.nim
@@ -0,0 +1,64 @@
+discard """
+  matrix: "--warningAsError:ProveInit --warningAsError:Uninit"
+"""
+
+{.experimental: "strictdefs".}
+
+type Test = object
+  id: int
+
+proc foo {.noreturn.} = discard
+
+block:
+  proc test(x: bool): Test =
+    if x:
+      foo()
+    else:
+      foo()
+
+block:
+  proc test(x: bool): Test =
+    if x:
+      result = Test()
+    else:
+      foo()
+
+  discard test(true)
+
+block:
+  proc test(x: bool): Test =
+    if x:
+      result = Test()
+    else:
+      return Test()
+
+  discard test(true)
+
+block:
+  proc test(x: bool): Test =
+    if x:
+      return Test()
+    else:
+      return Test()
+
+  discard test(true)
+
+block:
+  proc test(x: bool): Test =
+    if x:
+      result = Test()
+    else:
+      result = Test()
+      return
+
+  discard test(true)
+
+block:
+  proc test(x: bool): Test =
+    if x:
+      result = Test()
+      return
+    else:
+      raise newException(ValueError, "unreachable")
+
+  discard test(true)
diff --git a/tests/init/treturns.nim b/tests/init/treturns.nim
new file mode 100644
index 000000000..77469472a
--- /dev/null
+++ b/tests/init/treturns.nim
@@ -0,0 +1,93 @@
+{.experimental: "strictdefs".}
+
+type Test = object
+  id: int
+
+proc foo {.noreturn.} = discard
+
+proc test1(): Test =
+  if true: #[tt.Warning
+  ^ Cannot prove that 'result' is initialized. This will become a compile time error in the future. [ProveInit]]#
+    return Test()
+  else:
+    return
+
+proc test0(): Test =
+  if true: #[tt.Warning
+  ^ Cannot prove that 'result' is initialized. This will become a compile time error in the future. [ProveInit]]#
+    return
+  else:
+    foo()
+
+proc test2(): Test =
+  if true: #[tt.Warning
+  ^ Cannot prove that 'result' is initialized. This will become a compile time error in the future. [ProveInit]]#
+    return
+  else:
+    return
+
+proc test3(): Test =
+  if true: #[tt.Warning
+  ^ Cannot prove that 'result' is initialized. This will become a compile time error in the future. [ProveInit]]#
+    return
+  else:
+    return Test()
+
+proc test4(): Test =
+  if true: #[tt.Warning
+  ^ Cannot prove that 'result' is initialized. This will become a compile time error in the future. [ProveInit]]#
+    return
+  else:
+    result = Test()
+    return
+
+proc test5(x: bool): Test =
+  case x: #[tt.Warning
+  ^ Cannot prove that 'result' is initialized. This will become a compile time error in the future. [ProveInit]]#
+  of true:
+    return
+  else:
+    return Test()
+
+proc test6(x: bool): Test =
+  case x: #[tt.Warning
+  ^ Cannot prove that 'result' is initialized. This will become a compile time error in the future. [ProveInit]]#
+  of true:
+    return
+  else:
+    return
+
+proc test7(x: bool): Test =
+  case x: #[tt.Warning
+  ^ Cannot prove that 'result' is initialized. This will become a compile time error in the future. [ProveInit]]#
+  of true:
+    return
+  else:
+    discard
+
+proc test8(x: bool): Test =
+  case x: #[tt.Warning
+  ^ Cannot prove that 'result' is initialized. This will become a compile time error in the future. [ProveInit]]#
+  of true:
+    discard
+  else:
+    raise
+
+proc hasImportStmt(): bool =
+  if false: #[tt.Warning
+  ^ Cannot prove that 'result' is initialized. This will become a compile time error in the future. [ProveInit]]#
+    return true
+  else:
+    discard
+
+discard hasImportStmt()
+
+block:
+  proc hasImportStmt(): bool =
+    if false: #[tt.Warning
+    ^ Cannot prove that 'result' is initialized. This will become a compile time error in the future. [ProveInit]]#
+      return true
+    else:
+      return
+
+  discard hasImportStmt()