summary refs log tree commit diff stats
path: root/lib/pure/collections/intsets.nim
blob: 3fd81acf6e67012626fd7291979e228151585fd5 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
#
#
#            Nimrod's Runtime Library
#        (c) Copyright 2011 Andreas Rumpf
#
#    See the file "copying.txt", included in this
#    distribution, for details about the copyright.
#

## The ``intsets`` module implements an efficient int set implemented as a
## sparse bit set.
## **Note**: Since Nimrod currently does not allow the assignment operator to
## be overloaded, ``=`` for int sets performs some rather meaningless shallow
## copy.

import
  os, hashes, math

type
  TBitScalar = int

const 
  InitIntSetSize = 8         # must be a power of two!
  TrunkShift = 9
  BitsPerTrunk = 1 shl TrunkShift # needs to be a power of 2 and
                                  # divisible by 64
  TrunkMask = BitsPerTrunk - 1
  IntsPerTrunk = BitsPerTrunk div (sizeof(TBitScalar) * 8)
  IntShift = 5 + ord(sizeof(TBitScalar) == 8) # 5 or 6, depending on int width
  IntMask = 1 shl IntShift - 1

type
  PTrunk = ref TTrunk
  TTrunk {.final.} = object 
    next: PTrunk             # all nodes are connected with this pointer
    key: int                 # start address at bit 0
    bits: array[0..IntsPerTrunk - 1, TBitScalar] # a bit vector
  
  TTrunkSeq = seq[PTrunk]
  TIntSet* {.final.} = object ## an efficient set of 'int' implemented as a
                              ## sparse bit set
    counter, max: int
    head: PTrunk
    data: TTrunkSeq

proc mustRehash(length, counter: int): bool {.inline.} = 
  assert(length > counter)
  result = (length * 2 < counter * 3) or (length - counter < 4)

proc nextTry(h, maxHash: THash): THash {.inline.} = 
  result = ((5 * h) + 1) and maxHash 

proc IntSetGet(t: TIntSet, key: int): PTrunk = 
  var h = key and t.max
  while t.data[h] != nil: 
    if t.data[h].key == key: 
      return t.data[h]
    h = nextTry(h, t.max)
  result = nil

proc IntSetRawInsert(t: TIntSet, data: var TTrunkSeq, desc: PTrunk) = 
  var h = desc.key and t.max
  while data[h] != nil: 
    assert(data[h] != desc)
    h = nextTry(h, t.max)
  assert(data[h] == nil)
  data[h] = desc

proc IntSetEnlarge(t: var TIntSet) = 
  var n: TTrunkSeq
  var oldMax = t.max
  t.max = ((t.max + 1) * 2) - 1
  newSeq(n, t.max + 1)
  for i in countup(0, oldmax): 
    if t.data[i] != nil: IntSetRawInsert(t, n, t.data[i])
  swap(t.data, n)

proc IntSetPut(t: var TIntSet, key: int): PTrunk = 
  var h = key and t.max
  while t.data[h] != nil: 
    if t.data[h].key == key: 
      return t.data[h]
    h = nextTry(h, t.max)
  if mustRehash(t.max + 1, t.counter): IntSetEnlarge(t)
  inc(t.counter)
  h = key and t.max
  while t.data[h] != nil: h = nextTry(h, t.max)
  assert(t.data[h] == nil)
  new(result)
  result.next = t.head
  result.key = key
  t.head = result
  t.data[h] = result

proc contains*(s: TIntSet, key: int): bool =
  ## returns true iff `key` is in `s`.  
  var t = IntSetGet(s, `shr`(key, TrunkShift))
  if t != nil: 
    var u = key and TrunkMask
    result = (t.bits[`shr`(u, IntShift)] and `shl`(1, u and IntMask)) != 0
  else: 
    result = false
  
proc incl*(s: var TIntSet, key: int) = 
  ## includes an element `key` in `s`.
  var t = IntSetPut(s, `shr`(key, TrunkShift))
  var u = key and TrunkMask
  t.bits[`shr`(u, IntShift)] = t.bits[`shr`(u, IntShift)] or
      `shl`(1, u and IntMask)

proc excl*(s: var TIntSet, key: int) = 
  ## excludes `key` from the set `s`.
  var t = IntSetGet(s, `shr`(key, TrunkShift))
  if t != nil: 
    var u = key and TrunkMask
    t.bits[`shr`(u, IntShift)] = t.bits[`shr`(u, IntShift)] and
        not `shl`(1, u and IntMask)

proc containsOrIncl*(s: var TIntSet, key: int): bool = 
  ## returns true if `s` contains `key`, otherwise `key` is included in `s`
  ## and false is returned.
  var t = IntSetGet(s, `shr`(key, TrunkShift))
  if t != nil: 
    var u = key and TrunkMask
    result = (t.bits[`shr`(u, IntShift)] and `shl`(1, u and IntMask)) != 0
    if not result: 
      t.bits[`shr`(u, IntShift)] = t.bits[`shr`(u, IntShift)] or
          `shl`(1, u and IntMask)
  else: 
    incl(s, key)
    result = false
    
proc initIntSet*: TIntSet =
  ## creates a new int set that is empty.
  newSeq(result.data, InitIntSetSize)
  result.max = InitIntSetSize-1
  result.counter = 0
  result.head = nil

template dollarImpl(): stmt =
  result = "{"
  for key in items(s):
    if result.len > 1: result.add(", ")
    result.add($key)
  result.add("}")

proc `$`*(s: TIntSet): string =
  ## The `$` operator for int sets.
  dollarImpl()

iterator items*(s: TIntSet): int {.inline.} =
  ## iterates over any included element of `s`.
  var r = s.head
  while r != nil:
    var i = 0
    while i <= high(r.bits):
      var w = r.bits[i] 
      # taking a copy of r.bits[i] here is correct, because
      # modifying operations are not allowed during traversation
      var j = 0
      while w != 0:         # test all remaining bits for zero
        if (w and 1) != 0:  # the bit is set!
          yield (r.key shl TrunkShift) or (i shl IntShift +% j)
        inc(j)
        w = w shr 1
      inc(i)
    r = r.next

when isMainModule:
  var x = initIntSet()
  x.incl(1)
  x.incl(2)
  x.incl(7)
  x.incl(1056)
  for e in items(x): echo e