summary refs log tree commit diff stats
path: root/lib/pure/collections/tables.nim
diff options
context:
space:
mode:
Diffstat (limited to 'lib/pure/collections/tables.nim')
-rw-r--r--lib/pure/collections/tables.nim72
1 files changed, 68 insertions, 4 deletions
diff --git a/lib/pure/collections/tables.nim b/lib/pure/collections/tables.nim
index 5c4ac0401..a9357ce67 100644
--- a/lib/pure/collections/tables.nim
+++ b/lib/pure/collections/tables.nim
@@ -819,15 +819,18 @@ proc enlarge[A](t: var CountTable[A]) =
   swap(t.data, n)
 
 proc `[]=`*[A](t: var CountTable[A], key: A, val: int) =
-  ## puts a (key, value)-pair into `t`. `val` has to be positive.
+  ## puts a (key, value)-pair into `t`.
   assert val > 0
   var h = rawGet(t, key)
   if h >= 0:
     t.data[h].val = val
   else:
-    h = -1 - h
-    t.data[h].key = key
-    t.data[h].val = val
+    if mustRehash(len(t.data), t.counter): enlarge(t)
+    rawInsert(t, t.data, key, val)
+    inc(t.counter)
+    #h = -1 - h
+    #t.data[h].key = key
+    #t.data[h].val = val
 
 proc initCountTable*[A](initialSize=64): CountTable[A] =
   ## creates a new count table that is empty.
@@ -984,6 +987,22 @@ proc sort*[A](t: CountTableRef[A]) =
   ## `t` in the sorted order.
   t[].sort
 
+proc merge*[A](s: var CountTable[A], t: CountTable[A]) =
+  ## merges the second table into the first one
+  for key, value in t:
+    s.inc(key, value)
+
+proc merge*[A](s, t: CountTable[A]): CountTable[A] =
+  ## merges the two tables into a new one
+  result = initCountTable[A](nextPowerOfTwo(max(s.len, t.len)))
+  for table in @[s, t]:
+    for key, value in table:
+      result.inc(key, value)
+
+proc merge*[A](s, t: CountTableRef[A]) =
+  ## merges the second table into the first one
+  s[].merge(t[])
+
 when isMainModule:
   type
     Person = object
@@ -1012,3 +1031,48 @@ when isMainModule:
   s2[p2] = 45_000
   s3[p1] = 30_000
   s3[p2] = 45_000
+
+  var
+    t1 = initCountTable[string]()
+    t2 = initCountTable[string]()
+  t1.inc("foo")
+  t1.inc("bar", 2)
+  t1.inc("baz", 3)
+  t2.inc("foo", 4)
+  t2.inc("bar")
+  t2.inc("baz", 11)
+  merge(t1, t2)
+  assert(t1["foo"] == 5)
+  assert(t1["bar"] == 3)
+  assert(t1["baz"] == 14)
+
+  let
+    t1r = newCountTable[string]()
+    t2r = newCountTable[string]()
+  t1r.inc("foo")
+  t1r.inc("bar", 2)
+  t1r.inc("baz", 3)
+  t2r.inc("foo", 4)
+  t2r.inc("bar")
+  t2r.inc("baz", 11)
+  merge(t1r, t2r)
+  assert(t1r["foo"] == 5)
+  assert(t1r["bar"] == 3)
+  assert(t1r["baz"] == 14)
+
+  var
+    t1l = initCountTable[string]()
+    t2l = initCountTable[string]()
+  t1l.inc("foo")
+  t1l.inc("bar", 2)
+  t1l.inc("baz", 3)
+  t2l.inc("foo", 4)
+  t2l.inc("bar")
+  t2l.inc("baz", 11)
+  let
+    t1merging = t1l
+    t2merging = t2l
+  let merged = merge(t1merging, t2merging)
+  assert(merged["foo"] == 5)
+  assert(merged["bar"] == 3)
+  assert(merged["baz"] == 14)