summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--lib/pure/collections/intsets.nim22
-rw-r--r--lib/pure/collections/setimpl.nim6
-rw-r--r--lib/pure/collections/sharedtables.nim5
-rw-r--r--lib/pure/collections/tableimpl.nim22
-rw-r--r--lib/pure/collections/tables.nim14
-rw-r--r--tests/collections/ttables.nim45
-rw-r--r--tests/sets/tsets_various.nim72
-rw-r--r--tests/stdlib/tintsets.nim65
-rw-r--r--tests/stdlib/tsharedtable.nim86
9 files changed, 296 insertions, 41 deletions
diff --git a/lib/pure/collections/intsets.nim b/lib/pure/collections/intsets.nim
index 7ca379783..1967a0149 100644
--- a/lib/pure/collections/intsets.nim
+++ b/lib/pure/collections/intsets.nim
@@ -330,6 +330,15 @@ proc excl*(s: var IntSet, other: IntSet) =
 
   for item in other: excl(s, item)
 
+proc len*(s: IntSet): int {.inline.} =
+  ## Returns the number of elements in `s`.
+  if s.elems < s.a.len:
+    result = s.elems
+  else:
+    result = 0
+    for _ in s:
+      inc(result)
+
 proc missingOrExcl*(s: var IntSet, key: int): bool =
   ## Excludes `key` in the set `s` and tells if `key` was already missing from `s`.
   ##
@@ -348,9 +357,9 @@ proc missingOrExcl*(s: var IntSet, key: int): bool =
     assert a.missingOrExcl(5) == false
     assert a.missingOrExcl(5) == true
 
-  var count = s.elems
+  var count = s.len
   exclImpl(s, key)
-  result = count == s.elems
+  result = count == s.len
 
 proc clear*(result: var IntSet) =
   ## Clears the IntSet back to an empty state.
@@ -514,15 +523,6 @@ proc disjoint*(s1, s2: IntSet): bool =
       return false
   return true
 
-proc len*(s: IntSet): int {.inline.} =
-  ## Returns the number of elements in `s`.
-  if s.elems < s.a.len:
-    result = s.elems
-  else:
-    result = 0
-    for _ in s:
-      inc(result)
-
 proc card*(s: IntSet): int {.inline.} =
   ## Alias for `len() <#len,IntSet>`_.
   result = s.len()
diff --git a/lib/pure/collections/setimpl.nim b/lib/pure/collections/setimpl.nim
index f8950e354..f7a48ab91 100644
--- a/lib/pure/collections/setimpl.nim
+++ b/lib/pure/collections/setimpl.nim
@@ -38,7 +38,7 @@ proc enlarge[A](s: var HashSet[A]) =
   newSeq(n, len(s.data) * growthFactor)
   swap(s.data, n) # n is now old seq
   for i in countup(0, high(n)):
-    if isFilled(n[i].hcode):
+    if isFilledAndValid(n[i].hcode):
       var j = -1 - rawGetKnownHC(s, n[i].key, n[i].hcode)
       rawInsert(s, s.data, n[i].key, n[i].hcode, j)
 
@@ -112,7 +112,7 @@ proc enlarge[A](s: var OrderedSet[A]) =
   swap(s.data, n)
   while h >= 0:
     var nxt = n[h].next
-    if isFilled(n[h].hcode):
+    if isFilled(n[h].hcode): # should be isFilledAndValid once tombstones are used
       var j = -1 - rawGetKnownHC(s, n[h].key, n[h].hcode)
       rawInsert(s, s.data, n[h].key, n[h].hcode, j)
     h = nxt
@@ -130,7 +130,7 @@ proc exclImpl[A](s: var OrderedSet[A], key: A): bool {.inline.} =
   result = true
   while h >= 0:
     var nxt = n[h].next
-    if isFilled(n[h].hcode):
+    if isFilled(n[h].hcode): # should be isFilledAndValid once tombstones are used
       if n[h].hcode == hc and n[h].key == key:
         dec s.counter
         result = false
