summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--changelog.md4
-rw-r--r--compiler/ccgstmts.nim37
-rw-r--r--compiler/cgen.nim2
-rw-r--r--compiler/condsyms.nim1
-rw-r--r--compiler/dfa.nim2
-rw-r--r--compiler/jsgen.nim8
-rw-r--r--compiler/optimizer.nim2
-rw-r--r--compiler/sempass2.nim2
-rw-r--r--compiler/semstmts.nim2
-rw-r--r--compiler/semtypes.nim4
-rw-r--r--doc/manual.md5
-rw-r--r--lib/system/strmantle.nim23
-rw-r--r--tests/casestmt/tcstring.nim52
13 files changed, 126 insertions, 18 deletions
diff --git a/changelog.md b/changelog.md
index 61e67161e..d2a3544a4 100644
--- a/changelog.md
+++ b/changelog.md
@@ -136,6 +136,10 @@
   let foo: seq[(float, byte, cstring)] = @[(1, 2, "abc")]
   ```
 
+- `cstring` is now accepted as a selector in `case` statements, removing the
+  need to convert to `string`. On the JS backend, this is translated directly
+  to a `switch` statement.
+
 ## Compiler changes
 
 - The `gc` switch has been renamed to `mm` ("memory management") in order to reflect the
diff --git a/compiler/ccgstmts.nim b/compiler/ccgstmts.nim
index 0d17031df..0980de98f 100644
--- a/compiler/ccgstmts.nim
+++ b/compiler/ccgstmts.nim
@@ -837,17 +837,27 @@ template genCaseGeneric(p: BProc, t: PNode, d: var TLoc,
   fixLabel(p, lend)
 
 proc genCaseStringBranch(p: BProc, b: PNode, e: TLoc, labl: TLabel,
+                         stringKind: TTypeKind,
                          branches: var openArray[Rope]) =
   var x: TLoc
   for i in 0..<b.len - 1:
     assert(b[i].kind != nkRange)
     initLocExpr(p, b[i], x)
-    assert(b[i].kind in {nkStrLit..nkTripleStrLit})
-    var j = int(hashString(p.config, b[i].strVal) and high(branches))
-    appcg(p.module, branches[j], "if (#eqStrings($1, $2)) goto $3;$n",
+    var j: int
+    case b[i].kind
+    of nkStrLit..nkTripleStrLit:
+      j = int(hashString(p.config, b[i].strVal) and high(branches))
+    of nkNilLit: j = 0
+    else:
+      assert false, "invalid string case branch node kind"
+    if stringKind == tyCstring:
+      appcg(p.module, branches[j], "if (#eqCstrings($1, $2)) goto $3;$n",
+         [rdLoc(e), rdLoc(x), labl])
+    else:
+      appcg(p.module, branches[j], "if (#eqStrings($1, $2)) goto $3;$n",
          [rdLoc(e), rdLoc(x), labl])
 
-proc genStringCase(p: BProc, t: PNode, d: var TLoc) =
+proc genStringCase(p: BProc, t: PNode, stringKind: TTypeKind, d: var TLoc) =
   # count how many constant strings there are in the case:
   var strings = 0
   for i in 1..<t.len:
@@ -863,13 +873,17 @@ proc genStringCase(p: BProc, t: PNode, d: var TLoc) =
       inc(p.labels)
       if t[i].kind == nkOfBranch:
         genCaseStringBranch(p, t[i], a, "LA" & rope(p.labels) & "_",
-                            branches)
+                            stringKind, branches)
       else:
         # else statement: nothing to do yet
         # but we reserved a label, which we use later
         discard
-    linefmt(p, cpsStmts, "switch (#hashString($1) & $2) {$n",
-            [rdLoc(a), bitMask])
+    if stringKind == tyCstring:
+      linefmt(p, cpsStmts, "switch (#hashCstring($1) & $2) {$n",
+              [rdLoc(a), bitMask])
+    else:
+      linefmt(p, cpsStmts, "switch (#hashString($1) & $2) {$n",
+              [rdLoc(a), bitMask])
     for j in 0..high(branches):
       if branches[j] != nil:
         lineF(p, cpsStmts, "case $1: $n$2break;$n",
@@ -881,7 +895,10 @@ proc genStringCase(p: BProc, t: PNode, d: var TLoc) =
     var lend = genCaseSecondPass(p, t, d, labId, t.len-1)
     fixLabel(p, lend)
   else:
-    genCaseGeneric(p, t, d, "", "if (#eqStrings($1, $2)) goto $3;$n")
+    if stringKind == tyCstring:
+      genCaseGeneric(p, t, d, "", "if (#eqCstrings($1, $2)) goto $3;$n")
+    else:
+      genCaseGeneric(p, t, d, "", "if (#eqStrings($1, $2)) goto $3;$n")
 
 proc branchHasTooBigRange(b: PNode): bool =
   for it in b:
@@ -954,7 +971,9 @@ proc genCase(p: BProc, t: PNode, d: var TLoc) =
     getTemp(p, t.typ, d)
   case skipTypes(t[0].typ, abstractVarRange).kind
   of tyString:
-    genStringCase(p, t, d)
+    genStringCase(p, t, tyString, d)
+  of tyCstring:
+    genStringCase(p, t, tyCstring, d)
   of tyFloat..tyFloat128:
     genCaseGeneric(p, t, d, "if ($1 >= $2 && $1 <= $3) goto $4;$n",
                             "if ($1 == $2) goto $3;$n")
diff --git a/compiler/cgen.nim b/compiler/cgen.nim
index c0b94ebbb..3896a46ca 100644
--- a/compiler/cgen.nim
+++ b/compiler/cgen.nim
@@ -966,7 +966,7 @@ proc allPathsAsgnResult(n: PNode): InitResultEnum =
     if containsResult(n[0]): return InitRequired
     result = InitSkippable
     var exhaustive = skipTypes(n[0].typ,
-        abstractVarRange-{tyTypeDesc}).kind notin {tyFloat..tyFloat128, tyString}
+        abstractVarRange-{tyTypeDesc}).kind notin {tyFloat..tyFloat128, tyString, tyCstring}
     for i in 1..<n.len:
       let it = n[i]
       allPathsInBranch(it.lastSon)
diff --git a/compiler/condsyms.nim b/compiler/condsyms.nim
index f1d16c4e4..d811f2d75 100644
--- a/compiler/condsyms.nim
+++ b/compiler/condsyms.nim
@@ -142,3 +142,4 @@ proc initDefines*(symbols: StringTableRef) =
   defineSymbol("nimHasEnforceNoRaises")
   defineSymbol("nimHasTopDownInference")
   defineSymbol("nimHasTemplateRedefinitionPragma")
+  defineSymbol("nimHasCstringCase")
diff --git a/compiler/dfa.nim b/compiler/dfa.nim
index 5b048ff6e..669207151 100644
--- a/compiler/dfa.nim
+++ b/compiler/dfa.nim
@@ -461,7 +461,7 @@ proc genCase(c: var Con; n: PNode) =
   #    elsePart
   #  Lend:
   let isExhaustive = skipTypes(n[0].typ,
-    abstractVarRange-{tyTypeDesc}).kind notin {tyFloat..tyFloat128, tyString}
+    abstractVarRange-{tyTypeDesc}).kind notin {tyFloat..tyFloat128, tyString, tyCstring}
 
   # we generate endings as a set of chained gotos, this is a bit awkward but it
   # ensures when recursively traversing the CFG for various analysis, we don't
diff --git a/compiler/jsgen.nim b/compiler/jsgen.nim
index bbc64a2f0..a221b3127 100644
--- a/compiler/jsgen.nim
+++ b/compiler/jsgen.nim
@@ -873,8 +873,9 @@ proc genCaseJS(p: PProc, n: PNode, r: var TCompRes) =
     totalRange = 0
   genLineDir(p, n)
   gen(p, n[0], cond)
-  let stringSwitch = skipTypes(n[0].typ, abstractVar).kind == tyString
-  if stringSwitch:
+  let typeKind = skipTypes(n[0].typ, abstractVar).kind
+  let anyString = typeKind in {tyString, tyCstring}
+  if typeKind == tyString:
     useMagic(p, "toJSStr")
     lineF(p, "switch (toJSStr($1)) {$n", [cond.rdLoc])
   else:
@@ -899,10 +900,11 @@ proc genCaseJS(p: PProc, n: PNode, r: var TCompRes) =
             lineF(p, "case $1:$n", [cond.rdLoc])
             inc(v.intVal)
         else:
-          if stringSwitch:
+          if anyString:
             case e.kind
             of nkStrLit..nkTripleStrLit: lineF(p, "case $1:$n",
                 [makeJSString(e.strVal, false)])
+            of nkNilLit: lineF(p, "case null:$n", [])
             else: internalError(p.config, e.info, "jsgen.genCaseStmt: 2")
           else:
             gen(p, e, cond)
diff --git a/compiler/optimizer.nim b/compiler/optimizer.nim
index f484fdbf5..10b092e11 100644
--- a/compiler/optimizer.nim
+++ b/compiler/optimizer.nim
@@ -178,7 +178,7 @@ proc analyse(c: var Con; b: var BasicBlock; n: PNode) =
 
   of nkCaseStmt:
     let isExhaustive = skipTypes(n[0].typ,
-      abstractVarRange-{tyTypeDesc}).kind notin {tyFloat..tyFloat128, tyString} or
+      abstractVarRange-{tyTypeDesc}).kind notin {tyFloat..tyFloat128, tyString, tyCstring} or
       n[^1].kind == nkElse
 
     analyse(c, b, n[0])
diff --git a/compiler/sempass2.nim b/compiler/sempass2.nim
index ae6d638e4..df767630d 100644
--- a/compiler/sempass2.nim
+++ b/compiler/sempass2.nim
@@ -684,7 +684,7 @@ proc trackCase(tracked: PEffects, n: PNode) =
   let oldState = tracked.init.len
   let oldFacts = tracked.guards.s.len
   let stringCase = n[0].typ != nil and skipTypes(n[0].typ,
-        abstractVarRange-{tyTypeDesc}).kind in {tyFloat..tyFloat128, tyString}
+        abstractVarRange-{tyTypeDesc}).kind in {tyFloat..tyFloat128, tyString, tyCstring}
   let interesting = not stringCase and interestingCaseExpr(n[0]) and
         tracked.config.hasWarn(warnProveField)
   var inter: TIntersection = @[]
diff --git a/compiler/semstmts.nim b/compiler/semstmts.nim
index 43687cd2e..6ced487ce 100644
--- a/compiler/semstmts.nim
+++ b/compiler/semstmts.nim
@@ -1067,7 +1067,7 @@ proc semCase(c: PContext, n: PNode; flags: TExprFlags; expectedType: PType = nil
   of tyRange:
     if skipTypes(caseTyp[0], abstractInst).kind in shouldChckCovered:
       chckCovered = true
-  of tyFloat..tyFloat128, tyString, tyError:
+  of tyFloat..tyFloat128, tyString, tyCstring, tyError:
     discard
   else:
     popCaseContext(c)
diff --git a/compiler/semtypes.nim b/compiler/semtypes.nim
index 8ddc2196a..e807e0870 100644
--- a/compiler/semtypes.nim
+++ b/compiler/semtypes.nim
@@ -603,7 +603,9 @@ proc semCaseBranch(c: PContext, t, branch: PNode, branchIndex: int,
         checkMinSonsLen(t, 1, c.config)
         var tmp = fitNode(c, t[0].typ, r, r.info)
         # the call to fitNode may introduce a call to a converter
-        if tmp.kind in {nkHiddenCallConv}: tmp = semConstExpr(c, tmp)
+        if tmp.kind == nkHiddenCallConv or
+            (tmp.kind == nkHiddenStdConv and t[0].typ.kind == tyCstring):
+          tmp = semConstExpr(c, tmp)
         branch[i] = skipConv(tmp)
         inc(covered)
       else:
diff --git a/doc/manual.md b/doc/manual.md
index 49539a48b..2bb91f745 100644
--- a/doc/manual.md
+++ b/doc/manual.md
@@ -1432,6 +1432,8 @@ it can be modified:
   s[0] = 'u' # This is ok
   ```
 
