summary refs log tree commit diff stats
diff options
context:
space:
mode:
authorDmitry Atamanov <data-man@users.noreply.github.com>2018-05-20 21:11:25 +0300
committerAndreas Rumpf <rumpf_a@web.de>2018-05-20 20:11:25 +0200
commit90afb1baa7cd4c047093f6cb108c954cb3cb6ca9 (patch)
treeff93618120f00eb20208dd99037fe46fd31b40f2
parentf890f607360c18e96c5b1750fed5d682e04ed163 (diff)
downloadNim-90afb1baa7cd4c047093f6cb108c954cb3cb6ca9.tar.gz
binarySearch improvements (#7850)
-rw-r--r--changelog.md1
-rw-r--r--lib/pure/algorithm.nim50
2 files changed, 37 insertions, 14 deletions
diff --git a/changelog.md b/changelog.md
index f9daf55d1..57b54c71e 100644
--- a/changelog.md
+++ b/changelog.md
@@ -56,6 +56,7 @@
 - Added the type ``times.Duration`` for representing fixed durations of time.
 - Added the proc ``times.convert`` for converting between different time units,
   e.g days to seconds.
+- Added the proc ``algorithm.binarySearch[T, K]`` with the ```cmp``` parameter.
 
 ### Library changes
 
diff --git a/lib/pure/algorithm.nim b/lib/pure/algorithm.nim
index 2b668d1ca..169dcd602 100644
--- a/lib/pure/algorithm.nim
+++ b/lib/pure/algorithm.nim
@@ -64,27 +64,49 @@ proc reversed*[T](a: openArray[T]): seq[T] =
   ## returns the reverse of the array `a`.
   reversed(a, 0, a.high)
 
-proc binarySearch*[T](a: openArray[T], key: T): int =
+proc binarySearch*[T, K](a: openArray[T], key: K,
+              cmp: proc (x: T, y: K): int {.closure.}): int =
   ## binary search for `key` in `a`. Returns -1 if not found.
-  if ((a.len - 1) and a.len) == 0 and a.len > 0:
-    # when `a.len` is a power of 2, a faster div can be used.
-    var step = a.len div 2
+  ##
+  ## `cmp` is the comparator function to use, the expected return values are
+  ## the same as that of system.cmp.
+  if a.len == 0:
+    return -1
+
+  let len = a.len
+
+  if len == 1:
+    if cmp(a[0], key) == 0:
+      return 0
+    else:
+      return -1
+
+  if (len and (len - 1)) == 0:
+    # when `len` is a power of 2, a faster shr can be used.
+    var step = len shr 1
     while step > 0:
-      if a[result or step] <= key:
-        result = result or step
+      let i = result or step
+      if cmp(a[i], key) < 1:
+        result = i
       step = step shr 1
-    if a[result] != key: result = -1
+    if cmp(a[result], key) != 0: result = -1
   else:
-    var b = len(a)
+    var b = len
     while result < b:
-      var mid = (result + b) div 2
-      if a[mid] < key: result = mid + 1
-      else: b = mid
-    if result >= len(a) or a[result] != key: result = -1
+      var mid = (result + b) shr 1
+      if cmp(a[mid], key) < 0:
+        result = mid + 1
+      else:
+        b = mid
+    if result >= len or cmp(a[result], key) != 0: result = -1
+
+proc binarySearch*[T](a: openArray[T], key: T): int =
+  ## binary search for `key` in `a`. Returns -1 if not found.
+  binarySearch(a, key, cmp[T])
 
 proc smartBinarySearch*[T](a: openArray[T], key: T): int {.deprecated.} =
   ## **Deprecated since version 0.18.1**; Use ``binarySearch`` instead.
-  binarySearch(a,key)
+  binarySearch(a, key, cmp[T])
 
 const
   onlySafeCode = true
@@ -108,7 +130,7 @@ proc lowerBound*[T, K](a: openArray[T], key: K, cmp: proc(x: T, k: K): int {.clo
   var count = a.high - a.low + 1
   var step, pos: int
   while count != 0:
-    step = count div 2
+    step = count shr 1
     pos = result + step
     if cmp(a[pos], key) < 0:
       result = pos + 1