diff --git a/lib/pure/collections/sharedtables.nim b/lib/pure/collections/sharedtables.nim
index 0fbbdb3eb..27ac5e84f 100644
--- a/lib/pure/collections/sharedtables.nim
+++ b/lib/pure/collections/sharedtables.nim
@@ -206,6 +206,11 @@ proc del*[A, B](t: var SharedTable[A, B], key: A) =
   withLock t:
     delImpl()
 
+proc len*[A, B](t: var SharedTable[A, B]): int =
+  ## number of elements in `t`
+  withLock t:
+    result = t.counter
+
 proc init*[A, B](t: var SharedTable[A, B], initialSize = 64) =
   ## creates a new hash table that is empty.
   ##
diff --git a/lib/pure/collections/tableimpl.nim b/lib/pure/collections/tableimpl.nim
index aabaeeeb3..d7facda72 100644
--- a/lib/pure/collections/tableimpl.nim
+++ b/lib/pure/collections/tableimpl.nim
@@ -107,13 +107,23 @@ template clearImpl() {.dirty.} =
     t.data[i].val = default(type(t.data[i].val))
   t.counter = 0
 
+template ctAnd(a, b): bool =
+  # pending https://github.com/nim-lang/Nim/issues/13502
+  when a:
+    when b: true
+    else: false
+  else: false
+
 template initImpl(result: typed, size: int) =
-  assert isPowerOfTwo(size)
-  result.counter = 0
-  newSeq(result.data, size)
-  when compiles(result.first):
-    result.first = -1
-    result.last = -1
+  when ctAnd(declared(SharedTable), type(result) is SharedTable):
+    init(result, size)
+  else:
+    assert isPowerOfTwo(size)
+    result.counter = 0
+    newSeq(result.data, size)
+    when compiles(result.first):
+      result.first = -1
+      result.last = -1
 
 template insertImpl() = # for CountTable
   checkIfInitialized()
diff --git a/lib/pure/collections/tables.nim b/lib/pure/collections/tables.nim
index 2e3adc6fb..131804a22 100644
--- a/lib/pure/collections/tables.nim
+++ b/lib/pure/collections/tables.nim
@@ -1118,7 +1118,7 @@ iterator pairs*[A, B](t: TableRef[A, B]): (A, B) =
   ##   # value: [1, 5, 7, 9]
   let L = len(t)
   for h in 0 .. high(t.data):
-    if isFilled(t.data[h].hcode):
+    if isFilledAndValid(t.data[h].hcode):
       yield (t.data[h].key, t.data[h].val)
       assert(len(t) == L, "the length of the table changed while iterating over it")
 
@@ -1140,7 +1140,7 @@ iterator mpairs*[A, B](t: TableRef[A, B]): (A, var B) =
 
   let L = len(t)
   for h in 0 .. high(t.data):
-    if isFilled(t.data[h].hcode):
+    if isFilledAndValid(t.data[h].hcode):
       yield (t.data[h].key, t.data[h].val)
       assert(len(t) == L, "the length of the table changed while iterating over it")
 
@@ -1161,7 +1161,7 @@ iterator keys*[A, B](t: TableRef[A, B]): A =
 
   let L = len(t)
   for h in 0 .. high(t.data):
-    if isFilled(t.data[h].hcode):
+    if isFilledAndValid(t.data[h].hcode):
       yield t.data[h].key
       assert(len(t) == L, "the length of the table changed while iterating over it")
 
@@ -1182,7 +1182,7 @@ iterator values*[A, B](t: TableRef[A, B]): B =
 
   let L = len(t)
   for h in 0 .. high(t.data):
-    if isFilled(t.data[h].hcode):
+    if isFilledAndValid(t.data[h].hcode):
       yield t.data[h].val
       assert(len(t) == L, "the length of the table changed while iterating over it")
 
@@ -1203,7 +1203,7 @@ iterator mvalues*[A, B](t: TableRef[A, B]): var B =
 
   let L = len(t)
   for h in 0 .. high(t.data):
-    if isFilled(t.data[h].hcode):
+    if isFilledAndValid(t.data[h].hcode):
       yield t.data[h].val
       assert(len(t) == L, "the length of the table changed while iterating over it")
 
