summary refs log tree commit diff stats
path: root/compiler/nir/stringcases.nim
diff options
context:
space:
mode:
authorAndreas Rumpf <rumpf_a@web.de>2023-10-16 00:01:33 +0200
committerGitHub <noreply@github.com>2023-10-16 00:01:33 +0200
commit10c3ab626941d9c1ec69a2921696f4b7d0c2a6ac (patch)
treec1b8b2117bbb86101f54269ff2e66816ce4b8296 /compiler/nir/stringcases.nim
parent5f400983d588ec91a688453ce69cec94d310cc70 (diff)
downloadNim-10c3ab626941d9c1ec69a2921696f4b7d0c2a6ac.tar.gz
NIR: store sizes, alignments and offsets in the type graph; beginning… (#22822)
…s of a patent-pending new VM
Diffstat (limited to 'compiler/nir/stringcases.nim')
-rw-r--r--compiler/nir/stringcases.nim200
1 files changed, 200 insertions, 0 deletions
diff --git a/compiler/nir/stringcases.nim b/compiler/nir/stringcases.nim
new file mode 100644
index 000000000..9417d613d
--- /dev/null
+++ b/compiler/nir/stringcases.nim
@@ -0,0 +1,200 @@
+#
+#
+#           The Nim Compiler
+#        (c) Copyright 2023 Andreas Rumpf
+#
+#    See the file "copying.txt", included in this
+#    distribution, for details about the copyright.
+#
+
+## included from ast2ir.nim
+
+#[
+
+case s
+of "abc", "abbd":
+  echo 1
+of "hah":
+  echo 2
+of "gah":
+  echo 3
+
+# we produce code like this:
+
+if s[0] <= 'a':
+  if s == "abc: goto L1
+  elif s == "abbd": goto L1
+else:
+  if s[2] <= 'h':
+    if s == "hah": goto L2
+    elif s == "gah": goto L3
+goto afterCase
+
+L1:
+  echo 1
+  goto afterCase
+L2:
+  echo 2
+  goto afterCase
+L3:
+  echo 3
+  goto afterCase
+
+afterCase: ...
+
+]#
+
+# We split the set of strings into 2 sets of roughly the same size.
+# The condition used for splitting is a (position, char) tuple.
+# Every string of length > position for which s[position] <= char is in one
+# set else it is in the other set.
+
+from sequtils import addUnique
+
+type
+  Key = (LitId, LabelId)
+
+proc splitValue(strings: BiTable[string]; a: openArray[Key]; position: int): (char, float) =
+  var cand: seq[char] = @[]
+  for t in items a:
+    let s = strings[t[0]]
+    if s.len > position: cand.addUnique s[position]
+
+  result = ('\0', -1.0)
+  for disc in items cand:
+    var hits = 0
+    for t in items a:
+      let s = strings[t[0]]
+      if s.len > position and s[position] <= disc:
+        inc hits
+    # the split is the better, the more `hits` is close to `a.len / 2`:
+    let grade = 100000.0 - abs(hits.float - a.len.float / 2.0)
+    if grade > result[1]:
+      result = (disc, grade)
+
+proc tryAllPositions(strings: BiTable[string]; a: openArray[Key]): (char, int) =
+  var m = 0
+  for t in items a:
+    m = max(m, strings[t[0]].len)
+
+  result = ('\0', -1)
+  var best = -1.0
+  for i in 0 ..< m:
+    let current = splitValue(strings, a, i)
+    if current[1] > best:
+      best = current[1]
+      result = (current[0], i)
+
+type
+  SearchKind = enum
+    LinearSearch, SplitSearch
+  SearchResult* = object
+    case kind: SearchKind
+    of LinearSearch:
+      a: seq[Key]
+    of SplitSearch:
+      span: int
+      best: (char, int)
+
+proc emitLinearSearch(strings: BiTable[string]; a: openArray[Key]; dest: var seq[SearchResult]) =
+  var d = SearchResult(kind: LinearSearch, a: @[])
+  for x in a: d.a.add x
+  dest.add d
+
+proc split(strings: BiTable[string]; a: openArray[Key]; dest: var seq[SearchResult]) =
+  if a.len <= 4:
+    emitLinearSearch strings, a, dest
+  else:
+    let best = tryAllPositions(strings, a)
+    var groupA: seq[Key] = @[]
+    var groupB: seq[Key] = @[]
+    for t in items a:
+      let s = strings[t[0]]
+      if s.len > best[1] and s[best[1]] <= best[0]:
+        groupA.add t
+      else:
+        groupB.add t
+    if groupA.len == 0 or groupB.len == 0:
+      emitLinearSearch strings, a, dest
+    else:
+      let toPatch = dest.len
+      dest.add SearchResult(kind: SplitSearch, span: 1, best: best)
+      split strings, groupA, dest
+      split strings, groupB, dest
+      let dist = dest.len - toPatch
+      assert dist > 0
+      dest[toPatch].span = dist
+
+proc toProblemDescription(c: var ProcCon; n: PNode): (seq[Key], LabelId) =
+  result = (@[], newLabels(c.labelGen, n.len))
+  assert n.kind == nkCaseStmt
+  for i in 1..<n.len:
+    let it = n[i]
+    let thisBranch = LabelId(result[1].int + i - 1)
+    if it.kind == nkOfBranch:
+      for j in 0..<it.len-1:
+        assert it[j].kind in {nkStrLit..nkTripleStrLit}
+        result[0].add (c.lit.strings.getOrIncl(it[j].strVal), thisBranch)
+
+proc decodeSolution(c: var ProcCon; dest: var Tree; s: seq[SearchResult]; i: int;
+                    selector: Value; info: PackedLineInfo) =
+  case s[i].kind
+  of SplitSearch:
+    let thenA = i+1
+    let elseA = thenA + (if s[thenA].kind == LinearSearch: 1 else: s[thenA].span)
+    let best = s[i].best
+
+    let tmp = getTemp(c, Bool8Id, info)
+    buildTyped dest, info, Asgn, Bool8Id:
+      dest.copyTree tmp
+      buildTyped dest, info, Call, Bool8Id:
+        c.addUseCodegenProc dest, "nimStrAtLe", info
+        dest.copyTree selector
+        dest.addIntVal c.lit.numbers, info, c.m.nativeIntId, best[1]
+        dest.addIntVal c.lit.numbers, info, Char8Id, best[0].int
+
+    template then() =
+      c.decodeSolution dest, s, thenA, selector, info
+    template otherwise() =
+      c.decodeSolution dest, s, elseA, selector, info
+    buildIfThenElse tmp, then, otherwise
+    freeTemp c, tmp
+
+  of LinearSearch:
+    let tmp = getTemp(c, Bool8Id, info)
+    for x in s[i].a:
+      buildTyped dest, info, Asgn, Bool8Id:
+        dest.copyTree tmp
+        buildTyped dest, info, Call, Bool8Id:
+          c.addUseCodegenProc dest, "eqStrings", info
+          dest.copyTree selector
+          dest.addStrLit info, x[0]
+      buildIf tmp:
+        c.code.gotoLabel info, Goto, x[1]
+    freeTemp c, tmp
+
+proc genStringCase(c: var ProcCon; n: PNode; d: var Value) =
+  let (problem, firstBranch) = toProblemDescription(c, n)
+  var solution: seq[SearchResult] = @[]
+  split c.lit.strings, problem, solution
+
+  # XXX Todo move complex case selector into a temporary.
+  let selector = c.genx(n[0])
+
+  let info = toLineInfo(c, n.info)
+  decodeSolution c, c.code, solution, 0, selector, info
+
+  let lend = newLabel(c.labelGen)
+  c.code.addLabel info, Goto, lend
+  for i in 1..<n.len:
+    let it = n[i]
+    let thisBranch = LabelId(firstBranch.int + i - 1)
+    c.code.addLabel info, Label, thisBranch
+    if it.kind == nkOfBranch:
+      gen(c, it.lastSon, d)
+      c.code.addLabel info, Goto, lend
+    else:
+      gen(c, it.lastSon, d)
+
+  c.code.addLabel info, Label, lend
+  freeTemp c, selector
'#n429'>429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488