summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--lib/pure/strutils.nim64
-rw-r--r--tests/stdlib/tstrutil.nim27
2 files changed, 76 insertions, 15 deletions
diff --git a/lib/pure/strutils.nim b/lib/pure/strutils.nim
index 8b5db49ed..235e66f6a 100644
--- a/lib/pure/strutils.nim
+++ b/lib/pure/strutils.nim
@@ -809,7 +809,7 @@ proc split*(s: string, sep: string, maxsplit: int = -1): seq[string] {.noSideEff
   ## Substrings are separated by the string `sep`. This is a wrapper around the
   ## `split iterator <#split.i,string,string>`_.
   doAssert(sep.len > 0)
-  
+
   accumulateResult(split(s, sep, maxsplit))
 
 proc rsplit*(s: string, seps: set[char] = Whitespace,
@@ -1318,11 +1318,11 @@ proc preprocessSub(sub: string, a: var SkipTable) =
   for i in 0..m-1: a[sub[i]] = m-i
 {.pop.}
 
-proc findAux(s, sub: string, start: int, a: SkipTable): int =
+proc findAux(s, sub: string, start, last: int, a: SkipTable): int =
   # Fast "quick search" algorithm:
   var
     m = len(sub)
-    n = len(s)
+    n = last + 1
   # search:
   var j = start
   while j <= n - m:
@@ -1333,30 +1333,53 @@ proc findAux(s, sub: string, start: int, a: SkipTable): int =
     inc(j, a[s[j+m]])
   return -1
 
-proc find*(s, sub: string, start: Natural = 0): int {.noSideEffect,
+when not (defined(js) or defined(nimdoc) or defined(nimscript)):
+  proc c_memchr(cstr: pointer, c: char, n: csize): pointer {.
+                importc: "memchr", header: "<string.h>" .}
+  const hasCStringBuiltin = true
+else:
+  const hasCStringBuiltin = false
+
+proc find*(s, sub: string, start: Natural = 0, last: Natural = 0): int {.noSideEffect,
   rtl, extern: "nsuFindStr".} =
-  ## Searches for `sub` in `s` starting at position `start`.
+  ## Searches for `sub` in `s` inside range `start`..`last`.
+  ## If `last` is unspecified, it defaults to `s.high`.
   ##
   ## Searching is case-sensitive. If `sub` is not in `s`, -1 is returned.
   var a {.noinit.}: SkipTable
+  let last = if last==0: s.high else: last
   preprocessSub(sub, a)
-  result = findAux(s, sub, start, a)
+  result = findAux(s, sub, start, last, a)
 
-proc find*(s: string, sub: char, start: Natural = 0): int {.noSideEffect,
+proc find*(s: string, sub: char, start: Natural = 0, last: Natural = 0): int {.noSideEffect,
   rtl, extern: "nsuFindChar".} =
-  ## Searches for `sub` in `s` starting at position `start`.
+  ## Searches for `sub` in `s` inside range `start`..`last`.
+  ## If `last` is unspecified, it defaults to `s.high`.
   ##
   ## Searching is case-sensitive. If `sub` is not in `s`, -1 is returned.
-  for i in start..len(s)-1:
-    if sub == s[i]: return i
+  let last = if last==0: s.high else: last
+  when nimvm:
+    for i in start..last:
+      if sub == s[i]: return i
+  else:
+    when hasCStringBuiltin:
+      let found = c_memchr(s[start].unsafeAddr, sub, last-start+1)
+      if not found.isNil:
+        return cast[ByteAddress](found) -% cast[ByteAddress](s.cstring)
+    else:
+      for i in start..last:
+        if sub == s[i]: return i
+
   return -1
 
-proc find*(s: string, chars: set[char], start: Natural = 0): int {.noSideEffect,
+proc find*(s: string, chars: set[char], start: Natural = 0, last: Natural = 0): int {.noSideEffect,
   rtl, extern: "nsuFindCharSet".} =
-  ## Searches for `chars` in `s` starting at position `start`.
+  ## Searches for `chars` in `s` inside range `start`..`last`.
+  ## If `last` is unspecified, it defaults to `s.high`.
   ##
   ## If `s` contains none of the characters in `chars`, -1 is returned.
-  for i in start..s.len-1:
+  let last = if last==0: s.high else: last
+  for i in start..last:
     if s[i] in chars: return i
   return -1
 
@@ -1385,6 +1408,15 @@ proc rfind*(s: string, sub: char, start: int = -1): int {.noSideEffect,
     if sub == s[i]: return i
   return -1
 
+proc rfind*(s: string, chars: set[char], start: int = -1): int {.noSideEffect.} =
+  ## Searches for `chars` in `s` in reverse starting at position `start`.
+  ##
+  ## Searching is case-sensitive. If `sub` is not in `s`, -1 is returned.
+  let realStart = if start == -1: s.len-1 else: start
+  for i in countdown(realStart, 0):
+    if s[i] in chars: return i
+  return -1
+
 proc center*(s: string, width: int, fillChar: char = ' '): string {.
   noSideEffect, rtl, extern: "nsuCenterString".} =
   ## Return the contents of `s` centered in a string `width` long using
@@ -1472,9 +1504,10 @@ proc replace*(s, sub: string, by = ""): string {.noSideEffect,
   var a {.noinit.}: SkipTable
   result = ""
   preprocessSub(sub, a)
+  let last = s.high
   var i = 0
   while true:
-    var j = findAux(s, sub, i, a)
+    var j = findAux(s, sub, i, last, a)
     if j < 0: break
     add result, substr(s, i, j - 1)
     add result, by
@@ -1506,8 +1539,9 @@ proc replaceWord*(s, sub: string, by = ""): string {.noSideEffect,
   result = ""
   preprocessSub(sub, a)
   var i = 0
+  let last = s.high
   while true:
-    var j = findAux(s, sub, i, a)
+    var j = findAux(s, sub, i, last, a)
     if j < 0: break
     # word boundary?
     if (j == 0 or s[j-1] notin wordChars) and
diff --git a/tests/stdlib/tstrutil.nim b/tests/stdlib/tstrutil.nim
index b97f2b1e9..b5e3db4e2 100644
--- a/tests/stdlib/tstrutil.nim
+++ b/tests/stdlib/tstrutil.nim
@@ -64,7 +64,34 @@ proc testDelete =
   delete(s, 0, 0)
   assert s == "1236789ABCDEFG"
 
+proc testFind =
+  assert "0123456789ABCDEFGH".find('A') == 10
+  assert "0123456789ABCDEFGH".find('A', 5) == 10
+  assert "0123456789ABCDEFGH".find('A', 5, 10) == 10
+  assert "0123456789ABCDEFGH".find('A', 5, 9) == -1
+  assert "0123456789ABCDEFGH".find("A") == 10
+  assert "0123456789ABCDEFGH".find("A", 5) == 10
+  assert "0123456789ABCDEFGH".find("A", 5, 10) == 10
+  assert "0123456789ABCDEFGH".find("A", 5, 9) == -1
+  assert "0123456789ABCDEFGH".find({'A'..'C'}) == 10
+  assert "0123456789ABCDEFGH".find({'A'..'C'}, 5) == 10
+  assert "0123456789ABCDEFGH".find({'A'..'C'}, 5, 10) == 10
+  assert "0123456789ABCDEFGH".find({'A'..'C'}, 5, 9) == -1
+
+proc testRFind =
+  assert "0123456789ABCDEFGAH".rfind('A') == 17
+  assert "0123456789ABCDEFGAH".rfind('A', 13) == 10
+  assert "0123456789ABCDEFGAH".rfind('H', 13) == -1
+  assert "0123456789ABCDEFGAH".rfind("A") == 17
+  assert "0123456789ABCDEFGAH".rfind("A", 13) == 10
+  assert "0123456789ABCDEFGAH".rfind("H", 13) == -1
+  assert "0123456789ABCDEFGAH".rfind({'A'..'C'}) == 17
+  assert "0123456789ABCDEFGAH".rfind({'A'..'C'}, 13) == 12
+  assert "0123456789ABCDEFGAH".rfind({'G'..'H'}, 13) == -1
+
 testDelete()
+testFind()
+testRFind()
 
 assert(insertSep($1000_000) == "1_000_000")
 assert(insertSep($232) == "232")