summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--lib/pure/strutils.nim91
-rw-r--r--tests/stdlib/tstrutils.nim38
2 files changed, 71 insertions, 58 deletions
diff --git a/lib/pure/strutils.nim b/lib/pure/strutils.nim
index b98bd2bec..ffbefd0fb 100644
--- a/lib/pure/strutils.nim
+++ b/lib/pure/strutils.nim
@@ -1815,6 +1815,7 @@ func initSkipTable*(a: var SkipTable, sub: string) {.rtl,
   ## Initializes table `a` for efficient search of substring `sub`.
   ##
   ## See also:
+  ## * `initSkipTable func<#initSkipTable,string>`_
   ## * `find func<#find,SkipTable,string,string,Natural,int>`_
   # TODO: this should be the `default()` initializer for the type.
   let m = len(sub)
@@ -1823,7 +1824,16 @@ func initSkipTable*(a: var SkipTable, sub: string) {.rtl,
   for i in 0 ..< m - 1:
     a[sub[i]] = m - 1 - i
 
-func find*(a: SkipTable, s, sub: string, start: Natural = 0, last = 0): int {.
+func initSkipTable*(sub: string): SkipTable {.noinit, rtl,
+    extern: "nsuInitNewSkipTable".} =
+  ## Returns a new table initialized for `sub`.
+  ##
+  ## See also:
+  ## * `initSkipTable func<#initSkipTable,SkipTable,string>`_
+  ## * `find func<#find,SkipTable,string,string,Natural,int>`_
+  initSkipTable(result, sub)
+
+func find*(a: SkipTable, s, sub: string, start: Natural = 0, last = -1): int {.
     rtl, extern: "nsuFindStrA".} =
   ## Searches for `sub` in `s` inside range `start..last` using preprocessed
   ## table `a`. If `last` is unspecified, it defaults to `s.high` (the last
@@ -1832,9 +1842,10 @@ func find*(a: SkipTable, s, sub: string, start: Natural = 0, last = 0): int {.
   ## Searching is case-sensitive. If `sub` is not in `s`, -1 is returned.
   ##
   ## See also:
+  ## * `initSkipTable func<#initSkipTable,string>`_
   ## * `initSkipTable func<#initSkipTable,SkipTable,string>`_
   let
-    last = if last == 0: s.high else: last
+    last = if last < 0: s.high else: last
     subLast = sub.len - 1
 
   if subLast == -1:
@@ -1844,6 +1855,7 @@ func find*(a: SkipTable, s, sub: string, start: Natural = 0, last = 0): int {.
 
   # This is an implementation of the Boyer-Moore Horspool algorithms
   # https://en.wikipedia.org/wiki/Boyer%E2%80%93Moore%E2%80%93Horspool_algorithm
+  result = -1
   var skip = start
 
   while last - skip >= subLast:
@@ -1853,7 +1865,6 @@ func find*(a: SkipTable, s, sub: string, start: Natural = 0, last = 0): int {.
         return skip
       dec i
     inc skip, a[s[skip + subLast]]
-  return -1
 
 when not (defined(js) or defined(nimdoc) or defined(nimscript)):
   func c_memchr(cstr: pointer, c: char, n: csize_t): pointer {.
@@ -1865,10 +1876,10 @@ when not (defined(js) or defined(nimdoc) or defined(nimscript)):
 else:
   const hasCStringBuiltin = false
 
-func find*(s: string, sub: char, start: Natural = 0, last = 0): int {.rtl,
+func find*(s: string, sub: char, start: Natural = 0, last = -1): int {.rtl,
     extern: "nsuFindChar".} =
   ## Searches for `sub` in `s` inside range `start..last` (both ends included).
-  ## If `last` is unspecified, it defaults to `s.high` (the last element).
+  ## If `last` is unspecified or negative, it defaults to `s.high` (the last element).
   ##
   ## Searching is case-sensitive. If `sub` is not in `s`, -1 is returned.
   ## Otherwise the index returned is relative to `s[0]`, not `start`.
@@ -1877,26 +1888,30 @@ func find*(s: string, sub: char, start: Natural = 0, last = 0): int {.rtl,
   ## See also:
   ## * `rfind func<#rfind,string,char,Natural,int>`_
   ## * `replace func<#replace,string,char,char>`_
-  let last = if last == 0: s.high else: last
-  when nimvm:
+  result = -1
+  let last = if last < 0: s.high else: last
+
+  template findImpl =
     for i in int(start)..last:
-      if sub == s[i]: return i
+      if s[i] == sub:
+        return i
+
+  when nimvm:
+    findImpl()
   else:
     when hasCStringBuiltin:
-      let L = last-start+1
-      if L > 0:
-        let found = c_memchr(s[start].unsafeAddr, sub, cast[csize_t](L))
+      let length = last-start+1
+      if length > 0:
+        let found = c_memchr(s[start].unsafeAddr, sub, cast[csize_t](length))
         if not found.isNil:
           return cast[ByteAddress](found) -% cast[ByteAddress](s.cstring)
     else:
-      for i in int(start)..last:
-        if sub == s[i]: return i
-  return -1
+      findImpl()
 
-func find*(s: string, chars: set[char], start: Natural = 0, last = 0): int {.
+func find*(s: string, chars: set[char], start: Natural = 0, last = -1): int {.
     rtl, extern: "nsuFindCharSet".} =
   ## Searches for `chars` in `s` inside range `start..last` (both ends included).
-  ## If `last` is unspecified, it defaults to `s.high` (the last element).
+  ## If `last` is unspecified or negative, it defaults to `s.high` (the last element).
   ##
   ## If `s` contains none of the characters in `chars`, -1 is returned.
   ## Otherwise the index returned is relative to `s[0]`, not `start`.
@@ -1905,15 +1920,16 @@ func find*(s: string, chars: set[char], start: Natural = 0, last = 0): int {.
   ## See also:
   ## * `rfind func<#rfind,string,set[char],Natural,int>`_
   ## * `multiReplace func<#multiReplace,string,varargs[]>`_
-  let last = if last == 0: s.high else: last
+  result = -1
+  let last = if last < 0: s.high else: last
   for i in int(start)..last:
-    if s[i] in chars: return i
-  return -1
+    if s[i] in chars:
+      return i
 
-func find*(s, sub: string, start: Natural = 0, last = 0): int {.rtl,
+func find*(s, sub: string, start: Natural = 0, last = -1): int {.rtl,
     extern: "nsuFindStr".} =
   ## Searches for `sub` in `s` inside range `start..last` (both ends included).
-  ## If `last` is unspecified, it defaults to `s.high` (the last element).
+  ## If `last` is unspecified or negative, it defaults to `s.high` (the last element).
   ##
   ## Searching is case-sensitive. If `sub` is not in `s`, -1 is returned.
   ## Otherwise the index returned is relative to `s[0]`, not `start`.
@@ -1925,28 +1941,23 @@ func find*(s, sub: string, start: Natural = 0, last = 0): int {.rtl,
   if sub.len > s.len - start: return -1
   if sub.len == 1: return find(s, sub[0], start, last)
 
-  template useSkipTable {.dirty.} =
-    var a {.noinit.}: SkipTable
-    initSkipTable(a, sub)
-    result = find(a, s, sub, start, last)
+  template useSkipTable =
+    result = find(initSkipTable(sub), s, sub, start, last)
 
-  when not hasCStringBuiltin:
+  when nimvm:
     useSkipTable()
   else:
-    when nimvm:
-      useSkipTable()
-    else:
-      when hasCStringBuiltin:
-        if last == 0 and s.len > start:
-          let found = c_strstr(s[start].unsafeAddr, sub)
-          if not found.isNil:
-            result = cast[ByteAddress](found) -% cast[ByteAddress](s.cstring)
+    when hasCStringBuiltin:
+      if last < 0 and start < s.len:
+        let found = c_strstr(s[start].unsafeAddr, sub)
+        result = if not found.isNil:
+            cast[ByteAddress](found) -% cast[ByteAddress](s.cstring)
           else:
-            result = -1
-        else:
-          useSkipTable()
+            -1
       else:
         useSkipTable()
+    else:
+      useSkipTable()
 
 func rfind*(s: string, sub: char, start: Natural = 0, last = -1): int {.rtl,
     extern: "nsuRFindChar".} =
@@ -2121,8 +2132,7 @@ func replace*(s, sub: string, by = ""): string {.rtl,
     # copy the rest:
     add result, substr(s, i)
   else:
-    var a {.noinit.}: SkipTable
-    initSkipTable(a, sub)
+    var a = initSkipTable(sub)
     let last = s.high
     var i = 0
     while true:
@@ -2162,9 +2172,8 @@ func replaceWord*(s, sub: string, by = ""): string {.rtl,
   ## replaced.
   if sub.len == 0: return s
   const wordChars = {'a'..'z', 'A'..'Z', '0'..'9', '_', '\128'..'\255'}
-  var a {.noinit.}: SkipTable
   result = ""
-  initSkipTable(a, sub)
+  var a = initSkipTable(sub)
   var i = 0
   let last = s.high
   let sublen = sub.len
diff --git a/tests/stdlib/tstrutils.nim b/tests/stdlib/tstrutils.nim
index 26fc0c7b0..e17277ef2 100644
--- a/tests/stdlib/tstrutils.nim
+++ b/tests/stdlib/tstrutils.nim
@@ -232,18 +232,22 @@ template main() =
     {.pop.}
 
   block: # find
-    doAssert "0123456789ABCDEFGH".find('A') == 10
-    doAssert "0123456789ABCDEFGH".find('A', 5) == 10
-    doAssert "0123456789ABCDEFGH".find('A', 5, 10) == 10
-    doAssert "0123456789ABCDEFGH".find('A', 5, 9) == -1
-    doAssert "0123456789ABCDEFGH".find("A") == 10
-    doAssert "0123456789ABCDEFGH".find("A", 5) == 10
-    doAssert "0123456789ABCDEFGH".find("A", 5, 10) == 10
-    doAssert "0123456789ABCDEFGH".find("A", 5, 9) == -1
-    doAssert "0123456789ABCDEFGH".find({'A'..'C'}) == 10
-    doAssert "0123456789ABCDEFGH".find({'A'..'C'}, 5) == 10
-    doAssert "0123456789ABCDEFGH".find({'A'..'C'}, 5, 10) == 10
-    doAssert "0123456789ABCDEFGH".find({'A'..'C'}, 5, 9) == -1
+    const haystack: string = "0123456789ABCDEFGH"
+    doAssert haystack.find('A') == 10
+    doAssert haystack.find('A', 5) == 10
+    doAssert haystack.find('A', 5, 10) == 10
+    doAssert haystack.find('A', 5, 9) == -1
+    doAssert haystack.find("A") == 10
+    doAssert haystack.find("A", 5) == 10
+    doAssert haystack.find("A", 5, 10) == 10
+    doAssert haystack.find("A", 5, 9) == -1
+    doAssert haystack.find({'A'..'C'}) == 10
+    doAssert haystack.find({'A'..'C'}, 5) == 10
+    doAssert haystack.find({'A'..'C'}, 5, 10) == 10
+    doAssert haystack.find({'A'..'C'}, 5, 9) == -1
+    doAssert haystack.find('A', 0, 0) == -1 # search limited to the first char
+    doAssert haystack.find('A', 5, 0) == -1 # last < start
+    doAssert haystack.find('A', 5, 4) == -1 # last < start
 
     block:
       const haystack: string = "ABCABABABABCAB"
@@ -290,16 +294,16 @@ template main() =
 
     # when last <= start, searching for non-empty string
     block:
-      let last: int = -1
-      doAssert "abcd".find("ab", start=0, last=last) == -1
+      let last: int = -1 # searching through whole line
+      doAssert "abcd".find("ab", start=0, last=last) == 0
       doAssert "abcd".find("ab", start=1, last=last) == -1
-      doAssert "abcd".find("bc", start=1, last=last) == -1
+      doAssert "abcd".find("bc", start=1, last=last) == 1
       doAssert "abcd".find("bc", start=2, last=last) == -1
     block:
       let last: int = 0
-      doAssert "abcd".find("ab", start=0, last=last) == 0
+      doAssert "abcd".find("ab", start=0, last=last) == -1
       doAssert "abcd".find("ab", start=1, last=last) == -1
-      doAssert "abcd".find("bc", start=1, last=last) == 1
+      doAssert "abcd".find("bc", start=1, last=last) == -1
       doAssert "abcd".find("bc", start=2, last=last) == -1
     block:
       let last: int = 1