+`cstring` values may also be used in case statements like strings.
+
 Structured types
 ----------------
 A variable of a structured type can hold multiple values at the same
@@ -3098,6 +3100,9 @@ This holds only for expressions of ordinal types.
 "All possible values" of `expr` are determined by `expr`'s type.
 To suppress the static error an `else: discard` should be used.
 
+Only ordinal types, floats, strings and cstrings are allowed as values
+in case statements.
+
 For non-ordinal types, it is not possible to list every possible value and so
 these always require an `else` part.
 An exception to this rule is for the `string` type, which currently doesn't
diff --git a/lib/system/strmantle.nim b/lib/system/strmantle.nim
index 9cf4f9e55..feaac7817 100644
--- a/lib/system/strmantle.nim
+++ b/lib/system/strmantle.nim
@@ -43,6 +43,29 @@ proc hashString(s: string): int {.compilerproc.} =
   h = h + h shl 15
   result = cast[int](h)
 
+proc eqCstrings(a, b: cstring): bool {.inline, compilerproc.} =
+  if pointer(a) == pointer(b): result = true
+  elif a.isNil or b.isNil: result = false
+  else: result = c_strcmp(a, b) == 0
+
+proc hashCstring(s: cstring): int {.compilerproc.} =
+  # the compiler needs exactly the same hash function!
+  # this used to be used for efficient generation of cstring case statements
+  if s.isNil: return 0
+  var h : uint = 0
+  var i = 0
+  while true:
+    let c = s[i]
+    if c == '\0': break
+    h = h + uint(c)
+    h = h + h shl 10
+    h = h xor (h shr 6)
+    inc i
+  h = h + h shl 3
+  h = h xor (h shr 11)
+  h = h + h shl 15
+  result = cast[int](h)
+
 proc c_strtod(buf: cstring, endptr: ptr cstring): float64 {.
   importc: "strtod", header: "<stdlib.h>", noSideEffect.}
 
