about summary refs log tree commit diff stats
path: root/src/js/regex.nim
blob: d73e6e2bdcbecdefd668a4eadc7735ed307f18b6 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
# Interface for QuickJS libregexp.
import unicode

import bindings/libregexp
import bindings/quickjs
import types/opt
import utils/twtstr

export
  LRE_FLAG_GLOBAL,
  LRE_FLAG_IGNORECASE,
  LRE_FLAG_MULTILINE,
  LRE_FLAG_DOTALL,
  LRE_FLAG_UTF16,
  LRE_FLAG_STICKY

type
  Regex* = object
    bytecode: seq[uint8]
    buf: string

  RegexResult* = object
    success*: bool
    captures*: seq[tuple[s, e: int]] # start, end

  RegexReplace* = object
    regex: Regex
    rule: string
    global: bool

var dummyRuntime = JS_NewRuntime()
var dummyContext = JS_NewContextRaw(dummyRuntime)

func `$`*(regex: Regex): string =
  regex.buf

proc compileRegex*(buf: string, flags: int): Result[Regex, string] =
  var error_msg_size = 64
  var error_msg = newString(error_msg_size)
  prepareMutation(error_msg)
  var plen: cint
  let bytecode = lre_compile(addr plen, cstring(error_msg),
    cint(error_msg_size), cstring(buf), csize_t(buf.len), cint(flags),
    dummyContext)
  if bytecode == nil:
    return err(error_msg.until('\0')) # Failed to compile.
  assert plen > 0
  var bcseq = newSeqUninitialized[uint8](plen)
  copyMem(addr bcseq[0], bytecode, plen)
  dummyRuntime.js_free_rt(bytecode)
  let regex = Regex(
    buf: buf,
    bytecode: bcseq
  )
  return ok(regex)

func countBackslashes(buf: string, i: int): int =
  var j = 0
  for i in countdown(i, 0):
    if buf[i] != '\\':
      break
    inc j
  return j

# ^abcd -> ^abcd
# efgh$ -> efgh$
# ^ijkl$ -> ^ijkl$
# mnop -> ^mnop$
proc compileMatchRegex*(buf: string): Result[Regex, string] =
  if buf.len == 0:
    return compileRegex(buf, 0)
  if buf[0] == '^':
    return compileRegex(buf, 0)
  if buf[^1] == '$':
    # Check whether the final dollar sign is escaped.
    if buf.len == 1 or buf[^2] != '\\':
      return compileRegex(buf, 0)
    let j = buf.countBackslashes(buf.high - 2)
    if j mod 2 == 1: # odd, because we do not count the last backslash
      return compileRegex(buf, 0)
    # escaped. proceed as if no dollar sign was at the end
  if buf[^1] == '\\':
    # Check if the regex contains an invalid trailing backslash.
    let j = buf.countBackslashes(buf.high - 1)
    if j mod 2 != 1: # odd, because we do not count the last backslash
      return err("unexpected end")
  var buf2 = "^"
  buf2 &= buf
  buf2 &= "$"
  return compileRegex(buf2, 0)

proc compileSearchRegex*(str: string): Result[Regex, string] =
  # Parse any applicable flags in regex/<flags>. The last forward slash is
  # dropped when <flags> is empty, and interpreted as a character when the
  # flags are is invalid.

  var i = str.high
  var flagsi = -1
  while i >= 0:
    case str[i]
    of '/':
      flagsi = i
      break
    of 'i', 'm', 's', 'u': discard
    else: break # invalid flag
    dec i

  var flags = LRE_FLAG_GLOBAL # for easy backwards matching

  if flagsi == -1:
    return compileRegex(str, flags)

  for i in flagsi..str.high:
    case str[i]
    of '/': discard
    of 'i': flags = flags or LRE_FLAG_IGNORECASE
    of 'm': flags = flags or LRE_FLAG_MULTILINE
    of 's': flags = flags or LRE_FLAG_DOTALL
    of 'u': flags = flags or LRE_FLAG_UTF16
    else: assert false
  return compileRegex(str.substr(0, flagsi - 1), flags)

proc exec*(regex: Regex, str: string, start = 0, length = -1, nocaps = false): RegexResult =
  let length = if length == -1:
    str.len
  else:
    length
  assert 0 <= start and start <= length

  let bytecode = unsafeAddr regex.bytecode[0]
  let captureCount = lre_get_capture_count(bytecode)
  var capture: ptr UncheckedArray[int] = nil
  if captureCount > 0:
    let size = sizeof(ptr uint8) * captureCount * 2
    capture = cast[ptr UncheckedArray[int]](alloc0(size))
  var cstr = cstring(str)
  let flags = lre_get_flags(bytecode)
  var start = start
  while true:
    let ret = lre_exec(cast[ptr ptr uint8](capture), bytecode,
      cast[ptr uint8](cstr), cint(start), cint(length), cint(0), dummyContext)
    if ret != 1: #TODO error handling? (-1)
      break
    result.success = true
    if captureCount == 0 or nocaps:
      break
    let cstrAddress = cast[int](cstr)
    let ps = start
    start = capture[1] - cstrAddress
    for i in 0 ..< captureCount:
      let s = capture[i * 2] - cstrAddress
      let e = capture[i * 2 + 1] - cstrAddress
      result.captures.add((s, e))
    if (flags and LRE_FLAG_GLOBAL) != 1:
      break
    if start >= str.len:
      break
    if ps == start:
      start += runeLenAt(str, start)
  if captureCount > 0:
    dealloc(capture)

proc match*(regex: Regex, str: string, start = 0, length = str.len): bool =
  return regex.exec(str, start, length, nocaps = true).success