summary refs log tree commit diff stats
diff options
context:
space:
mode:
authorringabout <43030857+ringabout@users.noreply.github.com>2023-03-10 16:28:51 +0800
committerGitHub <noreply@github.com>2023-03-10 09:28:51 +0100
commit03198243227a076d6d758bdcf004d007dc8f9980 (patch)
treefe5d1a4f35ae2b42805064fae99e441f86ae8b6f
parent72e262666bdf2bb3c239183dd32b48bb05d113aa (diff)
downloadNim-03198243227a076d6d758bdcf004d007dc8f9980.tar.gz
fixes #21023; Segfault when mixing seqs, orc, variants and futures (#21497)
* fixes #21023; Segfault when mixing seqs, orc, variants and futures

* fixes none of the branches were explicitly selected

* add one more test

* one more test
-rw-r--r--compiler/semobjconstr.nim61
-rw-r--r--tests/arc/tcaseobj.nim61
-rw-r--r--tests/objects/tobject_default_value.nim32
3 files changed, 136 insertions, 18 deletions
diff --git a/compiler/semobjconstr.nim b/compiler/semobjconstr.nim
index 15e53a639..29067cfa4 100644
--- a/compiler/semobjconstr.nim
+++ b/compiler/semobjconstr.nim
@@ -142,15 +142,45 @@ proc fieldsPresentInInitExpr(c: PContext, fieldsRecList, initExpr: PNode): strin
       if result.len != 0: result.add ", "
       result.add field.sym.name.s.quoteStr
 
+proc locateFieldInDefaults(sym: PSym, defaults: seq[PNode]): bool =
+  result = false
+  for d in defaults:
+    if sym.id == d[0].sym.id:
+      return true
+
 proc collectMissingFields(c: PContext, fieldsRecList: PNode,
-                          constrCtx: var ObjConstrContext) =
-  for r in directFieldsInRecList(fieldsRecList):
-    if constrCtx.needsFullInit or
-       sfRequiresInit in r.sym.flags or
-       r.sym.typ.requiresInit:
+                          constrCtx: var ObjConstrContext, defaults: seq[PNode]
+                          ): seq[PSym] =
+    for r in directFieldsInRecList(fieldsRecList):
       let assignment = locateFieldInInitExpr(c, r.sym, constrCtx.initExpr)
-      if assignment == nil:
-        constrCtx.missingFields.add r.sym
+      if assignment == nil and not locateFieldInDefaults(r.sym, defaults):
+        if constrCtx.needsFullInit or
+          sfRequiresInit in r.sym.flags or
+            r.sym.typ.requiresInit:
+          constrCtx.missingFields.add r.sym
+        else:
+          result.add r.sym
+
+proc collectMissingCaseFields(c: PContext, branchNode: PNode,
+                          constrCtx: var ObjConstrContext, defaults: seq[PNode]): seq[PSym] =
+  if branchNode != nil:
+    let fieldsRecList = branchNode[^1]
+    result = collectMissingFields(c, fieldsRecList, constrCtx, defaults)
+
+proc collectOrAddMissingCaseFields(c: PContext, branchNode: PNode,
+                          constrCtx: var ObjConstrContext, defaults: var seq[PNode]) =
+  let res = collectMissingCaseFields(c, branchNode, constrCtx, defaults)
+  for sym in res:
+    let asgnType = newType(tyTypeDesc, nextTypeId(c.idgen), sym.typ.owner)
+    let recTyp = sym.typ.skipTypes(defaultFieldsSkipTypes)
+    rawAddSon(asgnType, recTyp)
+    let asgnExpr = newTree(nkCall,
+          newSymNode(getSysMagic(c.graph, constrCtx.initExpr.info, "zeroDefault", mZeroDefault)),
+          newNodeIT(nkType, constrCtx.initExpr.info, asgnType)
+        )
+    asgnExpr.flags.incl nfUseDefaultField
+    asgnExpr.typ = recTyp
+    defaults.add newTree(nkExprColonExpr, newSymNode(sym), asgnExpr)
 
 proc semConstructFields(c: PContext, n: PNode, constrCtx: var ObjConstrContext,
                         flags: TExprFlags): tuple[status: InitStatus, defaults: seq[PNode]] =
