diff options
author | Imran Hendley <imran.hendley@gmail.com> | 2018-02-12 15:20:17 -0500 |
---|---|---|
committer | Andreas Rumpf <rumpf_a@web.de> | 2018-02-12 21:20:17 +0100 |
commit | df4f707743879a0ea4363fcef446d89d8b421513 (patch) | |
tree | 2dfb8e9389a36a80c76472fd67d017c54ab0a2db /lib | |
parent | 4c702d5ab22d47f24c6e7a3c6679f2d43136f2d2 (diff) | |
download | Nim-df4f707743879a0ea4363fcef446d89d8b421513.tar.gz |
add more functionality from sets to intsets (#7185)
* add more functionality from sets to intsets * remove -+- * < and == performance * don't hardcode s.a.len * remove shortcuts from < and ==
Diffstat (limited to 'lib')
-rw-r--r-- | lib/pure/collections/intsets.nim | 177 |
1 files changed, 156 insertions, 21 deletions
diff --git a/lib/pure/collections/intsets.nim b/lib/pure/collections/intsets.nim index 085232564..bfecfe447 100644 --- a/lib/pure/collections/intsets.nim +++ b/lib/pure/collections/intsets.nim @@ -108,6 +108,28 @@ proc contains*(s: IntSet, key: int): bool = else: result = false +iterator items*(s: IntSet): int {.inline.} = + ## iterates over any included element of `s`. + if s.elems <= s.a.len: + for i in 0..<s.elems: + yield s.a[i] + else: + var r = s.head + while r != nil: + var i = 0 + while i <= high(r.bits): + var w = r.bits[i] + # taking a copy of r.bits[i] here is correct, because + # modifying operations are not allowed during traversation + var j = 0 + while w != 0: # test all remaining bits for zero + if (w and 1) != 0: # the bit is set! + yield (r.key shl TrunkShift) or (i shl IntShift +% j) + inc(j) + w = w shr 1 + inc(i) + r = r.next + proc bitincl(s: var IntSet, key: int) {.inline.} = var t = intSetPut(s, `shr`(key, TrunkShift)) var u = key and TrunkMask @@ -131,6 +153,10 @@ proc incl*(s: var IntSet, key: int) = # fall through: bitincl(s, key) +proc incl*(s: var IntSet, other: IntSet) = + ## Includes all elements from `other` into `s`. + for item in other: incl(s, item) + proc exclImpl(s: var IntSet, key: int) = if s.elems <= s.a.len: for i in 0..<s.elems: @@ -149,6 +175,10 @@ proc excl*(s: var IntSet, key: int) = ## excludes `key` from the set `s`. exclImpl(s, key) +proc excl*(s: var IntSet, other: IntSet) = + ## Excludes all elements from `other` from `s`. + for item in other: excl(s, item) + proc missingOrExcl*(s: var IntSet, key: int) : bool = ## returns true if `s` does not contain `key`, otherwise ## `key` is removed from `s` and false is returned. @@ -232,27 +262,77 @@ proc assign*(dest: var IntSet, src: IntSet) = it = it.next -iterator items*(s: IntSet): int {.inline.} = - ## iterates over any included element of `s`. - if s.elems <= s.a.len: - for i in 0..<s.elems: - yield s.a[i] +proc union*(s1, s2: IntSet): IntSet = + ## Returns the union of the sets `s1` and `s2`. + result.assign(s1) + incl(result, s2) + +proc intersection*(s1, s2: IntSet): IntSet = + ## Returns the intersection of the sets `s1` and `s2`. + result = initIntSet() + for item in s1: + if contains(s2, item): + incl(result, item) + +proc difference*(s1, s2: IntSet): IntSet = + ## Returns the difference of the sets `s1` and `s2`. + result = initIntSet() + for item in s1: + if not contains(s2, item): + incl(result, item) + +proc symmetricDifference*(s1, s2: IntSet): IntSet = + ## Returns the symmetric difference of the sets `s1` and `s2`. + result.assign(s1) + for item in s2: + if containsOrIncl(result, item): excl(result, item) + +proc `+`*(s1, s2: IntSet): IntSet {.inline.} = + ## Alias for `union(s1, s2) <#union>`_. + result = union(s1, s2) + +proc `*`*(s1, s2: IntSet): IntSet {.inline.} = + ## Alias for `intersection(s1, s2) <#intersection>`_. + result = intersection(s1, s2) + +proc `-`*(s1, s2: IntSet): IntSet {.inline.} = + ## Alias for `difference(s1, s2) <#difference>`_. + result = difference(s1, s2) + +proc disjoint*(s1, s2: IntSet): bool = + ## Returns true iff the sets `s1` and `s2` have no items in common. + for item in s1: + if contains(s2, item): + return false + return true + +proc len*(s: IntSet): int {.inline.} = + ## Returns the number of keys in `s`. + if s.elems < s.a.len: + result = s.elems else: - var r = s.head - while r != nil: - var i = 0 - while i <= high(r.bits): - var w = r.bits[i] - # taking a copy of r.bits[i] here is correct, because - # modifying operations are not allowed during traversation - var j = 0 - while w != 0: # test all remaining bits for zero - if (w and 1) != 0: # the bit is set! - yield (r.key shl TrunkShift) or (i shl IntShift +% j) - inc(j) - w = w shr 1 - inc(i) - r = r.next + result = 0 + for _ in s: + inc(result) + +proc card*(s: IntSet): int {.inline.} = + ## alias for `len() <#len>` _. + result = s.len() + +proc `<=`*(s1, s2: IntSet): bool = + ## Returns true iff `s1` is subset of `s2`. + for item in s1: + if not s2.contains(item): + return false + return true + +proc `<`*(s1, s2: IntSet): bool = + ## Returns true iff `s1` is proper subset of `s2`. + return s1 <= s2 and not (s2 <= s1) + +proc `==`*(s1, s2: IntSet): bool = + ## Returns true if both `s` and `t` have the same members and set size. + return s1 <= s2 and s2 <= s1 template dollarImpl(): untyped = result = "{" @@ -301,9 +381,64 @@ when isMainModule: ys.sort(cmp[int]) assert ys == @[1, 2, 7, 1056] + assert x == y + var z: IntSet for i in 0..1000: incl z, i + assert z.len() == i+1 for i in 0..1000: - assert i in z + assert z.contains(i) + + var w = initIntSet() + w.incl(1) + w.incl(4) + w.incl(50) + w.incl(1001) + w.incl(1056) + + var xuw = x.union(w) + var xuws = toSeq(items(xuw)) + xuws.sort(cmp[int]) + assert xuws == @[1, 2, 4, 7, 50, 1001, 1056] + + var xiw = x.intersection(w) + var xiws = toSeq(items(xiw)) + xiws.sort(cmp[int]) + assert xiws == @[1, 1056] + + var xdw = x.difference(w) + var xdws = toSeq(items(xdw)) + xdws.sort(cmp[int]) + assert xdws == @[2, 7] + + var xsw = x.symmetricDifference(w) + var xsws = toSeq(items(xsw)) + xsws.sort(cmp[int]) + assert xsws == @[2, 4, 7, 50, 1001] + + x.incl(w) + xs = toSeq(items(x)) + xs.sort(cmp[int]) + assert xs == @[1, 2, 4, 7, 50, 1001, 1056] + + assert w <= x + + assert w < x + + assert(not disjoint(w, x)) + var u = initIntSet() + u.incl(3) + u.incl(5) + u.incl(500) + assert disjoint(u, x) + + var v = initIntSet() + v.incl(2) + v.incl(50) + + x.excl(v) + xs = toSeq(items(x)) + xs.sort(cmp[int]) + assert xs == @[1, 4, 7, 1001, 1056] |