summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--lib/std/jsonutils.nim88
-rw-r--r--tests/stdlib/tjsonutils.nim53
2 files changed, 126 insertions, 15 deletions
diff --git a/lib/std/jsonutils.nim b/lib/std/jsonutils.nim
index be3d7e7c8..22f2a7a89 100644
--- a/lib/std/jsonutils.nim
+++ b/lib/std/jsonutils.nim
@@ -30,10 +30,59 @@ add a way to customize serialization, for eg:
 * handle cyclic references, using a cache of already visited addresses
 ]#
 
+import std/macros
+
 proc isNamedTuple(T: typedesc): bool {.magic: "TypeTrait".}
 proc distinctBase(T: typedesc): typedesc {.magic: "TypeTrait".}
 template distinctBase[T](a: T): untyped = distinctBase(type(a))(a)
 
+macro getDiscriminants(a: typedesc): seq[string] =
+  ## return the discriminant keys
+  # candidate for std/typetraits
+  var a = a.getTypeImpl
+  doAssert a.kind == nnkBracketExpr
+  let sym = a[1]
+  let t = sym.getTypeImpl
+  let t2 = t[2]
+  doAssert t2.kind == nnkRecList
+  result = newTree(nnkBracket)
+  for ti in t2:
+    if ti.kind == nnkRecCase:
+      let key = ti[0][0]
+      let typ = ti[0][1]
+      result.add newLit key.strVal
+  if result.len > 0:
+    result = quote do:
+      @`result`
+  else:
+    result = quote do:
+      seq[string].default
+
+macro initCaseObject(a: typedesc, fun: untyped): untyped =
+  ## does the minimum to construct a valid case object, only initializing
+  ## the discriminant fields; see also `getDiscriminants`
+  # maybe candidate for std/typetraits
+  var a = a.getTypeImpl
+  doAssert a.kind == nnkBracketExpr
+  let sym = a[1]
+  let t = sym.getTypeImpl
+  var t2: NimNode
+  case t.kind
+  of nnkObjectTy: t2 = t[2]
+  of nnkRefTy: t2 = t[0].getTypeImpl[2]
+  else: doAssert false, $t.kind # xxx `nnkPtrTy` could be handled too
+  doAssert t2.kind == nnkRecList
+  result = newTree(nnkObjConstr)
+  result.add sym
+  for ti in t2:
+    if ti.kind == nnkRecCase:
+      let key = ti[0][0]
+      let typ = ti[0][1]
+      let key2 = key.strVal
+      let val = quote do:
+        `fun`(`key2`, typedesc[`typ`])
+      result.add newTree(nnkExprColonExpr, key, val)
+
 proc checkJsonImpl(cond: bool, condStr: string, msg = "") =
   if not cond:
     # just pick 1 exception type for simplicity; other choices would be:
@@ -43,6 +92,19 @@ proc checkJsonImpl(cond: bool, condStr: string, msg = "") =
 template checkJson(cond: untyped, msg = "") =
   checkJsonImpl(cond, astToStr(cond), msg)
 
+template fromJsonFields(a, b, T, keys) =
+  checkJson b.kind == JObject, $(b.kind) # we could customize whether to allow JNull
+  var num = 0
+  for key, val in fieldPairs(a):
+    num.inc
+    when key notin keys:
+      if b.hasKey key:
+        fromJson(val, b[key])
+      else:
+        # we could customize to allow this
+        checkJson false, $($T, key, b)
+  checkJson b.len == num, $(b.len, num, $T, b) # could customize
+
 proc fromJson*[T](a: var T, b: JsonNode) =
   ## inplace version of `jsonTo`
   #[
@@ -85,25 +147,22 @@ proc fromJson*[T](a: var T, b: JsonNode) =
     a.setLen b.len
     for i, val in b.getElems:
       fromJson(a[i], val)
-  elif T is object | tuple:
-    const isNamed = T is object or isNamedTuple(T)
-    when isNamed:
-      checkJson b.kind == JObject, $(b.kind) # we could customize whether to allow JNull
-      var num = 0
-      for key, val in fieldPairs(a):
-        num.inc
-        if b.hasKey key:
-          fromJson(val, b[key])
-        else:
-          # we could customize to allow this
-          checkJson false, $($T, key, b)
-      checkJson b.len == num, $(b.len, num, $T, b) # could customize
+  elif T is object:
+    template fun(key, typ): untyped =
+      jsonTo(b[key], typ)
+    a = initCaseObject(T, fun)
+    const keys = getDiscriminants(T)
+    fromJsonFields(a, b, T, keys)
+  elif T is tuple:
+    when isNamedTuple(T):
+      fromJsonFields(a, b, T, seq[string].default)
     else:
       checkJson b.kind == JArray, $(b.kind) # we could customize whether to allow JNull
       var i = 0
       for val in fields(a):
         fromJson(val, b[i])
         i.inc
