summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--compiler/semdata.nim10
-rw-r--r--compiler/semobjconstr.nim107
-rw-r--r--compiler/semstmts.nim3
-rw-r--r--doc/manual.rst23
-rw-r--r--tests/objvariant/trt_discrim.nim138
-rw-r--r--tests/objvariant/trt_discrim_err0.nim17
-rw-r--r--tests/objvariant/trt_discrim_err1.nim17
-rw-r--r--tests/objvariant/trt_discrim_err2.nim14
-rw-r--r--tests/objvariant/trt_discrim_err3.nim17
9 files changed, 332 insertions, 14 deletions
diff --git a/compiler/semdata.nim b/compiler/semdata.nim
index 05816850a..e411633c2 100644
--- a/compiler/semdata.nim
+++ b/compiler/semdata.nim
@@ -40,6 +40,7 @@ type
     wasForwarded*: bool       # whether the current proc has a separate header
     mappingExists*: bool
     mapping*: TIdTable
+    caseContext*: seq[tuple[n: PNode, idx: int]]
 
   TMatchedConcept* = object
     candidateType*: PType
@@ -416,3 +417,12 @@ proc checkMinSonsLen*(n: PNode, length: int; conf: ConfigRef) =
 
 proc isTopLevel*(c: PContext): bool {.inline.} =
   result = c.currentScope.depthLevel <= 2
+
+proc pushCaseContext*(c: PContext, caseNode: PNode) =
+  add(c.p.caseContext, (caseNode, 0))
+
+proc popCaseContext*(c: PContext) =
+  discard pop(c.p.caseContext)
+
+proc setCaseContextIdx*(c: PContext, idx: int) =
+  c.p.caseContext[^1].idx = idx
diff --git a/compiler/semobjconstr.nim b/compiler/semobjconstr.nim
index a08e9c635..1597faa70 100644
--- a/compiler/semobjconstr.nim
+++ b/compiler/semobjconstr.nim
@@ -84,6 +84,53 @@ proc caseBranchMatchesExpr(branch, matched: PNode): bool =
 
   return false
 
+template processBranchVals(b, op) =
+  if b.kind != nkElifBranch:
+    for i in 0 .. b.len-2:
+      if b[i].kind == nkIntLit:
+        result.op(b[i].intVal.int)
+      elif b[i].kind == nkRange:
+        for i in b[i][0].intVal .. b[i][1].intVal:
+          result.op(i.int)
+
+proc allPossibleValues(c: PContext, t: PType): IntSet =
+  result = initIntSet()
+  if t.kind == tyEnum:
+    for field in t.n.sons:
+      result.incl(field.sym.position)
+  else:
+    for i in firstOrd(c.config, t) .. lastOrd(c.config, t):
+      result.incl(i.int)
+
+proc branchVals(c: PContext, caseNode: PNode, caseIdx: int,
+                isStmtBranch: bool): IntSet =
+  if caseNode[caseIdx].kind == nkOfBranch:
+    result = initIntSet()
+    processBranchVals(caseNode[caseIdx], incl)
+  else:
+    result = allPossibleValues(c, caseNode.sons[0].typ)
+    for i in 1 .. caseNode.len-2:
+      processBranchVals(caseNode[i], excl)
+
+proc formatUnsafeBranchVals(c: PContext, t: PType, diffVals: IntSet): string =
+  if diffVals.len <= 32:
+    var strs: seq[string]
+    if t.kind == tyEnum:
+      var i = 0
+      for val in diffVals:
+        while t.n.sons[i].sym.position < val: inc(i)
+        strs.add(t.n.sons[i].sym.name.s)
+    else:
+      for val in diffVals:
+        strs.add($val)
+    result = "{" & strs.join(", ") & "} "
+
+proc findUsefulCaseContext(c: PContext, discrimator: PNode): (PNode, int) =
+  for i in countdown(c.p.caseContext.high, 0):
+    let (caseNode, index) = c.p.caseContext[i]
+    if caseNode[0].kind == nkSym and caseNode[0].sym == discrimator.sym:
+      return (caseNode, index)
+
 proc pickCaseBranch(caseExpr, matched: PNode): PNode =
   # XXX: Perhaps this proc already exists somewhere
   let endsWithElse = caseExpr[^1].kind == nkElse