@@ -166,11 +196,6 @@ proc semConstructFields(c: PContext, n: PNode, constrCtx: var ObjConstrContext,
       let fields = branch[^1]
       fieldsPresentInInitExpr(c, fields, constrCtx.initExpr)
 
-    template collectMissingFields(branchNode: PNode) =
-      if branchNode != nil:
-        let fields = branchNode[^1]
-        collectMissingFields(c, fields, constrCtx)
-
     let discriminator = n[0]
     internalAssert c.config, discriminator.kind == nkSym
     var selectedBranch = -1
@@ -288,8 +313,7 @@ proc semConstructFields(c: PContext, n: PNode, constrCtx: var ObjConstrContext,
       # When a branch is selected with a partial match, some of the fields
       # that were not initialized may be mandatory. We must check for this:
       if result.status == initPartial:
-        collectMissingFields branchNode
-
+        collectOrAddMissingCaseFields(c, branchNode, constrCtx, result.defaults)
     else:
       result.status = initNone
       let discriminatorVal = semConstrField(c, flags + {efPreferStatic},
@@ -302,7 +326,7 @@ proc semConstructFields(c: PContext, n: PNode, constrCtx: var ObjConstrContext,
         # a result:
         let defaultValue = newIntLit(c.graph, constrCtx.initExpr.info, 0)
         let matchedBranch = n.pickCaseBranch defaultValue
-        collectMissingFields matchedBranch
+        discard collectMissingCaseFields(c, matchedBranch, constrCtx, @[])
       else:
         result.status = initPartial
         if discriminatorVal.kind == nkIntLit:
@@ -312,11 +336,12 @@ proc semConstructFields(c: PContext, n: PNode, constrCtx: var ObjConstrContext,
           if matchedBranch != nil:
             let (_, defaults) = semConstructFields(c, matchedBranch[^1], constrCtx, flags)
             result.defaults.add defaults
-            collectMissingFields matchedBranch
+            collectOrAddMissingCaseFields(c, matchedBranch, constrCtx, result.defaults)
         else:
           # All bets are off. If any of the branches has a mandatory
           # fields we must produce an error:
-          for i in 1..<n.len: collectMissingFields n[i]
+          for i in 1..<n.len:
+            discard collectMissingCaseFields(c, n[i], constrCtx, @[])
   of nkSym:
     let field = n.sym
     let e = semConstrField(c, flags, field, constrCtx.initExpr)
@@ -348,7 +373,7 @@ proc semConstructTypeAux(c: PContext,
     result.status.mergeInitStatus status
     result.defaults.add defaults
     if status in {initPartial, initNone, initUnknown}:
-      collectMissingFields c, t.n, constrCtx
+      discard collectMissingFields(c, t.n, constrCtx, result.defaults)
     let base = t[0]
     if base == nil: break
     t = skipTypes(base, skipPtrs)
diff --git a/tests/arc/tcaseobj.nim b/tests/arc/tcaseobj.nim
index d52833e4d..26c122384 100644
--- a/tests/arc/tcaseobj.nim
+++ b/tests/arc/tcaseobj.nim
@@ -269,3 +269,64 @@ proc bug20305 =
   echo x.pChildren
 
 bug20305()
+
+# bug #21023
+block:
+  block:
+    type
+      MGErrorKind = enum
+        mgeUnexpected, mgeNotFound
+
+    type Foo = object
+      kind: MGErrorKind
+      ex: Exception
+
+    type Boo = object
+      a: seq[int]
+
+    type
+      Result2 = object
+        case o: bool
+        of false:
+          e: Foo
+        of true:
+          v: Boo
+
+    proc startSessionSync(): Result2 =
+      return Result2(o: true)
+
+    proc mainSync =
+      let ff = startSessionSync()
+      doAssert ff.o == true
+
+    mainSync()
+
+  block:
+    type
+      MGErrorKind = enum
+        mgeUnexpected, mgeNotFound
+
+    type Foo = object
+      kind: MGErrorKind
+      ex: Exception
+
+    type Boo = object
+      a: seq[int]
+
+    type
+      Result2 = object
+        case o: bool
+        of false:
+          e: Foo
+        of true:
+          v: Boo
+          s: int
+
+    proc startSessionSync(): Result2 =
+      return Result2(o: true, s: 12)
+
+    proc mainSync =
+      let ff = startSessionSync()
+      doAssert ff.s == 12
+
+    mainSync()
diff --git a/tests/objects/tobject_default_value.nim b/tests/objects/tobject_default_value.nim
index 150fb0876..59af943e0 100644
--- a/tests/objects/tobject_default_value.nim
+++ b/tests/objects/tobject_default_value.nim
@@ -559,6 +559,38 @@ template main {.dirty.} =
     let x = default(Default)
     doAssert x.data is DjangoDateTime
 
+  block:
+    type
+      Result2 = object
+        case o: bool
+        of false:
+          e: float
+        of true:
+          v {.requiresInit.} : int = 1
+
+    proc startSessionSync(): Result2 =
+      return Result2(o: true)
+
+    proc mainSync =
+      let ff = startSessionSync()
+      doAssert ff.v == 1
+
+    mainSync()
+
+  block:
+    type
+      Result2 = object
+        v {.requiresInit.} : int = 1
+
+    proc startSessionSync(): Result2 =
+      return Result2()
+
+    proc mainSync =
+      let ff = startSessionSync()
+      doAssert ff.v == 1
+
+    mainSync()
+
 
 static: main()
 main()