summary refs log tree commit diff stats
path: root/lib/std/enumutils.nim
diff options
context:
space:
mode:
Diffstat (limited to 'lib/std/enumutils.nim')
-rw-r--r--lib/std/enumutils.nim40
1 files changed, 24 insertions, 16 deletions
diff --git a/lib/std/enumutils.nim b/lib/std/enumutils.nim
index 9d4ff1bcf..9c338817d 100644
--- a/lib/std/enumutils.nim
+++ b/lib/std/enumutils.nim
@@ -7,8 +7,8 @@
 #    distribution, for details about the copyright.
 #
 
-import macros
-from typetraits import OrdinalEnum, HoleyEnum
+import std/macros
+from std/typetraits import OrdinalEnum, HoleyEnum
 
 when defined(nimPreviewSlimSystem):
   import std/assertions
@@ -21,32 +21,34 @@ macro genEnumCaseStmt*(typ: typedesc, argSym: typed, default: typed,
   # Generates a case stmt, which assigns the correct enum field given
   # a normalized string comparison to the `argSym` input.
   # string normalization is done using passed normalizer.
-  # NOTE: for an enum with fields Foo, Bar, ... we cannot generate
-  # `of "Foo".nimIdentNormalize: Foo`.
-  # This will fail, if the enum is not defined at top level (e.g. in a block).
-  # Thus we check for the field value of the (possible holed enum) and convert
-  # the integer value to the generic argument `typ`.
   let typ = typ.getTypeInst[1]
-  let impl = typ.getImpl[2]
+  let typSym = typ.getTypeImpl.getTypeInst # skip aliases etc to get type sym
+  let impl = typSym.getImpl[2]
   expectKind impl, nnkEnumTy
   let normalizerNode = quote: `normalizer`
   expectKind normalizerNode, nnkSym
   result = nnkCaseStmt.newTree(newCall(normalizerNode, argSym))
   # stores all processed field strings to give error msg for ambiguous enums
   var foundFields: seq[string] = @[]
+  var fVal = ""
   var fStr = "" # string of current field
   var fNum = BiggestInt(0) # int value of current field
   for f in impl:
     case f.kind
     of nnkEmpty: continue # skip first node of `enumTy`
-    of nnkSym, nnkIdent: fStr = f.strVal
+    of nnkSym, nnkIdent:
+      fVal = f.strVal
+      fStr = fVal
     of nnkAccQuoted:
-      fStr = ""
+      fVal = ""
       for ch in f:
-        fStr.add ch.strVal
+        fVal.add ch.strVal
+      fStr = fVal
     of nnkEnumFieldDef:
+      fVal = f[0].strVal
       case f[1].kind
-      of nnkStrLit: fStr = f[1].strVal
+      of nnkStrLit:
+        fStr = f[1].strVal
       of nnkTupleConstr:
         fStr = f[1][1].strVal
         fNum = f[1][0].intVal
@@ -64,7 +66,7 @@ macro genEnumCaseStmt*(typ: typedesc, argSym: typed, default: typed,
     if fNum >= userMin and fNum <= userMax:
       fStr = normalizer(fStr)
       if fStr notin foundFields:
-        result.add nnkOfBranch.newTree(newLit fStr,  nnkCall.newTree(typ, newLit fNum))
+        result.add nnkOfBranch.newTree(newLit fStr, newDotExpr(typ, ident fVal))
         foundFields.add fStr
       else:
         error("Ambiguous enums cannot be parsed, field " & $fStr &
@@ -80,7 +82,7 @@ macro genEnumCaseStmt*(typ: typedesc, argSym: typed, default: typed,
     result.add nnkElse.newTree(default)
 
 macro enumFullRange(a: typed): untyped =
-  newNimNode(nnkCurly).add(a.getType[1][1..^1])
+  newNimNode(nnkBracket).add(a.getType[1][1..^1])
 
 macro enumNames(a: typed): untyped =
   # this could be exported too; in particular this could be useful for enum with holes.
@@ -112,9 +114,9 @@ const invalidSlot = uint8.high
 
 proc genLookup[T: typedesc[HoleyEnum]](_: T): auto =
   const n = span(T)
-  var ret: array[n, uint8]
   var i = 0
   assert n <= invalidSlot.int
+  var ret {.noinit.}: array[n, uint8]
   for ai in mitems(ret): ai = invalidSlot
   for ai in items(T):
     ret[ai.ord - T.low.ord] = uint8(i)
@@ -172,6 +174,9 @@ template symbolRank*[T: enum](a: T): int =
   when T is Ordinal: ord(a) - T.low.ord.static
   else: symbolRankImpl(a)
 
+proc rangeBase(T: typedesc): typedesc {.magic: "TypeTrait".}
+  # skip one level of range; return the base type of a range type
+
 func symbolName*[T: enum](a: T): string =
   ## Returns the symbol name of an enum.
   ##
@@ -190,5 +195,8 @@ func symbolName*[T: enum](a: T): string =
       c1 = 4
       c2 = 20
     assert c1.symbolName == "c1"
-  const names = enumNames(T)
+  when T is range:
+    const names = enumNames(rangeBase T)
+  else:
+    const names = enumNames(T)
   names[a.symbolRank]