summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--lib/pure/collections/sharedtables.nim46
-rw-r--r--lib/pure/collections/tableimpl.nim9
-rw-r--r--tests/collections/ttables.nim26
3 files changed, 77 insertions, 4 deletions
diff --git a/lib/pure/collections/sharedtables.nim b/lib/pure/collections/sharedtables.nim
index 28509caa1..de573bcb2 100644
--- a/lib/pure/collections/sharedtables.nim
+++ b/lib/pure/collections/sharedtables.nim
@@ -130,6 +130,52 @@ proc hasKeyOrPut*[A, B](t: var SharedTable[A, B], key: A, val: B): bool =
   withLock t:
     hasKeyOrPutImpl(enlarge)
 
+proc withKey*[A, B](t: var SharedTable[A, B], key: A,
+                    mapper: proc(key: A, val: var B, pairExists: var bool)) =
+  ## Computes a new mapping for the ``key`` with the specified ``mapper``
+  ## procedure.
+  ##
+  ## The ``mapper`` takes 3 arguments:
+  ##   #. ``key`` - the current key, if it exists, or the key passed to
+  ##      ``withKey`` otherwise;
+  ##   #. ``val`` - the current value, if the key exists, or default value
+  ##      of the type otherwise;
+  ##   #. ``pairExists`` - ``true`` if the key exists, ``false`` otherwise.
+  ## The ``mapper`` can can modify ``val`` and ``pairExists`` values to change
+  ## the mapping of the key or delete it from the table.
+  ## When adding a value, make sure to set ``pairExists`` to ``true`` along
+  ## with modifying the ``val``.
+  ##
+  ## The operation is performed atomically and other operations on the table
+  ## will be blocked while the ``mapper`` is invoked, so it should be short and
+  ## simple.
+  ##
+  ## Example usage:
+  ##
+  ## .. code-block:: nim
+  ##
+  ##   # If value exists, decrement it.
+  ##   # If it becomes zero or less, delete the key
+  ##   t.withKey(1'i64) do (k: int64, v: var int, pairExists: var bool):
+  ##     if pairExists:
+  ##       dec v
+  ##       if v <= 0:
+  ##         pairExists = false
+  withLock t:
+    var hc: Hash
+    var index = rawGet(t, key, hc)
+
+    var pairExists = index >= 0
+    if pairExists:
+      mapper(t.data[index].key, t.data[index].val, pairExists)
+      if not pairExists:
+        delImplIdx(t, index)
+    else:
+      var val: B
+      mapper(key, val, pairExists)
+      if pairExists:
+        maybeRehashPutImpl(enlarge)
+
 proc `[]=`*[A, B](t: var SharedTable[A, B], key: A, val: B) =
   ## puts a (key, value)-pair into `t`.
   withLock t:
diff --git a/lib/pure/collections/tableimpl.nim b/lib/pure/collections/tableimpl.nim
index 5e871f08b..c0d45c392 100644
--- a/lib/pure/collections/tableimpl.nim
+++ b/lib/pure/collections/tableimpl.nim
@@ -120,9 +120,7 @@ template default[T](t: typedesc[T]): T =
   var v: T
   v
 
-template delImpl() {.dirty.} =
-  var hc: Hash
-  var i = rawGet(t, key, hc)
+template delImplIdx(t, i) =
   let msk = maxHash(t)
   if i >= 0:
     dec(t.counter)
@@ -145,6 +143,11 @@ template delImpl() {.dirty.} =
         else:
           shallowCopy(t.data[j], t.data[i]) # data[j] will be marked EMPTY next loop
 
+template delImpl() {.dirty.} =
+  var hc: Hash
+  var i = rawGet(t, key, hc)
+  delImplIdx(t, i)
+
 template clearImpl() {.dirty.} =
   for i in 0 .. <t.data.len:
     when compiles(t.data[i].hcode): # CountTable records don't contain a hcode
diff --git a/tests/collections/ttables.nim b/tests/collections/ttables.nim
index 0e06bc26f..a2988e841 100644
--- a/tests/collections/ttables.nim
+++ b/tests/collections/ttables.nim
@@ -1,8 +1,9 @@
 discard """
+  cmd: "nim c --threads:on $file"
   output: '''true'''
 """
 
-import hashes, tables
+import hashes, tables, sharedtables
 
 const
   data = {
@@ -211,6 +212,29 @@ block clearCountTableTest:
   t.clear()
   assert t.len() == 0
 
+block withKeyTest:
+  var t = initSharedTable[int, int]()
+  t.withKey(1) do (k: int, v: var int, pairExists: var bool):
+    assert(v == 0)
+    pairExists = true
+    v = 42
+  assert(t.mget(1) == 42)
+  t.withKey(1) do (k: int, v: var int, pairExists: var bool):
+    assert(v == 42)
+    pairExists = false
+  try:
+    discard t.mget(1)
+    assert(false, "KeyError expected")
+  except KeyError:
+    discard
+  t.withKey(2) do (k: int, v: var int, pairExists: var bool):
+    pairExists = false
+  try:
+    discard t.mget(2)
+    assert(false, "KeyError expected")
+  except KeyError:
+    discard
+
 proc orderedTableSortTest() =
   var t = initOrderedTable[string, int](2)
   for key, val in items(data): t[key] = val