summary refs log tree commit diff stats
diff options
context:
space:
mode:
authorringabout <43030857+ringabout@users.noreply.github.com>2023-09-01 14:55:19 +0800
committerGitHub <noreply@github.com>2023-09-01 08:55:19 +0200
commitaffd3f78587f1b18a8b4bc2fc51967cd55cb7531 (patch)
tree38c166f7734bb573fc08ae1eb5abfb827559f124
parent3b206ed988765419c5ff02c2b08645f24ed0cbad (diff)
downloadNim-affd3f78587f1b18a8b4bc2fc51967cd55cb7531.tar.gz
fixes #22613; Default value does not work with object's discriminator (#22614)
* fixes #22613; Default value does not work with object's discriminator

fixes #22613

* merge branches

* add a test case

* fixes status

* remove outdated comments

* move collectBranchFields into the global scope
-rw-r--r--compiler/semobjconstr.nim66
-rw-r--r--tests/objects/tobject_default_value.nim17
2 files changed, 59 insertions, 24 deletions
diff --git a/compiler/semobjconstr.nim b/compiler/semobjconstr.nim
index 3e0a6fdb7..2d366d8fc 100644
--- a/compiler/semobjconstr.nim
+++ b/compiler/semobjconstr.nim
@@ -75,7 +75,6 @@ proc locateFieldInInitExpr(c: PContext, field: PSym, initExpr: PNode): PNode =
 
 proc semConstrField(c: PContext, flags: TExprFlags,
                     field: PSym, initExpr: PNode): PNode =
-  result = nil
   let assignment = locateFieldInInitExpr(c, field, initExpr)
   if assignment != nil:
     if nfSem in assignment.flags: return assignment[1]
@@ -93,7 +92,9 @@ proc semConstrField(c: PContext, flags: TExprFlags,
     assignment[0] = newSymNode(field)
     assignment[1] = initValue
     assignment.flags.incl nfSem
-    return initValue
+    result = initValue
+  else:
+    result = nil
 
 proc branchVals(c: PContext, caseNode: PNode, caseIdx: int,
                 isStmtBranch: bool): IntSet =
@@ -192,6 +193,23 @@ proc collectOrAddMissingCaseFields(c: PContext, branchNode: PNode,
     asgnExpr.typ = recTyp
     defaults.add newTree(nkExprColonExpr, newSymNode(sym), asgnExpr)
 
+proc collectBranchFields(c: PContext, n: PNode, discriminatorVal: PNode,
+                          constrCtx: var ObjConstrContext, flags: TExprFlags) =
+  # All bets are off. If any of the branches has a mandatory
+  # fields we must produce an error:
+  for i in 1..<n.len:
+    let branchNode = n[i]
+    if branchNode != nil:
+      let oldCheckDefault = constrCtx.checkDefault
+      constrCtx.checkDefault = true
+      let (_, defaults) = semConstructFields(c, branchNode[^1], constrCtx, flags)
+      constrCtx.checkDefault = oldCheckDefault
+      if len(defaults) > 0:
+        localError(c.config, discriminatorVal.info, "branch initialization " &
+                    "with a runtime discriminator is not supported " &
+                    "for a branch whose fields have default values.")
+    discard collectMissingCaseFields(c, n[i], constrCtx, @[])
+
 proc semConstructFields(c: PContext, n: PNode, constrCtx: var ObjConstrContext,
                         flags: TExprFlags): tuple[status: InitStatus, defaults: seq[PNode]] =
   result = (initUnknown, @[])
@@ -331,13 +349,27 @@ proc semConstructFields(c: PContext, n: PNode, constrCtx: var ObjConstrContext,
                                             discriminator.sym,
                                             constrCtx.initExpr)
       if discriminatorVal == nil:
-        # None of the branches were explicitly selected by the user and no
-        # value was given to the discrimator. We can assume that it will be
-        # initialized to zero and this will select a particular branch as
-        # a result:
-        let defaultValue = newIntLit(c.graph, constrCtx.initExpr.info, 0)
-        let matchedBranch = n.pickCaseBranch defaultValue
-        discard collectMissingCaseFields(c, matchedBranch, constrCtx, @[])
+        if discriminator.sym.ast != nil:
+          # branch is selected by the default field value of discriminator
+          let discriminatorDefaultVal = discriminator.sym.ast
+          result.status = initUnknown
+          result.defaults.add newTree(nkExprColonExpr, n[0], discriminatorDefaultVal)
+          if discriminatorDefaultVal.kind == nkIntLit:
+            let matchedBranch = n.pickCaseBranch discriminatorDefaultVal
+            if matchedBranch != nil:
+              let (_, defaults) = semConstructFields(c, matchedBranch[^1], constrCtx, flags)
+              result.defaults.add defaults
+              collectOrAddMissingCaseFields(c, matchedBranch, constrCtx, result.defaults)
+          else:
+            collectBranchFields(c, n, discriminatorDefaultVal, constrCtx, flags)
+        else:
+          # None of the branches were explicitly selected by the user and no
+          # value was given to the discrimator. We can assume that it will be
+          # initialized to zero and this will select a particular branch as
+          # a result:
+          let defaultValue = newIntLit(c.graph, constrCtx.initExpr.info, 0)
+          let matchedBranch = n.pickCaseBranch defaultValue
+          discard collectMissingCaseFields(c, matchedBranch, constrCtx, @[])
       else:
         result.status = initPartial
         if discriminatorVal.kind == nkIntLit:
@@ -349,20 +381,8 @@ proc semConstructFields(c: PContext, n: PNode, constrCtx: var ObjConstrContext,
             result.defaults.add defaults
             collectOrAddMissingCaseFields(c, matchedBranch, constrCtx, result.defaults)
         else:
-          # All bets are off. If any of the branches has a mandatory
-          # fields we must produce an error:
-          for i in 1..<n.len:
-            let branchNode = n[i]
-            if branchNode != nil:
-              let oldCheckDefault = constrCtx.checkDefault
-              constrCtx.checkDefault = true
-              let (_, defaults) = semConstructFields(c, branchNode[^1], constrCtx, flags)
-              constrCtx.checkDefault = oldCheckDefault
-              if len(defaults) > 0:
-                localError(c.config, discriminatorVal.info, "branch initialization " &
-                            "with a runtime discriminator is not supported " &
-                            "for a branch whose fields have default values.")
-            discard collectMissingCaseFields(c, n[i], constrCtx, @[])
+          collectBranchFields(c, n, discriminatorVal, constrCtx, flags)
+
   of nkSym:
     let field = n.sym
     let e = semConstrField(c, flags, field, constrCtx.initExpr)
diff --git a/tests/objects/tobject_default_value.nim b/tests/objects/tobject_default_value.nim
index 2d86dce11..3af790da6 100644
--- a/tests/objects/tobject_default_value.nim
+++ b/tests/objects/tobject_default_value.nim
@@ -377,7 +377,7 @@ template main {.dirty.} =
         Red, Blue, Yellow
   
     type
-      ObjectVarint3 = object # fixme it doesn't work with static
+      ObjectVarint3 = object
         case kind: Color = Blue
         of Red:
           data1: int = 10
@@ -703,5 +703,20 @@ template main {.dirty.} =
     var foo = new Container
     doAssert int(foo.thing[0].x) == 1
 
+  block: # bug #22613
+    type
+      K = enum
+        A = "a"
+        B = "b"
+      T = object
+        case kind: K = B
+        of A:
+          a: int
+        of B:
+          b: float
+
+    doAssert T().kind == B
+
+
 static: main()
 main()