@@ -173,28 +220,60 @@ proc semConstructFields(c: PContext, recNode: PNode,
           selectedBranch = i
 
     if selectedBranch != -1:
-      let branchNode = recNode[selectedBranch]
-      let flags = flags*{efAllowDestructor} + {efNeedStatic, efPreferNilResult}
-      let discriminatorVal = semConstrField(c, flags,
-                                            discriminator.sym, initExpr)
-      if discriminatorVal == nil:
+      template badDiscriminatorError =
         let fields = fieldsPresentInBranch(selectedBranch)
         localError(c.config, initExpr.info,
           ("you must provide a compile-time value for the discriminator '$1' " &
           "in order to prove that it's safe to initialize $2.") %
           [discriminator.sym.name.s, fields])
         mergeInitStatus(result, initNone)
-      else:
-        let discriminatorVal = discriminatorVal.skipHidden
 
-        template wrongBranchError(i) =
-          let fields = fieldsPresentInBranch(i)
-          localError(c.config, initExpr.info,
-            "a case selecting discriminator '$1' with value '$2' " &
-            "appears in the object construction, but the field(s) $3 " &
-            "are in conflict with this value.",
-            [discriminator.sym.name.s, discriminatorVal.renderTree, fields])
+      template wrongBranchError(i) =
+        let fields = fieldsPresentInBranch(i)
+        localError(c.config, initExpr.info,
+          "a case selecting discriminator '$1' with value '$2' " &
+          "appears in the object construction, but the field(s) $3 " &
+          "are in conflict with this value.",
+          [discriminator.sym.name.s, discriminatorVal.renderTree, fields])
+
+      let branchNode = recNode[selectedBranch]
+      let flags = flags*{efAllowDestructor} + {efPreferStatic,
+                                               efPreferNilResult}
+      var discriminatorVal = semConstrField(c, flags,
+                                            discriminator.sym, initExpr)
 
+      if discriminatorVal != nil:
+        discriminatorVal = discriminatorVal.skipHidden
+      if discriminatorVal == nil:
+        badDiscriminatorError()
+      elif discriminatorVal.kind == nkSym:
+        let (ctorCase, ctorIdx) = findUsefulCaseContext(c, discriminatorVal)
+        if ctorCase == nil:
+          badDiscriminatorError()
+        elif discriminatorVal.sym.kind != skLet:
+          localError(c.config, discriminatorVal.info,
+            "runtime discriminator must be immutable if branch fields are " &
+            "initialized, a 'let' binding is required.")
+        elif not isOrdinalType(discriminatorVal.sym.typ, true) or
+            lengthOrd(c.config, discriminatorVal.sym.typ) > MaxSetElements:
+          localError(c.config, discriminatorVal.info,
+            "branch initialization with a runtime discriminator only " &
+            "supports ordinal types with 2^16 elements or less.")
+        elif ctorCase[ctorIdx].kind == nkElifBranch:
+          localError(c.config, discriminatorVal.info, "branch initialization " &
+            "with a runtime discriminator is not supported inside of an " &
+            "`elif` branch.")
+        else:
+          var
+            ctorBranchVals = branchVals(c, ctorCase, ctorIdx, true)
+            recBranchVals = branchVals(c, recNode, selectedBranch, false)
+            branchValsDiff = ctorBranchVals - recBranchVals
+          if branchValsDiff.len != 0:
+            localError(c.config, discriminatorVal.info, ("possible values " &
+              "$2are in conflict with discriminator values for " &
+              "selected object branch $1.") % [$selectedBranch,
+              formatUnsafeBranchVals(c, recNode.sons[0].typ, branchValsDiff)])
+      else:
         if branchNode.kind != nkElse:
           if not branchNode.caseBranchMatchesExpr(discriminatorVal):
             wrongBranchError(selectedBranch)
diff --git a/compiler/semstmts.nim b/compiler/semstmts.nim
index 1aa5c9fec..3090bf455 100644
--- a/compiler/semstmts.nim
+++ b/compiler/semstmts.nim
@@ -874,6 +874,7 @@ proc semCase(c: PContext, n: PNode; flags: TExprFlags): PNode =
   result = n
   checkMinSonsLen(n, 2, c.config)
   openScope(c)
+  pushCaseContext(c, n)
   n.sons[0] = semExprWithType(c, n.sons[0])
   var chckCovered = false
   var covered: BiggestInt = 0
@@ -897,6 +898,7 @@ proc semCase(c: PContext, n: PNode; flags: TExprFlags): PNode =
     localError(c.config, n.sons[0].info, errSelectorMustBeOfCertainTypes)
     return
   for i in 1 ..< sonsLen(n):
+    setCaseContextIdx(c, i)
     var x = n.sons[i]
     when defined(nimsuggest):
       if c.config.ideCmd == ideSug and exactEquals(c.config.m.trackPos, x.info) and caseTyp.kind == tyEnum:
@@ -934,6 +936,7 @@ proc semCase(c: PContext, n: PNode; flags: TExprFlags): PNode =
                  formatMissingEnums(n))
     else:
       localError(c.config, n.info, "not all cases are covered")
