summary refs log tree commit diff stats
path: root/lib/std/enumutils.nim
diff options
context:
space:
mode:
Diffstat (limited to 'lib/std/enumutils.nim')
-rw-r--r--lib/std/enumutils.nim17
1 files changed, 12 insertions, 5 deletions
diff --git a/lib/std/enumutils.nim b/lib/std/enumutils.nim
index 0386c2589..9c338817d 100644
--- a/lib/std/enumutils.nim
+++ b/lib/std/enumutils.nim
@@ -7,8 +7,8 @@
 #    distribution, for details about the copyright.
 #
 
-import macros
-from typetraits import OrdinalEnum, HoleyEnum
+import std/macros
+from std/typetraits import OrdinalEnum, HoleyEnum
 
 when defined(nimPreviewSlimSystem):
   import std/assertions
@@ -22,7 +22,8 @@ macro genEnumCaseStmt*(typ: typedesc, argSym: typed, default: typed,
   # a normalized string comparison to the `argSym` input.
   # string normalization is done using passed normalizer.
   let typ = typ.getTypeInst[1]
-  let impl = typ.getImpl[2]
+  let typSym = typ.getTypeImpl.getTypeInst # skip aliases etc to get type sym
+  let impl = typSym.getImpl[2]
   expectKind impl, nnkEnumTy
   let normalizerNode = quote: `normalizer`
   expectKind normalizerNode, nnkSym
@@ -81,7 +82,7 @@ macro genEnumCaseStmt*(typ: typedesc, argSym: typed, default: typed,
     result.add nnkElse.newTree(default)
 
 macro enumFullRange(a: typed): untyped =
-  newNimNode(nnkCurly).add(a.getType[1][1..^1])
+  newNimNode(nnkBracket).add(a.getType[1][1..^1])
 
 macro enumNames(a: typed): untyped =
   # this could be exported too; in particular this could be useful for enum with holes.
@@ -173,6 +174,9 @@ template symbolRank*[T: enum](a: T): int =
   when T is Ordinal: ord(a) - T.low.ord.static
   else: symbolRankImpl(a)
 
+proc rangeBase(T: typedesc): typedesc {.magic: "TypeTrait".}
+  # skip one level of range; return the base type of a range type
+
 func symbolName*[T: enum](a: T): string =
   ## Returns the symbol name of an enum.
   ##
@@ -191,5 +195,8 @@ func symbolName*[T: enum](a: T): string =
       c1 = 4
       c2 = 20
     assert c1.symbolName == "c1"
-  const names = enumNames(T)
+  when T is range:
+    const names = enumNames(rangeBase T)
+  else:
+    const names = enumNames(T)
   names[a.symbolRank]