@@ -1282,6 +1282,10 @@ template forAllOrderedPairs(yieldStmt: untyped) {.dirty.} =
     var h = t.first
     while h >= 0:
       var nxt = t.data[h].next
+       # For OrderedTable/OrderedTableRef, isFilled is ok because `del` is O(n)
+       # and doesn't create tombsones, but if it does start using tombstones,
+       # carefully replace `isFilled` by `isFilledAndValid` as appropriate for these
+       # table types only, ditto with `OrderedSet`.
       if isFilled(t.data[h].hcode):
         yieldStmt
       h = nxt
diff --git a/tests/collections/ttables.nim b/tests/collections/ttables.nim
index bba95c1f1..9b7506d1a 100644
--- a/tests/collections/ttables.nim
+++ b/tests/collections/ttables.nim
@@ -173,9 +173,14 @@ block tableconstr:
 block ttables2:
   proc TestHashIntInt() =
     var tab = initTable[int,int]()
-    for i in 1..1_000_000:
+    when defined(nimTestsTablesDisableSlow):
+      # helps every single time when this test needs to be debugged
+      let n = 1_000
+    else:
+      let n = 1_000_000
+    for i in 1..n:
       tab[i] = i
-    for i in 1..1_000_000:
+    for i in 1..n:
       var x = tab[i]
       if x != i : echo "not found ", i
 
@@ -395,3 +400,39 @@ block tablesref:
   orderedTableSortTest()
   echo "3"
 
+
+block: # https://github.com/nim-lang/Nim/issues/13496
+  template testDel(body) =
+    block:
+      body
+      when t is CountTable|CountTableRef:
+        t.inc(15, 1)
+        t.inc(19, 2)
+        t.inc(17, 3)
+        t.inc(150, 4)
+        t.del(150)
+      else:
+        t[15] = 1
+        t[19] = 2
+        t[17] = 3
+        t[150] = 4
+        t.del(150)
+      doAssert t.len == 3
+      doAssert sortedItems(t.values) == @[1, 2, 3]
+      doAssert sortedItems(t.keys) == @[15, 17, 19]
+      doAssert sortedPairs(t) == @[(15, 1), (17, 3), (19, 2)]
+      var s = newSeq[int]()
+      for v in t.values: s.add(v)
+      assert s.len == 3
+      doAssert sortedItems(s) == @[1, 2, 3]
+      when t is OrderedTable|OrderedTableRef:
+        doAssert toSeq(t.keys) == @[15, 19, 17]
+        doAssert toSeq(t.values) == @[1,2,3]
+        doAssert toSeq(t.pairs) == @[(15, 1), (19, 2), (17, 3)]
+
+  testDel(): (var t: Table[int, int])
+  testDel(): (let t = newTable[int, int]())
+  testDel(): (var t: OrderedTable[int, int])
+  testDel(): (let t = newOrderedTable[int, int]())
+  testDel(): (var t: CountTable[int])
+  testDel(): (let t = newCountTable[int]())
diff --git a/tests/sets/tsets_various.nim b/tests/sets/tsets_various.nim
index 8a63763b4..c27d8e124 100644
--- a/tests/sets/tsets_various.nim
+++ b/tests/sets/tsets_various.nim
@@ -7,9 +7,14 @@ set is empty
 
 import sets, hashes
 
+from sequtils import toSeq
+from algorithm import sorted
+
+proc sortedPairs[T](t: T): auto = toSeq(t.pairs).sorted
+template sortedItems(t: untyped): untyped = sorted(toSeq(t))
 
 block tsetpop:
-  var a = initSet[int]()
+  var a = initHashSet[int]()
   for i in 1..1000:
     a.incl(i)
   doAssert len(a) == 1000
@@ -50,7 +55,7 @@ block tsets2:
       "80"]
 
   block tableTest1:
-    var t = initSet[tuple[x, y: int]]()
+    var t = initHashSet[tuple[x, y: int]]()
     t.incl((0,0))
     t.incl((1,0))
     assert(not t.containsOrIncl((0,1)))