+  popCaseContext(c)
   closeScope(c)
   if isEmptyType(typ) or typ.kind in {tyNil, tyUntyped} or
       (not hasElse and efInTypeof notin flags):
diff --git a/doc/manual.rst b/doc/manual.rst
index 00635fefc..615b8e302 100644
--- a/doc/manual.rst
+++ b/doc/manual.rst
@@ -1543,6 +1543,29 @@ branch switch ``system.reset`` has to be used. Also, when the fields of a
 particular branch are specified during object construction, the corresponding
 discriminator value must be specified as a constant expression.
 
+As a special rule, the discriminator kind can also be bounded using a ``case``
+statement. If possible values of the discriminator variable in a
+``case`` statement branch are a subset of discriminator values for the selected
+object branch, the initialization is considered valid. This analysis only works
+for immutable discriminators of an ordinal type and disregards ``elif``
+branches.
+
+A small example:
+
+.. code-block:: nim
+
+  let unknownKind = nkSub
+
+  # invalid: unsafe initialization because the kind field is not statically known:
+  var y = Node(kind: unknownKind, strVal: "y")
+
+  var z = Node()
+  case unknownKind
+  of nkAdd, nkSub:
+    # valid: possible values of this branch are a subset of nkAdd/nkSub object branch:
+    z = Node(kind: unknownKind, leftOp: Node(), rightOp: Node())
+  else:
+    echo "ignoring: ", unknownKind
 
 Set type
 --------
