summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--lib/pure/collections/tableimpl.nim21
-rw-r--r--lib/pure/collections/tables.nim48
-rw-r--r--tests/stdlib/tmget.nim16
3 files changed, 81 insertions, 4 deletions
diff --git a/lib/pure/collections/tableimpl.nim b/lib/pure/collections/tableimpl.nim
index 630af3970..3542741fa 100644
--- a/lib/pure/collections/tableimpl.nim
+++ b/lib/pure/collections/tableimpl.nim
@@ -45,7 +45,7 @@ template addImpl(enlarge) {.dirty.} =
   rawInsert(t, t.data, key, val, hc, j)
   inc(t.counter)
 
-template maybeRehashPutImpl(enlarge) {.dirty.} =
+template maybeRehashPutImpl(enlarge, val) {.dirty.} =
   checkIfInitialized()
   if mustRehash(t):
     enlarge(t)
@@ -59,7 +59,7 @@ template putImpl(enlarge) {.dirty.} =
   var hc: Hash = default(Hash)
   var index = rawGet(t, key, hc)
   if index >= 0: t.data[index].val = val
-  else: maybeRehashPutImpl(enlarge)
+  else: maybeRehashPutImpl(enlarge, val)
 
 template mgetOrPutImpl(enlarge) {.dirty.} =
   checkIfInitialized()
@@ -67,17 +67,30 @@ template mgetOrPutImpl(enlarge) {.dirty.} =
   var index = rawGet(t, key, hc)
   if index < 0:
     # not present: insert (flipping index)
-    maybeRehashPutImpl(enlarge)
+    when declared(val):
+      maybeRehashPutImpl(enlarge, val)
+    else:
+      maybeRehashPutImpl(enlarge, default(B))
   # either way return modifiable val
   result = t.data[index].val
 
+# template mgetOrPutDefaultImpl(enlarge) {.dirty.} =
+#   checkIfInitialized()
+#   var hc: Hash = default(Hash)
+#   var index = rawGet(t, key, hc)
+#   if index < 0:
+#     # not present: insert (flipping index)
+#     maybeRehashPutImpl(enlarge, default(B))
+#   # either way return modifiable val
+#   result = t.data[index].val
+
 template hasKeyOrPutImpl(enlarge) {.dirty.} =
   checkIfInitialized()
   var hc: Hash = default(Hash)
   var index = rawGet(t, key, hc)
   if index < 0:
     result = false
-    maybeRehashPutImpl(enlarge)
+    maybeRehashPutImpl(enlarge, val)
   else: result = true
 
 # delImplIdx is KnuthV3 Algo6.4R adapted to i=i+1 (from i=i-1) which has come to
diff --git a/lib/pure/collections/tables.nim b/lib/pure/collections/tables.nim
index 0a902ba84..8e4a3c35f 100644
--- a/lib/pure/collections/tables.nim
+++ b/lib/pure/collections/tables.nim
@@ -474,6 +474,18 @@ proc mgetOrPut*[A, B](t: var Table[A, B], key: A, val: B): var B =
 
   mgetOrPutImpl(enlarge)
 
+proc mgetOrPut*[A, B](t: var Table[A, B], key: A): var B =
+  ## Retrieves the value at `t[key]` or puts the
+  ## default initialization value for type `B` (e.g. 0 for any
+  ## integer type).
+  runnableExamples:
+    var a = {'a': 5}.newTable
+    doAssert a.mgetOrPut('a') == 5
+    a.mgetOrPut('z').inc
+    doAssert a == {'a': 5, 'z': 1}.newTable
+
+  mgetOrPutImpl(enlarge)
+
 proc len*[A, B](t: Table[A, B]): int =
   ## Returns the number of keys in `t`.
   runnableExamples:
@@ -1013,6 +1025,18 @@ proc mgetOrPut*[A, B](t: TableRef[A, B], key: A, val: B): var B =
     doAssert t[25] == @[25, 35]
   t[].mgetOrPut(key, val)
 
