summary refs log tree commit diff stats
path: root/compiler/nilcheck.nim
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/nilcheck.nim')
-rw-r--r--compiler/nilcheck.nim36
1 files changed, 18 insertions, 18 deletions
diff --git a/compiler/nilcheck.nim b/compiler/nilcheck.nim
index 96e096770..7e0efc34b 100644
--- a/compiler/nilcheck.nim
+++ b/compiler/nilcheck.nim
@@ -7,8 +7,8 @@
 #    distribution, for details about the copyright.
 #
 
-import ast, renderer, intsets, tables, msgs, options, lineinfos, strformat, idents, treetab, hashes
-import sequtils, strutils, sets
+import ast, renderer, msgs, options, lineinfos, idents, treetab
+import std/[intsets, tables, sequtils, strutils, sets, strformat, hashes]
 
 when defined(nimPreviewSlimSystem):
   import std/assertions
@@ -498,7 +498,7 @@ proc checkCall(n, ctx, map): Check =
   # check args and handle possible mutations
 
   var isNew = false
-  result.map = map
+  result = Check(map: map)
   for i, child in n:
     discard check(child, ctx, map)
 
@@ -507,7 +507,7 @@ proc checkCall(n, ctx, map): Check =
       # as it might have been mutated
       # TODO similar for normal refs and fields: find dependent exprs: brackets
 
-      if child.kind == nkHiddenAddr and not child.typ.isNil and child.typ.kind == tyVar and child.typ[0].kind == tyRef:
+      if child.kind == nkHiddenAddr and not child.typ.isNil and child.typ.kind == tyVar and child.typ.elementType.kind == tyRef:
         if not isNew:
           result.map = newNilMap(map)
           isNew = true
@@ -753,6 +753,7 @@ proc checkReturn(n, ctx, map): Check =
 
 proc checkIf(n, ctx, map): Check =
   ## check branches based on condition
+  result = default(Check)
   var mapIf: NilMap = map
 
   # first visit the condition
@@ -825,7 +826,7 @@ proc checkFor(n, ctx, map): Check =
   var check2 = check(n.sons[2], ctx, m)
   var map2 = check2.map
 
-  result.map = ctx.union(map0, m)
+  result = Check(map: ctx.union(map0, m))
   result.map = ctx.union(result.map, map2)
   result.nilability = Safe
 
@@ -853,7 +854,7 @@ proc checkWhile(n, ctx, map): Check =
   var check2 = check(n.sons[1], ctx, m)
   var map2 = check2.map
 
-  result.map = ctx.union(map0, map1)
+  result = Check(map: ctx.union(map0, map1))
   result.map = ctx.union(result.map, map2)
   result.nilability = Safe
 
@@ -899,7 +900,7 @@ proc checkInfix(n, ctx, map): Check =
 proc checkIsNil(n, ctx, map; isElse: bool = false): Check =
   ## check isNil calls
   ## update the map depending on if it is not isNil or isNil
-  result.map = newNilMap(map)
+  result = Check(map: newNilMap(map))
   let value = n[1]
   result.map.store(ctx, ctx.index(n[1]), if not isElse: Nil else: Safe, TArg, n.info, n)
 
@@ -918,7 +919,7 @@ proc infix(ctx: NilCheckerContext, l: PNode, r: PNode, magic: TMagic): PNode =
     newSymNode(op, r.info),
     l,
     r)
-  result.typ = newType(tyBool, nextTypeId ctx.idgen, nil)
+  result.typ = newType(tyBool, ctx.idgen, nil)
 
 proc prefixNot(ctx: NilCheckerContext, node: PNode): PNode =
   var cache = newIdentCache()
@@ -928,7 +929,7 @@ proc prefixNot(ctx: NilCheckerContext, node: PNode): PNode =
   result = nkPrefix.newTree(
     newSymNode(op, node.info),
     node)
-  result.typ = newType(tyBool, nextTypeId ctx.idgen, nil)
+  result.typ = newType(tyBool, ctx.idgen, nil)
 
 proc infixEq(ctx: NilCheckerContext, l: PNode, r: PNode): PNode =
   infix(ctx, l, r, mEqRef)
@@ -947,7 +948,7 @@ proc checkCase(n, ctx, map): Check =
   #   c2
   # also a == true is a , a == false is not a
   let base = n[0]
-  result.map = map.copyMap()
+  result = Check(map: map.copyMap())
   result.nilability = Safe
   var a: PNode = nil
   for child in n:
@@ -1219,7 +1220,7 @@ proc check(n: PNode, ctx: NilCheckerContext, map: NilMap): Check =
     result = check(n.sons[1], ctx, map)
   of nkStmtList, nkStmtListExpr, nkChckRangeF, nkChckRange64, nkChckRange,
      nkBracket, nkCurly, nkPar, nkTupleConstr, nkClosure, nkObjConstr, nkElse:
-    result.map = map
+    result = Check(map: map)
     if n.kind in {nkObjConstr, nkTupleConstr}:
       # TODO deeper nested elements?
       # A(field: B()) #
@@ -1246,10 +1247,10 @@ proc check(n: PNode, ctx: NilCheckerContext, map: NilMap): Check =
     result = checkIf(n, ctx, map)
   of nkAsgn, nkFastAsgn, nkSinkAsgn:
     result = checkAsgn(n[0], n[1], ctx, map)
-  of nkVarSection:
-    result.map = map
+  of nkVarSection, nkLetSection:
+    result = Check(map: map)
     for child in n:
-      result = checkAsgn(child[0], child[2], ctx, result.map)
+      result = checkAsgn(child[0].skipPragmaExpr, child[2], ctx, result.map)
   of nkForStmt:
     result = checkFor(n, ctx, map)
   of nkCaseStmt:
@@ -1273,8 +1274,7 @@ proc check(n: PNode, ctx: NilCheckerContext, map: NilMap): Check =
   else:
 
     var elementMap = map.copyMap()
-    var elementCheck: Check
-    elementCheck.map = elementMap
+    var elementCheck = Check(map: elementMap)
     for element in n:
       elementCheck = check(element, ctx, elementCheck.map)
 
@@ -1367,7 +1367,7 @@ proc checkNil*(s: PSym; body: PNode; conf: ConfigRef, idgen: IdGenerator) =
         continue
       map.store(context, context.index(child), typeNilability(child.typ), TArg, child.info, child)
 
-  map.store(context, resultExprIndex, if not s.typ[0].isNil and s.typ[0].kind == tyRef: Nil else: Safe, TResult, s.ast.info)
+  map.store(context, resultExprIndex, if not s.typ.returnType.isNil and s.typ.returnType.kind == tyRef: Nil else: Safe, TResult, s.ast.info)
 
   # echo "checking ", s.name.s, " ", filename
 
@@ -1383,5 +1383,5 @@ proc checkNil*(s: PSym; body: PNode; conf: ConfigRef, idgen: IdGenerator) =
   # (ANotNil, BNotNil) :
   # do we check on asgn nilability at all?
 
-  if not s.typ[0].isNil and s.typ[0].kind == tyRef and tfNotNil in s.typ[0].flags:
+  if not s.typ.returnType.isNil and s.typ.returnType.kind == tyRef and tfNotNil in s.typ.returnType.flags:
     checkResult(s.ast, context, res.map)