summary refs log tree commit diff stats
path: root/rod/nimsets.nim
diff options
context:
space:
mode:
Diffstat (limited to 'rod/nimsets.nim')
-rwxr-xr-xrod/nimsets.nim175
1 files changed, 175 insertions, 0 deletions
diff --git a/rod/nimsets.nim b/rod/nimsets.nim
new file mode 100755
index 000000000..337aedda9
--- /dev/null
+++ b/rod/nimsets.nim
@@ -0,0 +1,175 @@
+#
+#
+#           The Nimrod Compiler
+#        (c) Copyright 2009 Andreas Rumpf
+#
+#    See the file "copying.txt", included in this
+#    distribution, for details about the copyright.
+#
+
+# this unit handles Nimrod sets; it implements symbolic sets
+
+import 
+  ast, astalgo, trees, nversion, msgs, platform, bitsets, types, rnimsyn
+
+proc toBitSet*(s: PNode, b: var TBitSet)
+  # this function is used for case statement checking:
+proc overlap*(a, b: PNode): bool
+proc inSet*(s: PNode, elem: PNode): bool
+proc someInSet*(s: PNode, a, b: PNode): bool
+proc emptyRange*(a, b: PNode): bool
+proc SetHasRange*(s: PNode): bool
+  # returns true if set contains a range (needed by the code generator)
+  # these are used for constant folding:
+proc unionSets*(a, b: PNode): PNode
+proc diffSets*(a, b: PNode): PNode
+proc intersectSets*(a, b: PNode): PNode
+proc symdiffSets*(a, b: PNode): PNode
+proc containsSets*(a, b: PNode): bool
+proc equalSets*(a, b: PNode): bool
+proc cardSet*(s: PNode): BiggestInt
+# implementation
+
+proc inSet(s: PNode, elem: PNode): bool = 
+  if s.kind != nkCurly: InternalError(s.info, "inSet")
+  for i in countup(0, sonsLen(s) - 1): 
+    if s.sons[i].kind == nkRange: 
+      if leValue(s.sons[i].sons[0], elem) and
+          leValue(elem, s.sons[i].sons[1]): 
+        return true
+    else: 
+      if sameValue(s.sons[i], elem): 
+        return true
+  result = false
+
+proc overlap(a, b: PNode): bool = 
+  if a.kind == nkRange: 
+    if b.kind == nkRange: 
+      result = leValue(a.sons[0], b.sons[1]) and
+          leValue(b.sons[1], a.sons[1]) or
+          leValue(a.sons[0], b.sons[0]) and leValue(b.sons[0], a.sons[1])
+    else: 
+      result = leValue(a.sons[0], b) and leValue(b, a.sons[1])
+  else: 
+    if b.kind == nkRange: 
+      result = leValue(b.sons[0], a) and leValue(a, b.sons[1])
+    else: 
+      result = sameValue(a, b)
+
+proc SomeInSet(s: PNode, a, b: PNode): bool = 
+  # checks if some element of a..b is in the set s
+  if s.kind != nkCurly: InternalError(s.info, "SomeInSet")
+  for i in countup(0, sonsLen(s) - 1): 
+    if s.sons[i].kind == nkRange: 
+      if leValue(s.sons[i].sons[0], b) and leValue(b, s.sons[i].sons[1]) or
+          leValue(s.sons[i].sons[0], a) and leValue(a, s.sons[i].sons[1]): 
+        return true
+    else: 
+      # a <= elem <= b
+      if leValue(a, s.sons[i]) and leValue(s.sons[i], b): 
+        return true
+  result = false
+
+proc toBitSet(s: PNode, b: var TBitSet) = 
+  var first, j: BiggestInt
+  first = firstOrd(s.typ.sons[0])
+  bitSetInit(b, int(getSize(s.typ)))
+  for i in countup(0, sonsLen(s) - 1): 
+    if s.sons[i].kind == nkRange: 
+      j = getOrdValue(s.sons[i].sons[0])
+      while j <= getOrdValue(s.sons[i].sons[1]): 
+        BitSetIncl(b, j - first)
+        inc(j)
+    else: 
+      BitSetIncl(b, getOrdValue(s.sons[i]) - first)
+  
+proc ToTreeSet(s: TBitSet, settype: PType, info: TLineInfo): PNode = 
+  var 
+    a, b, e, first: BiggestInt # a, b are interval borders
+    elemType: PType
+    n: PNode
+  elemType = settype.sons[0]
+  first = firstOrd(elemType)
+  result = newNodeI(nkCurly, info)
+  result.typ = settype
+  result.info = info
+  e = 0
+  while e < high(s) * elemSize: 
+    if bitSetIn(s, e): 
+      a = e
+      b = e
+      while true: 
+        Inc(b)
+        if (b > high(s) * elemSize) or not bitSetIn(s, b): break 
+      Dec(b)
+      if a == b: 
+        addSon(result, newIntTypeNode(nkIntLit, a + first, elemType))
+      else: 
+        n = newNodeI(nkRange, info)
+        n.typ = elemType
+        addSon(n, newIntTypeNode(nkIntLit, a + first, elemType))
+        addSon(n, newIntTypeNode(nkIntLit, b + first, elemType))
+        addSon(result, n)
+      e = b
+    Inc(e)
+
+type 
+  TSetOP = enum 
+    soUnion, soDiff, soSymDiff, soIntersect
+
+proc nodeSetOp(a, b: PNode, op: TSetOp): PNode = 
+  var x, y: TBitSet
+  toBitSet(a, x)
+  toBitSet(b, y)
+  case op
+  of soUnion: BitSetUnion(x, y)
+  of soDiff: BitSetDiff(x, y)
+  of soSymDiff: BitSetSymDiff(x, y)
+  of soIntersect: BitSetIntersect(x, y)
+  result = toTreeSet(x, a.typ, a.info)
+
+proc unionSets(a, b: PNode): PNode = 
+  result = nodeSetOp(a, b, soUnion)
+
+proc diffSets(a, b: PNode): PNode = 
+  result = nodeSetOp(a, b, soDiff)
+
+proc intersectSets(a, b: PNode): PNode = 
+  result = nodeSetOp(a, b, soIntersect)
+
+proc symdiffSets(a, b: PNode): PNode = 
+  result = nodeSetOp(a, b, soSymDiff)
+
+proc containsSets(a, b: PNode): bool = 
+  var x, y: TBitSet
+  toBitSet(a, x)
+  toBitSet(b, y)
+  result = bitSetContains(x, y)
+
+proc equalSets(a, b: PNode): bool = 
+  var x, y: TBitSet
+  toBitSet(a, x)
+  toBitSet(b, y)
+  result = bitSetEquals(x, y)
+
+proc cardSet(s: PNode): BiggestInt = 
+  # here we can do better than converting it into a compact set
+  # we just count the elements directly
+  result = 0
+  for i in countup(0, sonsLen(s) - 1): 
+    if s.sons[i].kind == nkRange: 
+      result = result + getOrdValue(s.sons[i].sons[1]) -
+          getOrdValue(s.sons[i].sons[0]) + 1
+    else: 
+      Inc(result)
+  
+proc SetHasRange(s: PNode): bool = 
+  if s.kind != nkCurly: InternalError(s.info, "SetHasRange")
+  for i in countup(0, sonsLen(s) - 1): 
+    if s.sons[i].kind == nkRange: 
+      return true
+  result = false
+
+proc emptyRange(a, b: PNode): bool = 
+  result = not leValue(a, b)  # a > b iff not (a <= b)
+  
\ No newline at end of file