summary refs log tree commit diff stats
path: root/compiler/semobjconstr.nim
diff options
context:
space:
mode:
authorJasper Jenkins <jasper.vs.jenkins@gmail.com>2019-05-26 12:22:02 -0700
committerAndreas Rumpf <rumpf_a@web.de>2019-05-26 21:22:02 +0200
commitf7744260959bbdaefdc5172aaf7ff4770f8f8c03 (patch)
tree071d38af9c1e9134292a17ef197d2bd7df81b28c /compiler/semobjconstr.nim
parent16aa10dfe101da99c402657a420dbf2785ca4a2a (diff)
downloadNim-f7744260959bbdaefdc5172aaf7ff4770f8f8c03.tar.gz
Smarter variant object construction (#11273)
Diffstat (limited to 'compiler/semobjconstr.nim')
-rw-r--r--compiler/semobjconstr.nim107
1 files changed, 93 insertions, 14 deletions
diff --git a/compiler/semobjconstr.nim b/compiler/semobjconstr.nim
index a08e9c635..1597faa70 100644
--- a/compiler/semobjconstr.nim
+++ b/compiler/semobjconstr.nim
@@ -84,6 +84,53 @@ proc caseBranchMatchesExpr(branch, matched: PNode): bool =
 
   return false
 
+template processBranchVals(b, op) =
+  if b.kind != nkElifBranch:
+    for i in 0 .. b.len-2:
+      if b[i].kind == nkIntLit:
+        result.op(b[i].intVal.int)
+      elif b[i].kind == nkRange:
+        for i in b[i][0].intVal .. b[i][1].intVal:
+          result.op(i.int)
+
+proc allPossibleValues(c: PContext, t: PType): IntSet =
+  result = initIntSet()
+  if t.kind == tyEnum:
+    for field in t.n.sons:
+      result.incl(field.sym.position)
+  else:
+    for i in firstOrd(c.config, t) .. lastOrd(c.config, t):
+      result.incl(i.int)
+
+proc branchVals(c: PContext, caseNode: PNode, caseIdx: int,
+                isStmtBranch: bool): IntSet =
+  if caseNode[caseIdx].kind == nkOfBranch:
+    result = initIntSet()
+    processBranchVals(caseNode[caseIdx], incl)
+  else:
+    result = allPossibleValues(c, caseNode.sons[0].typ)
+    for i in 1 .. caseNode.len-2:
+      processBranchVals(caseNode[i], excl)
+
+proc formatUnsafeBranchVals(c: PContext, t: PType, diffVals: IntSet): string =
+  if diffVals.len <= 32:
+    var strs: seq[string]
+    if t.kind == tyEnum:
+      var i = 0
+      for val in diffVals:
+        while t.n.sons[i].sym.position < val: inc(i)
+        strs.add(t.n.sons[i].sym.name.s)
+    else:
+      for val in diffVals:
+        strs.add($val)
+    result = "{" & strs.join(", ") & "} "
+
+proc findUsefulCaseContext(c: PContext, discrimator: PNode): (PNode, int) =
+  for i in countdown(c.p.caseContext.high, 0):
+    let (caseNode, index) = c.p.caseContext[i]
+    if caseNode[0].kind == nkSym and caseNode[0].sym == discrimator.sym:
+      return (caseNode, index)
+
 proc pickCaseBranch(caseExpr, matched: PNode): PNode =
   # XXX: Perhaps this proc already exists somewhere
   let endsWithElse = caseExpr[^1].kind == nkElse
@@ -173,28 +220,60 @@ proc semConstructFields(c: PContext, recNode: PNode,
           selectedBranch = i
 
     if selectedBranch != -1:
-      let branchNode = recNode[selectedBranch]
-      let flags = flags*{efAllowDestructor} + {efNeedStatic, efPreferNilResult}
-      let discriminatorVal = semConstrField(c, flags,
-                                            discriminator.sym, initExpr)
-      if discriminatorVal == nil:
+      template badDiscriminatorError =
         let fields = fieldsPresentInBranch(selectedBranch)
         localError(c.config, initExpr.info,
           ("you must provide a compile-time value for the discriminator '$1' " &
           "in order to prove that it's safe to initialize $2.") %
           [discriminator.sym.name.s, fields])
         mergeInitStatus(result, initNone)
-      else:
-        let discriminatorVal = discriminatorVal.skipHidden
 
-        template wrongBranchError(i) =
-          let fields = fieldsPresentInBranch(i)
-          localError(c.config, initExpr.info,
-            "a case selecting discriminator '$1' with value '$2' " &
-            "appears in the object construction, but the field(s) $3 " &
-            "are in conflict with this value.",
-            [discriminator.sym.name.s, discriminatorVal.renderTree, fields])
+      template wrongBranchError(i) =
+        let fields = fieldsPresentInBranch(i)
+        localError(c.config, initExpr.info,
+          "a case selecting discriminator '$1' with value '$2' " &
+          "appears in the object construction, but the field(s) $3 " &
+          "are in conflict with this value.",
+          [discriminator.sym.name.s, discriminatorVal.renderTree, fields])
+
+      let branchNode = recNode[selectedBranch]
+      let flags = flags*{efAllowDestructor} + {efPreferStatic,
+                                               efPreferNilResult}
+      var discriminatorVal = semConstrField(c, flags,
+                                            discriminator.sym, initExpr)
 
+      if discriminatorVal != nil:
+        discriminatorVal = discriminatorVal.skipHidden
+      if discriminatorVal == nil:
+        badDiscriminatorError()
+      elif discriminatorVal.kind == nkSym:
+        let (ctorCase, ctorIdx) = findUsefulCaseContext(c, discriminatorVal)
+        if ctorCase == nil:
+          badDiscriminatorError()
+        elif discriminatorVal.sym.kind != skLet:
+          localError(c.config, discriminatorVal.info,
+            "runtime discriminator must be immutable if branch fields are " &
+            "initialized, a 'let' binding is required.")
+        elif not isOrdinalType(discriminatorVal.sym.typ, true) or
+            lengthOrd(c.config, discriminatorVal.sym.typ) > MaxSetElements:
+          localError(c.config, discriminatorVal.info,
+            "branch initialization with a runtime discriminator only " &
+            "supports ordinal types with 2^16 elements or less.")
+        elif ctorCase[ctorIdx].kind == nkElifBranch:
+          localError(c.config, discriminatorVal.info, "branch initialization " &
+            "with a runtime discriminator is not supported inside of an " &
+            "`elif` branch.")
+        else:
+          var
+            ctorBranchVals = branchVals(c, ctorCase, ctorIdx, true)
+            recBranchVals = branchVals(c, recNode, selectedBranch, false)
+            branchValsDiff = ctorBranchVals - recBranchVals
+          if branchValsDiff.len != 0:
+            localError(c.config, discriminatorVal.info, ("possible values " &
+              "$2are in conflict with discriminator values for " &
+              "selected object branch $1.") % [$selectedBranch,
+              formatUnsafeBranchVals(c, recNode.sons[0].typ, branchValsDiff)])
+      else:
         if branchNode.kind != nkElse:
           if not branchNode.caseBranchMatchesExpr(discriminatorVal):
             wrongBranchError(selectedBranch)