diff --git a/tests/casestmt/tcstring.nim b/tests/casestmt/tcstring.nim
new file mode 100644
index 000000000..288373402
--- /dev/null
+++ b/tests/casestmt/tcstring.nim
@@ -0,0 +1,52 @@
+discard """
+  targets: "c cpp js"
+"""
+
+type Result = enum none, a, b, c, d, e, f
+
+proc foo1(x: cstring): Result =
+  const y = cstring"hash"
+  const arr = [cstring"it", cstring"finally"]
+  result = none
+  case x
+  of "Andreas", "Rumpf": result = a
+  of cstring"aa", "bb": result = b
+  of "cc", y, "when": result = c
+  of "will", arr, "be", "generated": result = d
+  of nil: result = f
+
+var results = [
+  foo1("Rumpf"), foo1("Andreas"),
+  foo1("aa"), foo1(cstring"bb"),
+  foo1("cc"), foo1("hash"),
+  foo1("finally"), foo1("generated"),
+  foo1("no"), foo1("another no"),
+  foo1(nil)]
+doAssert results == [a, a, b, b, c, c, d, d, none, none, f], $results
+
+proc foo2(x: cstring): Result =
+  const y = cstring"hash"
+  const arr = [cstring"it", cstring"finally"]
+  doAssert not (compiles do:
+    result = case x
+    of "Andreas", "Rumpf": a
+    of cstring"aa", "bb": b
+    of "cc", y, "when": c
+    of "will", arr, "be", "generated": d)
+  case x
+  of "Andreas", "Rumpf": a
+  of cstring"aa", "bb": b
+  of "cc", y, "when": c
+  of "will", arr, "be", "generated": d
+  of nil: f
+  else: e
+
+results = [
+  foo2("Rumpf"), foo2("Andreas"),
+  foo2("aa"), foo2(cstring"bb"),
+  foo2("cc"), foo2("hash"),
+  foo2("finally"), foo2("generated"),
+  foo2("no"), foo2("another no"),
+  foo2(nil)]
+
+doAssert results == [a, a, b, b, c, c, d, d, e, e, f], $results