@@ -63,7 +68,7 @@ block tsets2:
     #  "{(x: 0, y: 0), (x: 0, y: 1), (x: 1, y: 0), (x: 1, y: 1)}")
 
   block setTest2:
-    var t = initSet[string]()
+    var t = initHashSet[string]()
     t.incl("test")
     t.incl("111")
     t.incl("123")
@@ -102,9 +107,9 @@ block tsets2:
 
 block tsets3:
   let
-    s1: HashSet[int] = toSet([1, 2, 4, 8, 16])
-    s2: HashSet[int] = toSet([1, 2, 3, 5, 8])
-    s3: HashSet[int] = toSet([3, 5, 7])
+    s1: HashSet[int] = toHashSet([1, 2, 4, 8, 16])
+    s2: HashSet[int] = toHashSet([1, 2, 3, 5, 8])
+    s3: HashSet[int] = toHashSet([3, 5, 7])
 
   block union:
     let
@@ -172,7 +177,7 @@ block tsets3:
       assert i in s1_s3 xor i in s1
       assert i in s2_s3 xor i in s2
 
-    assert((s3 -+- s3) == initSet[int]())
+    assert((s3 -+- s3) == initHashSet[int]())
     assert((s3 -+- s1) == s1_s3)
 
   block difference:
@@ -191,10 +196,61 @@ block tsets3:
     for i in s2:
       assert i in s2_s3 xor i in s3
 
-    assert((s2 - s2) == initSet[int]())
+    assert((s2 - s2) == initHashSet[int]())
 
   block disjoint:
     assert(not disjoint(s1, s2))
     assert disjoint(s1, s3)
     assert(not disjoint(s2, s3))
     assert(not disjoint(s2, s2))
