summary refs log tree commit diff stats
path: root/lib
diff options
context:
space:
mode:
authorImran Hendley <imran.hendley@gmail.com>2018-02-12 15:20:17 -0500
committerAndreas Rumpf <rumpf_a@web.de>2018-02-12 21:20:17 +0100
commitdf4f707743879a0ea4363fcef446d89d8b421513 (patch)
tree2dfb8e9389a36a80c76472fd67d017c54ab0a2db /lib
parent4c702d5ab22d47f24c6e7a3c6679f2d43136f2d2 (diff)
downloadNim-df4f707743879a0ea4363fcef446d89d8b421513.tar.gz
add more functionality from sets to intsets (#7185)
* add more functionality from sets to intsets

* remove -+-

* < and == performance

* don't hardcode s.a.len

* remove shortcuts from < and ==
Diffstat (limited to 'lib')
-rw-r--r--lib/pure/collections/intsets.nim177
1 files changed, 156 insertions, 21 deletions
diff --git a/lib/pure/collections/intsets.nim b/lib/pure/collections/intsets.nim
index 085232564..bfecfe447 100644
--- a/lib/pure/collections/intsets.nim
+++ b/lib/pure/collections/intsets.nim
@@ -108,6 +108,28 @@ proc contains*(s: IntSet, key: int): bool =
     else:
       result = false
 
+iterator items*(s: IntSet): int {.inline.} =
+  ## iterates over any included element of `s`.
+  if s.elems <= s.a.len:
+    for i in 0..<s.elems:
+      yield s.a[i]
+  else:
+    var r = s.head
+    while r != nil:
+      var i = 0
+      while i <= high(r.bits):
+        var w = r.bits[i]
+        # taking a copy of r.bits[i] here is correct, because
+        # modifying operations are not allowed during traversation
+        var j = 0
+        while w != 0:         # test all remaining bits for zero
+          if (w and 1) != 0:  # the bit is set!
+            yield (r.key shl TrunkShift) or (i shl IntShift +% j)
+          inc(j)
+          w = w shr 1
+        inc(i)
+      r = r.next
+
 proc bitincl(s: var IntSet, key: int) {.inline.} =
   var t = intSetPut(s, `shr`(key, TrunkShift))
   var u = key and TrunkMask
@@ -131,6 +153,10 @@ proc incl*(s: var IntSet, key: int) =
     # fall through:
   bitincl(s, key)
 
+proc incl*(s: var IntSet, other: IntSet) =
+  ## Includes all elements from `other` into `s`.
+  for item in other: incl(s, item)
+
 proc exclImpl(s: var IntSet, key: int) =
   if s.elems <= s.a.len:
     for i in 0..<s.elems:
@@ -149,6 +175,10 @@ proc excl*(s: var IntSet, key: int) =
   ## excludes `key` from the set `s`.
   exclImpl(s, key)
 
+proc excl*(s: var IntSet, other: IntSet) =
+  ## Excludes all elements from `other` from `s`.
+  for item in other: excl(s, item)
+
 proc missingOrExcl*(s: var IntSet, key: int) : bool =
   ## returns true if `s` does not contain `key`, otherwise
   ## `key` is removed from `s` and false is returned.
@@ -232,27 +262,77 @@ proc assign*(dest: var IntSet, src: IntSet) =
 
       it = it.next
 
-iterator items*(s: IntSet): int {.inline.} =
-  ## iterates over any included element of `s`.
-  if s.elems <= s.a.len:
-    for i in 0..<s.elems:
-      yield s.a[i]
+proc union*(s1, s2: IntSet): IntSet =
+  ## Returns the union of the sets `s1` and `s2`.
+  result.assign(s1)
+  incl(result, s2)
+
+proc intersection*(s1, s2: IntSet): IntSet =
+  ## Returns the intersection of the sets `s1` and `s2`.
+  result = initIntSet()
+  for item in s1:
+    if contains(s2, item):
+      incl(result, item)
+
+proc difference*(s1, s2: IntSet): IntSet =
+  ## Returns the difference of the sets `s1` and `s2`.
+  result = initIntSet()
+  for item in s1:
+    if not contains(s2, item):
+      incl(result, item)
+
+proc symmetricDifference*(s1, s2: IntSet): IntSet =
+  ## Returns the symmetric difference of the sets `s1` and `s2`.
+  result.assign(s1)
+  for item in s2:
+    if containsOrIncl(result, item): excl(result, item)
+
+proc `+`*(s1, s2: IntSet): IntSet {.inline.} =
+  ## Alias for `union(s1, s2) <#union>`_.
+  result = union(s1, s2)
+
+proc `*`*(s1, s2: IntSet): IntSet {.inline.} =
+  ## Alias for `intersection(s1, s2) <#intersection>`_.
+  result = intersection(s1, s2)
+
+proc `-`*(s1, s2: IntSet): IntSet {.inline.} =
+  ## Alias for `difference(s1, s2) <#difference>`_.
+  result = difference(s1, s2)
+
+proc disjoint*(s1, s2: IntSet): bool =
+  ## Returns true iff the sets `s1` and `s2` have no items in common.
+  for item in s1:
+    if contains(s2, item):
+      return false
+  return true
+
+proc len*(s: IntSet): int {.inline.} =
+  ## Returns the number of keys in `s`.
+  if s.elems < s.a.len:
+    result = s.elems
   else:
-    var r = s.head
-    while r != nil:
-      var i = 0
-      while i <= high(r.bits):
-        var w = r.bits[i]
-        # taking a copy of r.bits[i] here is correct, because
-        # modifying operations are not allowed during traversation
-        var j = 0
-        while w != 0:         # test all remaining bits for zero
-          if (w and 1) != 0:  # the bit is set!
-            yield (r.key shl TrunkShift) or (i shl IntShift +% j)
-          inc(j)
-          w = w shr 1
-        inc(i)
-      r = r.next
+    result = 0
+    for _ in s:
+      inc(result)
+
+proc card*(s: IntSet): int {.inline.} = 
+  ## alias for `len() <#len>` _.
+  result = s.len()
+
+proc `<=`*(s1, s2: IntSet): bool =
+  ## Returns true iff `s1` is subset of `s2`.
+  for item in s1:
+    if not s2.contains(item):
+      return false
+  return true
+
+proc `<`*(s1, s2: IntSet): bool =
+  ## Returns true iff `s1` is proper subset of `s2`.
+  return s1 <= s2 and not (s2 <= s1)
+
+proc `==`*(s1, s2: IntSet): bool =
+  ## Returns true if both `s` and `t` have the same members and set size.
+  return s1 <= s2 and s2 <= s1
 
 template dollarImpl(): untyped =
   result = "{"
@@ -301,9 +381,64 @@ when isMainModule:
   ys.sort(cmp[int])
   assert ys == @[1, 2, 7, 1056]
 
+  assert x == y
+
   var z: IntSet
   for i in 0..1000:
     incl z, i
+    assert z.len() == i+1
   for i in 0..1000:
-    assert i in z
+    assert z.contains(i)
+
+  var w = initIntSet()
+  w.incl(1)
+  w.incl(4)
+  w.incl(50)
+  w.incl(1001)
+  w.incl(1056)
+
+  var xuw = x.union(w)
+  var xuws = toSeq(items(xuw))
+  xuws.sort(cmp[int])
+  assert xuws == @[1, 2, 4, 7, 50, 1001, 1056]
+
+  var xiw = x.intersection(w)
+  var xiws = toSeq(items(xiw))
+  xiws.sort(cmp[int])
+  assert xiws == @[1, 1056]
+
+  var xdw = x.difference(w)
+  var xdws = toSeq(items(xdw))
+  xdws.sort(cmp[int])
+  assert xdws == @[2, 7]
+
+  var xsw = x.symmetricDifference(w)
+  var xsws = toSeq(items(xsw))
+  xsws.sort(cmp[int])
+  assert xsws == @[2, 4, 7, 50, 1001]
+
+  x.incl(w)
+  xs = toSeq(items(x))
+  xs.sort(cmp[int])
+  assert xs == @[1, 2, 4, 7, 50, 1001, 1056]
+
+  assert w <= x
+
+  assert w < x
+
+  assert(not disjoint(w, x))
 
+  var u = initIntSet()
+  u.incl(3)
+  u.incl(5)
+  u.incl(500)
+  assert disjoint(u, x)
+
+  var v = initIntSet()
+  v.incl(2)
+  v.incl(50)
+
+  x.excl(v)
+  xs = toSeq(items(x))
+  xs.sort(cmp[int])
+  assert xs == @[1, 4, 7, 1001, 1056]