+proc mgetOrPut*[A, B](t: TableRef[A, B], key: A): var B =
+  ## Retrieves the value at `t[key]` or puts the
+  ## default initialization value for type `B` (e.g. 0 for any
+  ## integer type).
+  runnableExamples:
+    var a = {'a': 5}.newTable
+    doAssert a.mgetOrPut('a') == 5
+    a.mgetOrPut('z').inc
+    doAssert a == {'a': 5, 'z': 1}.newTable
+
+  t[].mgetOrPut(key)
+
 proc len*[A, B](t: TableRef[A, B]): int =
   ## Returns the number of keys in `t`.
   runnableExamples:
@@ -1503,6 +1527,18 @@ proc mgetOrPut*[A, B](t: var OrderedTable[A, B], key: A, val: B): var B =
 
   mgetOrPutImpl(enlarge)
 
+proc mgetOrPut*[A, B](t: var OrderedTable[A, B], key: A): var B =
+  ## Retrieves the value at `t[key]` or puts the
+  ## default initialization value for type `B` (e.g. 0 for any
+  ## integer type).
+  runnableExamples:
+    var a = {'a': 5}.toOrderedTable
+    doAssert a.mgetOrPut('a') == 5
+    a.mgetOrPut('z').inc
+    doAssert a == {'a': 5, 'z': 1}.toOrderedTable
+
+  mgetOrPutImpl(enlarge)
+
 proc len*[A, B](t: OrderedTable[A, B]): int {.inline.} =
   ## Returns the number of keys in `t`.
   runnableExamples:
@@ -1992,6 +2028,18 @@ proc mgetOrPut*[A, B](t: OrderedTableRef[A, B], key: A, val: B): var B =
 
   result = t[].mgetOrPut(key, val)
 
+proc mgetOrPut*[A, B](t: OrderedTableRef[A, B], key: A): var B =
+  ## Retrieves the value at `t[key]` or puts the
+  ## default initialization value for type `B` (e.g. 0 for any
+  ## integer type).
+  runnableExamples:
+    var a = {'a': 5}.toOrderedTable
+    doAssert a.mgetOrPut('a') == 5
+    a.mgetOrPut('z').inc
+    doAssert a == {'a': 5, 'z': 1}.toOrderedTable
+
+  t[].mgetOrPut(key)
+
 proc len*[A, B](t: OrderedTableRef[A, B]): int {.inline.} =
   ## Returns the number of keys in `t`.
   runnableExamples:
diff --git a/tests/stdlib/tmget.nim b/tests/stdlib/tmget.nim
index bf5e53560..f41963f02 100644
--- a/tests/stdlib/tmget.nim
+++ b/tests/stdlib/tmget.nim
@@ -3,15 +3,19 @@ discard """
   output: '''Can't access 6
 10
 11
+2
 Can't access 6
 10
 11
+2
 Can't access 6
 10
 11
+2
 Can't access 6
 10
 11
+2
 0
 10
 11
@@ -41,6 +45,9 @@ block:
   x[5] += 1
   var c = x[5]
   echo c
+  x.mgetOrPut(7).inc
+  x.mgetOrPut(7).inc
+  echo x[7]
 
 block:
   var x = newTable[int, int]()
@@ -53,6 +60,9 @@ block:
   x[5] += 1
   var c = x[5]
   echo c
+  x.mgetOrPut(7).inc
+  x.mgetOrPut(7).inc
+  echo x[7]
 
 block:
   var x = initOrderedTable[int, int]()
@@ -65,6 +75,9 @@ block:
   x[5] += 1
   var c = x[5]
   echo c
+  x.mgetOrPut(7).inc
+  x.mgetOrPut(7).inc
+  echo x[7]
 
 block:
   var x = newOrderedTable[int, int]()
@@ -77,6 +90,9 @@ block:
   x[5] += 1
   var c = x[5]
   echo c
+  x.mgetOrPut(7).inc
+  x.mgetOrPut(7).inc
+  echo x[7]
 
 block:
   var x = initCountTable[int]()