+
+block: # https://github.com/nim-lang/Nim/issues/13496
+  template testDel(body) =
+    block:
+      body
+      t.incl(15)
+      t.incl(19)
+      t.incl(17)
+      t.incl(150)
+      t.excl(150)
+      doAssert t.len == 3
+      doAssert sortedItems(t) == @[15, 17, 19]
+      var s = newSeq[int]()
+      for v in t: s.add(v)
+      assert s.len == 3
+      doAssert sortedItems(s) == @[15, 17, 19]
+      when t is OrderedSet:
+        doAssert sortedPairs(t) == @[(a: 0, b: 15), (a: 1, b: 19), (a: 2, b: 17)]
+        doAssert toSeq(t) == @[15, 19, 17]
+
+  testDel(): (var t: HashSet[int])
+  testDel(): (var t: OrderedSet[int])
+
+block: # test correctness after a number of inserts/deletes
+  template testDel(body) =
+    block:
+      body
+      var expected: seq[int]
+      let n = 100
+      let n2 = n*2
+      for i in 0..<n:
+        t.incl(i)
+      for i in 0..<n:
+        if i mod 3 == 0:
+          t.excl(i)
+      for i in n..<n2:
+        t.incl(i)
+      for i in 0..<n2:
+        if i mod 7 == 0:
+          t.excl(i)
+
+      for i in 0..<n2:
+        if (i>=n or i mod 3 != 0) and i mod 7 != 0:
+          expected.add i
+
+      for i in expected: doAssert i in t
+      doAssert t.len == expected.len
+      doAssert sortedItems(t) == expected
+
+  testDel(): (var t: HashSet[int])
+  testDel(): (var t: OrderedSet[int])
diff --git a/tests/stdlib/tintsets.nim b/tests/stdlib/tintsets.nim
new file mode 100644
index 000000000..f859b87ae
--- /dev/null
+++ b/tests/stdlib/tintsets.nim
@@ -0,0 +1,65 @@
+import intsets
+import std/sets
+
+from sequtils import toSeq
+from algorithm import sorted
+
+proc sortedPairs[T](t: T): auto = toSeq(t.pairs).sorted
+template sortedItems(t: untyped): untyped = sorted(toSeq(t))
+
+block: # we use HashSet as groundtruth, it's well tested elsewhere
+  template testDel(t, t0) =
+
+    template checkEquals() =
+      doAssert t.len == t0.len
+      for k in t0:
+        doAssert k in t
+      for k in t:
+        doAssert k in t0
+
+      doAssert sortedItems(t) == sortedItems(t0)
+
+    template incl2(i) =
+      t.incl i
+      t0.incl i
+
+    template excl2(i) =
+      t.excl i
+      t0.excl i
+
+    block:
+      var expected: seq[int]
+      let n = 100
+      let n2 = n*2
+      for i in 0..<n:
+        incl2(i)
+      checkEquals()
+      for i in 0..<n:
+        if i mod 3 == 0:
+          if i < n div 2:
+            excl2(i)
+          else:
+            t0.excl i
+            doAssert i in t
+            doAssert not t.missingOrExcl(i)
+
+      checkEquals()
+      for i in n..<n2:
+        incl2(i)
+      checkEquals()
+      for i in 0..<n2:
+        if i mod 7 == 0:
+          excl2(i)
+      checkEquals()
+
+      # notin check
+      for i in 0..<t.len:
+        if i mod 7 == 0:
+          doAssert i notin t0
+          doAssert i notin t
+          # issue #13505
+          doAssert t.missingOrExcl(i)
+
+  var t: IntSet
+  var t0: HashSet[int]
+  testDel(t, t0)
diff --git a/tests/stdlib/tsharedtable.nim b/tests/stdlib/tsharedtable.nim
index 99d20e08a..ce6aa96df 100644
--- a/tests/stdlib/tsharedtable.nim
+++ b/tests/stdlib/tsharedtable.nim
@@ -6,10 +6,84 @@ output: '''
 
 import sharedtables
 
-var table: SharedTable[int, int]
+block:
+  var table: SharedTable[int, int]
 
-init(table)
-table[1] = 10
-assert table.mget(1) == 10
-assert table.mgetOrPut(3, 7) == 7
-assert table.mgetOrPut(3, 99) == 7
+  init(table)
+  table[1] = 10
+  assert table.mget(1) == 10
+  assert table.mgetOrPut(3, 7) == 7
+  assert table.mgetOrPut(3, 99) == 7
+  deinitSharedTable(table)
+
+import sequtils, algorithm
+proc sortedPairs[T](t: T): auto = toSeq(t.pairs).sorted
+template sortedItems(t: untyped): untyped = sorted(toSeq(t))
+
+import tables # refs issue #13504
+
+block: # we use Table as groundtruth, it's well tested elsewhere
+  template testDel(t, t0) =
+    template put2(i) =
+      t[i] = i
+      t0[i] = i
+
+    template add2(i, val) =
+      t.add(i, val)
+      t0.add(i, val)
+
+    template del2(i) =
+      t.del(i)
+      t0.del(i)
+
+    template checkEquals() =
+      doAssert t.len == t0.len
+      for k,v in t0:
+        doAssert t.mgetOrPut(k, -1) == v # sanity check
+        doAssert t.mget(k) == v
+
+    let n = 100
+    let n2 = n*2
+    let n3 = n*3
+    let n4 = n*4
+    let n5 = n*5
+
+    for i in 0..<n:
+      put2(i)
+    for i in 0..<n:
+      if i mod 3 == 0:
+        del2(i)
+    for i in n..<n2:
+      put2(i)
+    for i in 0..<n2:
+      if i mod 7 == 0:
+        del2(i)
+
+    checkEquals()
+
+    for i in n2..<n3:
+      t0[i] = -2
+      doAssert t.mgetOrPut(i, -2) == -2
+      doAssert t.mget(i) == -2
+
+    for i in 0..<n4:
+      let ok = i in t0
+      if not ok: t0[i] = -i
+      doAssert t.hasKeyOrPut(i, -i) == ok
+
+    checkEquals()
+
+    for i in n4..<n5:
+      add2(i, i*10)
+      add2(i, i*11)
+      add2(i, i*12)
+      del2(i)
+      del2(i)
+
+    checkEquals()
+
+  var t: SharedTable[int, int]
+  init(t) # ideally should be auto-init
+  var t0: Table[int, int]
+  testDel(t, t0)
+  deinitSharedTable(t)