diff --git a/tests/objvariant/trt_discrim.nim b/tests/objvariant/trt_discrim.nim
new file mode 100644
index 000000000..612647fbe
--- /dev/null
+++ b/tests/objvariant/trt_discrim.nim
@@ -0,0 +1,138 @@
+template accept(x) =
+  static: assert(compiles(x))
+
+template reject(x) =
+  static: assert(not compiles(x))
+
+type
+  Kind = enum k1 = 2, k2 = 33, k3 = 84, k4 = 278, k5 = 1000 # Holed enum work!
+  KindObj = object
+    case kind: Kind
+    of k1, k2..k3: i32: int32
+    of k4: f32: float32
+    else: str: string
+
+  IntObj = object
+    case kind: int16
+    of low(int16) .. -1: bad: string
+    of 0: neutral: string
+    of 1 .. high(int16): good: string
+
+  OtherKind = enum ok1, ok2, ok3, ok4, ok5
+  NestedKindObj = object
+    case kind: Kind
+    of k3, k5: discard
+    of k2: str: string
+    of k1, k4:
+      case otherKind: OtherKind
+      of ok1, ok2..ok3: i32: int32
+      of ok4: f32: float32
+      else: nestedStr: string
+
+let kind = k4 # actual value should have no impact on the analysis.
+
+accept: # Mimics the structure of the type. The optimial case.
+  case kind
+  of k1, k2, k3: discard KindObj(kind: kind, i32: 1)
+  of k4: discard KindObj(kind: kind, f32: 2.0)
+  else: discard KindObj(kind: kind, str: "3")
+
+accept: # Specifying the else explicitly is fine too.
+  case kind
+  of k1, k2, k3: discard KindObj(kind: kind, i32: 1)
+  of k4: discard KindObj(kind: kind, f32: 2.0)
+  of k5: discard KindObj(kind: kind, str: "3")
+
+accept:
+  case kind
+  of k1..k3, k5: discard
+  else: discard KindObj(kind: kind, f32: 2.0)
+
+accept:
+  case kind
+  of k4, k5: discard
+  else: discard KindObj(kind: kind, i32: 1)
+
+accept: # elif branches are ignored
+  case kind
+  of k1, k2, k3: discard KindObj(kind: kind, i32: 1)
+  of k4: discard KindObj(kind: kind, f32: 2.0)
+  elif kind in {k1..k5}: discard
+  else: discard KindObj(kind: kind, str: "3")
+
+reject: # k4 conflicts with i32
+  case kind
+  of k1, k2, k3, k4: discard KindObj(kind: kind, i32: 1)
+  else: discard KindObj(kind: kind, str: "3")
+
+reject: # k4 is not caught, conflicts with str in the else branch
+  case kind
+  of k1, k2, k3: discard KindObj(kind: kind, i32: 1)
+  else: discard KindObj(kind: kind, str: "3")
+
+reject: # elif branches are ignored
+  case kind
+  of k1, k2, k3: discard KindObj(kind: kind, i32: 1)
+  elif kind == k4: discard
+  else: discard KindObj(kind: kind, str: "3")
+
+let intKind = 29'i16
+
+accept:
+  case intKind
+  of low(int16) .. -1: discard IntObj(kind: intKind, bad: "bad")
+  of 0: discard IntObj(kind: intKind, neutral: "neutral")
+  of 1 .. high(int16): discard IntObj(kind: intKind, good: "good")
+
+reject: # 0 leaks to else
+  case intKind
+  of low(int16) .. -1: discard IntObj(kind: intKind, bad: "bad")
+  of 1 .. high(int16): discard IntObj(kind: intKind, good: "good")
+
+accept:
+  case intKind
+  of low(int16) .. -1: discard IntObj(kind: intKind, bad: "bad")
+  of 0: discard IntObj(kind: intKind, neutral: "neutral")
+  of 10, 11 .. high(int16), 1 .. 9: discard IntObj(kind: intKind, good: "good")
+
+accept:
+  case kind
+  of {k1, k2}, [k3]: discard KindObj(kind: kind, i32: 1)
+  of k4: discard KindObj(kind: kind, f32: 2.0)
+  else: discard KindObj(kind: kind, str: "3")
+
+reject:
+  case kind
+  of {k1, k2, k3}, [k4]: discard KindObj(kind: kind, i32: 1)
+  else: discard KindObj(kind: kind, str: "3")
+
+accept:
+  case kind
+  of k3, k5: discard NestedKindObj(kind: kind)
+  of k2: discard NestedKindObj(kind: kind, str: "not nested")
+  of k1, k4:
+    let otherKind = ok5
+    case otherKind
+    of ok1..ok3: discard NestedKindObj(kind: kind, otherKind: otherKind, i32: 3)
+    of ok4: discard NestedKindObj(kind: kind, otherKind: otherKind, f32: 5.0)
+    else: discard NestedKindObj(kind: kind, otherKind: otherKind,
+                                nestedStr: "nested")
+
+reject:
+  case kind
+  of k3, k5: discard NestedKindObj(kind: kind)
+  of k2: discard NestedKindObj(kind: kind, str: "not nested")
+  of k1, k4:
+    let otherKind = ok5
+    case otherKind
+    of ok1..ok3: discard NestedKindObj(kind: kind, otherKind: otherKind, i32: 3)
+    else: discard NestedKindObj(kind: kind, otherKind: otherKind,
+                                nestedStr: "nested")
+
+var varkind = k4
+
+reject: # not immutable.
+  case varkind
+  of k1, k2, k3: discard KindObj(varkind: kind, i32: 1)
+  of k4: discard KindObj(varkind: kind, f32: 2.0)
+  else: discard KindObj(varkind: kind, str: "3")
diff --git a/tests/objvariant/trt_discrim_err0.nim b/tests/objvariant/trt_discrim_err0.nim
new file mode 100644
index 000000000..02b551cbc
--- /dev/null
+++ b/tests/objvariant/trt_discrim_err0.nim
@@ -0,0 +1,17 @@
+discard """
+  errormsg: "possible values {k1, k3, k4} are in conflict with discriminator values for selected object branch 3"
+  line: 17
+"""
+
+type
+  Kind = enum k1, k2, k3, k4, k5
+  KindObj = object
+    case kind: Kind
+    of k1, k2..k3: i32: int32
+    of k4: f32: float32
+    else: str: string
+
+let kind = k3
+case kind
+of k2: discard KindObj(kind: kind, i32: 1)
+else: discard KindObj(kind: kind, str: "3")
diff --git a/tests/objvariant/trt_discrim_err1.nim b/tests/objvariant/trt_discrim_err1.nim
new file mode 100644
index 000000000..de29420a2
--- /dev/null
+++ b/tests/objvariant/trt_discrim_err1.nim
@@ -0,0 +1,17 @@
+discard """
+  errormsg: "branch initialization with a runtime discriminator is not supported inside of an `elif` branch."
+  line: 16
+"""
+type
+  Color = enum Red, Green, Blue
+  ColorObj = object
+    case colorKind: Color
+    of Red: red: string
+    of Green: green: string
+    of Blue: blue: string
+
+let colorKind = Blue
+case colorKind
+of Red: echo ColorObj(colorKind: colorKind, red: "red")
+elif colorKind == Green: echo ColorObj(colorKind: colorKind, green: "green")
+else: echo ColorObj(colorKind: colorKind, blue: "blue")
diff --git a/tests/objvariant/trt_discrim_err2.nim b/tests/objvariant/trt_discrim_err2.nim
new file mode 100644
index 000000000..c5352014e
--- /dev/null
+++ b/tests/objvariant/trt_discrim_err2.nim
@@ -0,0 +1,14 @@
+discard """
+  errormsg: "branch initialization with a runtime discriminator only supports ordinal types with 2^16 elements or less."
+  line: 13
+"""
+type
+  HoledObj = object
+    case kind: int
+    of 0: a: int
+    else: discard
+
+let someInt = low(int)
+case someInt
+of 938: echo HoledObj(kind: someInt, a: 1)
+else: discard
diff --git a/tests/objvariant/trt_discrim_err3.nim b/tests/objvariant/trt_discrim_err3.nim
new file mode 100644
index 000000000..e739c3d50
--- /dev/null
+++ b/tests/objvariant/trt_discrim_err3.nim
@@ -0,0 +1,17 @@
+discard """
+  errormsg: "runtime discriminator must be immutable if branch fields are initialized, a 'let' binding is required."
+  line: 16
+"""
+
+type
+  Kind = enum k1, k2, k3, k4, k5
+  KindObj = object
+    case kind: Kind
+    of k1, k2..k3: i32: int32
+    of k4: f32: float32
+    else: str: string
+
+var kind = k3
+case kind
+of k2: discard KindObj(kind: kind, i32: 1)
+else: discard KindObj(kind: kind, str: "3")