+      checkJson b.len == i, $(b.len, i, $T, b) # could customize
   else:
     # checkJson not appropriate here
     static: doAssert false, "not yet implemented: " & $T
@@ -120,8 +179,7 @@ proc toJson*[T](a: T): JsonNode =
     result = newJObject()
     for k, v in pairs(a): result[k] = toJson(v)
   elif T is object | tuple:
-    const isNamed = T is object or isNamedTuple(T)
-    when isNamed:
+    when T is object or isNamedTuple(T):
       result = newJObject()
       for k, v in a.fieldPairs: result[k] = toJson(v)
     else:
diff --git a/tests/stdlib/tjsonutils.nim b/tests/stdlib/tjsonutils.nim
index fca980dc9..0b2ec7179 100644
--- a/tests/stdlib/tjsonutils.nim
+++ b/tests/stdlib/tjsonutils.nim
@@ -9,6 +9,9 @@ proc testRoundtrip[T](t: T, expected: string) =
   let j = t.toJson
   doAssert $j == expected, $j
   doAssert j.jsonTo(T).toJson == j
+  var t2: T
+  t2.fromJson(j)
+  doAssert t2.toJson == j
 
 import tables
 import strtabs
@@ -66,5 +69,55 @@ template fn() =
     doAssert b2.ord == 1 # explains the `1`
     testRoundtrip(a): """[1,2,3]"""
 
+  block: # case object
+    type Foo = object
+      x0: float
+      case t1: bool
+      of true: z1: int8
+      of false: z2: uint16
+      x1: string
+    testRoundtrip(Foo(t1: true, z1: 5, x1: "bar")): """{"x0":0.0,"t1":true,"z1":5,"x1":"bar"}"""
+    testRoundtrip(Foo(x0: 1.5, t1: false, z2: 6)): """{"x0":1.5,"t1":false,"z2":6,"x1":""}"""
+    type PFoo = ref Foo
+    testRoundtrip(PFoo(x0: 1.5, t1: false, z2: 6)): """{"x0":1.5,"t1":false,"z2":6,"x1":""}"""
+
+  block: # ref case object
+    type Foo = ref object
+      x0: float
+      case t1: bool
+      of true: z1: int8
+      of false: z2: uint16
+      x1: string
+    testRoundtrip(Foo(t1: true, z1: 5, x1: "bar")): """{"x0":0.0,"t1":true,"z1":5,"x1":"bar"}"""
+    testRoundtrip(Foo(x0: 1.5, t1: false, z2: 6)): """{"x0":1.5,"t1":false,"z2":6,"x1":""}"""
+
+  block: # generic case object
+    type Foo[T] = ref object
+      x0: float
+      case t1: bool
+      of true: z1: int8
+      of false: z2: uint16
+      x1: string
+    testRoundtrip(Foo[float](t1: true, z1: 5, x1: "bar")): """{"x0":0.0,"t1":true,"z1":5,"x1":"bar"}"""
+    testRoundtrip(Foo[int](x0: 1.5, t1: false, z2: 6)): """{"x0":1.5,"t1":false,"z2":6,"x1":""}"""
+    # sanity check: nesting inside a tuple
+    testRoundtrip((Foo[int](x0: 1.5, t1: false, z2: 6), "foo")): """[{"x0":1.5,"t1":false,"z2":6,"x1":""},"foo"]"""
+
+  block: # case object: 2 discriminants, `when` branch, range discriminant
+    type Foo[T] = object
+      case t1: bool
+      of true:
+        z1: int8
+      of false:
+        z2: uint16
+      when T is float:
+        case t2: range[0..3]
+        of 0: z3: int8
+        of 2,3: z4: uint16
+        else: discard
+    testRoundtrip(Foo[float](t1: true, z1: 5, t2: 3, z4: 12)): """{"t1":true,"z1":5,"t2":3,"z4":12}"""
+    testRoundtrip(Foo[int](t1: false, z2: 7)): """{"t1":false,"z2":7}"""
+    # pending https://github.com/nim-lang/Nim/issues/14698, test with `type Foo[T] = ref object`
+
 static: fn()
 fn()