summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--changelog.md4
-rw-r--r--lib/std/wrapnils.nim109
-rw-r--r--tests/stdlib/twrapnils.nim82
-rw-r--r--tools/kochdocs.nim1
4 files changed, 196 insertions, 0 deletions
diff --git a/changelog.md b/changelog.md
index 3838da254..3d58a76c9 100644
--- a/changelog.md
+++ b/changelog.md
@@ -55,6 +55,10 @@
 
 - Added `times.fromUnixFloat,toUnixFloat`, subsecond resolution versions of `fromUnix`,`toUnixFloat`.
 
+- Added `wrapnils` module for chains of field-access and indexing where the LHS can be nil.
+  This simplifies code by reducing need for if-else branches around intermediate maybe nil values.
+  Eg: `echo ?.n.typ.kind`
+
 ## Library changes
 
 - `asyncdispatch.drain` now properly takes into account `selector.hasPendingOperations`
diff --git a/lib/std/wrapnils.nim b/lib/std/wrapnils.nim
new file mode 100644
index 000000000..b9eb70790
--- /dev/null
+++ b/lib/std/wrapnils.nim
@@ -0,0 +1,109 @@
+## This module allows chains of field-access and indexing where the LHS can be nil.
+## This simplifies code by reducing need for if-else branches around intermediate values
+## that maybe be nil.
+##
+## Note: experimental module and relies on {.experimental: "dotOperators".}
+## Unstable API.
+
+runnableExamples:
+  type Foo = ref object
+    x1: string
+    x2: Foo
+    x3: ref int
+
+  var f: Foo
+  assert ?.f.x2.x1 == "" # returns default value since `f` is nil
+
+  var f2 = Foo(x1: "a")
+  f2.x2 = f2
+  assert ?.f2.x1 == "a" # same as f2.x1 (no nil LHS in this chain)
+  assert ?.Foo(x1: "a").x1 == "a" # can use constructor inside
+
+  # when you know a sub-expression is not nil, you can scope it as follows:
+  assert ?.(f2.x2.x2).x3[] == 0 # because `f` is nil
+
+type Wrapnil[T] = object
+  valueImpl: T
+  validImpl: bool
+
+proc wrapnil[T](a: T): Wrapnil[T] =
+  ## See top-level example.
+  Wrapnil[T](valueImpl: a, validImpl: true)
+
+template unwrap(a: Wrapnil): untyped =
+  ## See top-level example.
+  a.valueImpl
+
+{.push experimental: "dotOperators".}
+
+template `.`*(a: Wrapnil, b): untyped =
+  ## See top-level example.
+  let a1 = a # to avoid double evaluations
+  let a2 = a1.valueImpl
+  type T = Wrapnil[type(a2.b)]
+  if a1.validImpl:
+    when type(a2) is ref|ptr:
+      if a2 == nil:
+        default(T)
+      else:
+        wrapnil(a2.b)
+    else:
+      wrapnil(a2.b)
+  else:
+    # nil is "sticky"; this is needed, see tests
+    default(T)
+
+{.pop.}
+
+proc isValid(a: Wrapnil): bool =
+  ## Returns true if `a` didn't contain intermediate `nil` values (note that
+  ## `a.valueImpl` itself can be nil even in that case)
+  a.validImpl
+
+template `[]`*[I](a: Wrapnil, i: I): untyped =
+  ## See top-level example.
+  let a1 = a # to avoid double evaluations
+  if a1.validImpl:
+    # correctly will raise IndexError if a is valid but wraps an empty container
+    wrapnil(a1.valueImpl[i])
+  else:
+    default(Wrapnil[type(a1.valueImpl[i])])
+
+template `[]`*(a: Wrapnil): untyped =
+  ## See top-level example.
+  let a1 = a # to avoid double evaluations
+  let a2 = a1.valueImpl
+  type T = Wrapnil[type(a2[])]
+  if a1.validImpl:
+    if a2 == nil:
+      default(T)
+    else:
+      wrapnil(a2[])
+  else:
+    default(T)
+
+import std/macros
+
+proc replace(n: NimNode): NimNode =
+  if n.kind == nnkPar:
+    doAssert n.len == 1
+    newCall(bindSym"wrapnil", n[0])
+  elif n.kind in {nnkCall, nnkObjConstr}:
+    newCall(bindSym"wrapnil", n)
+  elif n.len == 0:
+    newCall(bindSym"wrapnil", n)
+  else:
+    n[0] = replace(n[0])
+    n
+
+macro `?.`*(a: untyped): untyped =
+  ## Transforms `a` into an expression that can be safely evaluated even in
+  ## presence of intermediate nil pointers/references, in which case a default
+  ## value is produced.
+  #[
+  Using a template like this wouldn't work:
+    template `?.`*(a: untyped): untyped = wrapnil(a)[]
+  ]#
+  result = replace(a)
+  result = quote do:
+    `result`.valueImpl
diff --git a/tests/stdlib/twrapnils.nim b/tests/stdlib/twrapnils.nim
new file mode 100644
index 000000000..b20c67479
--- /dev/null
+++ b/tests/stdlib/twrapnils.nim
@@ -0,0 +1,82 @@
+import std/wrapnils
+
+const wrapnilExtendedExports = declared(wrapnil)
+  # for now, wrapnil, isValid, unwrap are not exported
+
+proc checkNotZero(x: float): float =
+  doAssert x != 0
+  x
+
+var witness = 0
+
+proc main() =
+  type Bar = object
+    b1: int
+    b2: ptr string
+
+  type Foo = ref object
+    x1: float
+    x2: Foo
+    x3: string
+    x4: Bar
+    x5: seq[int]
+    x6: ptr Bar
+    x7: array[2, string]
+    x8: seq[int]
+    x9: ref Bar
+
+  type Gook = ref object
+    foo: Foo
+
+  proc fun(a: Bar): auto = a.b2
+
+  var a: Foo
+  var x6 = create(Bar)
+  x6.b1 = 42
+  var a2 = Foo(x1: 1.0, x5: @[10, 11], x6: x6)
+  var a3 = Foo(x1: 1.2, x3: "abc")
+  a3.x2 = a3
+
+  var gook = Gook(foo: a)
+
+  proc initFoo(x1: float): auto =
+    witness.inc
+    result = Foo(x1: x1)
+
+  doAssert ?.a.x2.x2.x1 == 0.0
+  doAssert ?.a3.x2.x2.x1 == 1.2
+  doAssert ?.a3.x2.x2.x3[1] == 'b'
+
+  doAssert ?.a3.x2.x2.x5.len == 0
+  doAssert a3.x2.x2.x3.len == 3
+
+  when wrapnilExtendedExports:
+    # example calling wrapnil directly, with and without unwrap
+    doAssert a3.wrapnil.x2.x2.x3.len == wrapnil(3)
+    doAssert a3.wrapnil.x2.x2.x3.len.unwrap == 3
+    doAssert a2.wrapnil.x4.isValid
+    doAssert not a.wrapnil.x4.isValid
+
+  doAssert ?.a.x2.x2.x3[1] == default(char)
+  # here we only apply wrapnil around gook.foo, not gook (and assume gook is not nil)
+  doAssert ?.(gook.foo).x2.x2.x1 == 0.0
+
+  doAssert ?.a2.x6[] == Bar(b1: 42) # deref for ptr Bar
+
+  doAssert ?.a2.x1.checkNotZero == 1.0
+  doAssert a == nil
+  # shows that checkNotZero won't be called if a nil is found earlier in chain
+  doAssert ?.a.x1.checkNotZero == 0.0
+
+  # checks that a chain without nil but with an empty seq still throws IndexError
+  doAssertRaises(IndexError): discard ?.a2.x8[3]
+
+  # make sure no double evaluation bug
+  doAssert witness == 0
+  doAssert ?.initFoo(1.3).x1 == 1.3
+  doAssert witness == 1
+
+  # here, it's used twice, to deref `ref Bar` and then `ptr string`
+  doAssert ?.a.x9[].fun[] == ""
+
+main()
diff --git a/tools/kochdocs.nim b/tools/kochdocs.nim
index f99567cd0..3b3df75a7 100644
--- a/tools/kochdocs.nim
+++ b/tools/kochdocs.nim
@@ -155,6 +155,7 @@ lib/pure/strutils.nim
 lib/pure/math.nim
 lib/std/editdistance.nim
 lib/std/wordwrap.nim
+lib/std/wrapnils.nim
 lib/experimental/diff.nim
 lib/pure/algorithm.nim
 lib/pure/stats.nim