summary refs log tree commit diff stats
path: root/compiler/nir/stringcases.nim
diff options
context:
space:
mode:
authorAndreas Rumpf <rumpf_a@web.de>2023-10-16 00:01:33 +0200
committerGitHub <noreply@github.com>2023-10-16 00:01:33 +0200
commit10c3ab626941d9c1ec69a2921696f4b7d0c2a6ac (patch)
treec1b8b2117bbb86101f54269ff2e66816ce4b8296 /compiler/nir/stringcases.nim
parent5f400983d588ec91a688453ce69cec94d310cc70 (diff)
downloadNim-10c3ab626941d9c1ec69a2921696f4b7d0c2a6ac.tar.gz
NIR: store sizes, alignments and offsets in the type graph; beginning… (#22822)
…s of a patent-pending new VM
Diffstat (limited to 'compiler/nir/stringcases.nim')
-rw-r--r--compiler/nir/stringcases.nim200
1 files changed, 200 insertions, 0 deletions
diff --git a/compiler/nir/stringcases.nim b/compiler/nir/stringcases.nim
new file mode 100644
index 000000000..9417d613d
--- /dev/null
+++ b/compiler/nir/stringcases.nim
@@ -0,0 +1,200 @@
+#
+#
+#           The Nim Compiler
+#        (c) Copyright 2023 Andreas Rumpf
+#
+#    See the file "copying.txt", included in this
+#    distribution, for details about the copyright.
+#
+
+## included from ast2ir.nim
+
+#[
+
+case s
+of "abc", "abbd":
+  echo 1
+of "hah":
+  echo 2
+of "gah":
+  echo 3
+
+# we produce code like this:
+
+if s[0] <= 'a':
+  if s == "abc: goto L1
+  elif s == "abbd": goto L1
+else:
+  if s[2] <= 'h':
+    if s == "hah": goto L2
+    elif s == "gah": goto L3
+goto afterCase
+
+L1:
+  echo 1
+  goto afterCase
+L2:
+  echo 2
+  goto afterCase
+L3:
+  echo 3
+  goto afterCase
+
+afterCase: ...
+
+]#
+
+# We split the set of strings into 2 sets of roughly the same size.
+# The condition used for splitting is a (position, char) tuple.
+# Every string of length > position for which s[position] <= char is in one
+# set else it is in the other set.
+
+from sequtils import addUnique
+
+type
+  Key = (LitId, LabelId)
+
+proc splitValue(strings: BiTable[string]; a: openArray[Key]; position: int): (char, float) =
+  var cand: seq[char] = @[]
+  for t in items a:
+    let s = strings[t[0]]
+    if s.len > position: cand.addUnique s[position]
+
+  result = ('\0', -1.0)
+  for disc in items cand:
+    var hits = 0
+    for t in items a:
+      let s = strings[t[0]]
+      if s.len > position and s[position] <= disc:
+        inc hits
+    # the split is the better, the more `hits` is close to `a.len / 2`:
+    let grade = 100000.0 - abs(hits.float - a.len.float / 2.0)
+    if grade > result[1]:
+      result = (disc, grade)
+
+proc tryAllPositions(strings: BiTable[string]; a: openArray[Key]): (char, int) =
+  var m = 0
+  for t in items a:
+    m = max(m, strings[t[0]].len)
+
+  result = ('\0', -1)
+  var best = -1.0
+  for i in 0 ..< m:
+    let current = splitValue(strings, a, i)
+    if current[1] > best:
+      best = current[1]
+      result = (current[0], i)
+
+type
+  SearchKind = enum
+    LinearSearch, SplitSearch
+  SearchResult* = object
+    case kind: SearchKind
+    of LinearSearch:
+      a: seq[Key]
+    of SplitSearch:
+      span: int
+      best: (char, int)
+
+proc emitLinearSearch(strings: BiTable[string]; a: openArray[Key]; dest: var seq[SearchResult]) =
+  var d = SearchResult(kind: LinearSearch, a: @[])
+  for x in a: d.a.add x
+  dest.add d
+
+proc split(strings: BiTable[string]; a: openArray[Key]; dest: var seq[SearchResult]) =
+  if a.len <= 4:
+    emitLinearSearch strings, a, dest
+  else:
+    let best = tryAllPositions(strings, a)
+    var groupA: seq[Key] = @[]
+    var groupB: seq[Key] = @[]
+    for t in items a:
+      let s = strings[t[0]]
+      if s.len > best[1] and s[best[1]] <= best[0]:
+        groupA.add t
+      else:
+        groupB.add t
+    if groupA.len == 0 or groupB.len == 0:
+      emitLinearSearch strings, a, dest
+    else:
+      let toPatch = dest.len
+      dest.add SearchResult(kind: SplitSearch, span: 1, best: best)
+      split strings, groupA, dest
+      split strings, groupB, dest
+      let dist = dest.len - toPatch
+      assert dist > 0
+      dest[toPatch].span = dist
+
+proc toProblemDescription(c: var ProcCon; n: PNode): (seq[Key], LabelId) =
+  result = (@[], newLabels(c.labelGen, n.len))
+  assert n.kind == nkCaseStmt
+  for i in 1..<n.len:
+    let it = n[i]
+    let thisBranch = LabelId(result[1].int + i - 1)
+    if it.kind == nkOfBranch:
+      for j in 0..<it.len-1:
+        assert it[j].kind in {nkStrLit..nkTripleStrLit}
+        result[0].add (c.lit.strings.getOrIncl(it[j].strVal), thisBranch)
+
+proc decodeSolution(c: var ProcCon; dest: var Tree; s: seq[SearchResult]; i: int;
+                    selector: Value; info: PackedLineInfo) =
+  case s[i].kind
+  of SplitSearch:
+    let thenA = i+1
+    let elseA = thenA + (if s[thenA].kind == LinearSearch: 1 else: s[thenA].span)
+    let best = s[i].best
+
+    let tmp = getTemp(c, Bool8Id, info)
+    buildTyped dest, info, Asgn, Bool8Id:
+      dest.copyTree tmp
+      buildTyped dest, info, Call, Bool8Id:
+        c.addUseCodegenProc dest, "nimStrAtLe", info
+        dest.copyTree selector
+        dest.addIntVal c.lit.numbers, info, c.m.nativeIntId, best[1]
+        dest.addIntVal c.lit.numbers, info, Char8Id, best[0].int
+
+    template then() =
+      c.decodeSolution dest, s, thenA, selector, info
+    template otherwise() =
+      c.decodeSolution dest, s, elseA, selector, info
+    buildIfThenElse tmp, then, otherwise
+    freeTemp c, tmp
+
+  of LinearSearch:
+    let tmp = getTemp(c, Bool8Id, info)
+    for x in s[i].a:
+      buildTyped dest, info, Asgn, Bool8Id:
+        dest.copyTree tmp
+        buildTyped dest, info, Call, Bool8Id:
+          c.addUseCodegenProc dest, "eqStrings", info
+          dest.copyTree selector
+          dest.addStrLit info, x[0]
+      buildIf tmp:
+        c.code.gotoLabel info, Goto, x[1]
+    freeTemp c, tmp
+
+proc genStringCase(c: var ProcCon; n: PNode; d: var Value) =
+  let (problem, firstBranch) = toProblemDescription(c, n)
+  var solution: seq[SearchResult] = @[]
+  split c.lit.strings, problem, solution
+
+  # XXX Todo move complex case selector into a temporary.
+  let selector = c.genx(n[0])
+
+  let info = toLineInfo(c, n.info)
+  decodeSolution c, c.code, solution, 0, selector, info
+
+  let lend = newLabel(c.labelGen)
+  c.code.addLabel info, Goto, lend
+  for i in 1..<n.len:
+    let it = n[i]
+    let thisBranch = LabelId(firstBranch.int + i - 1)
+    c.code.addLabel info, Label, thisBranch
+    if it.kind == nkOfBranch:
+      gen(c, it.lastSon, d)
+      c.code.addLabel info, Goto, lend
+    else:
+      gen(c, it.lastSon, d)
+
+  c.code.addLabel info, Label, lend
+  freeTemp c, selector