summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--changelog.md3
-rw-r--r--compiler/semstmts.nim21
-rw-r--r--lib/core/macros.nim85
-rw-r--r--tests/pragmas/tcustom_pragma.nim101
4 files changed, 174 insertions, 36 deletions
diff --git a/changelog.md b/changelog.md
index 60e993359..7569bf133 100644
--- a/changelog.md
+++ b/changelog.md
@@ -41,6 +41,9 @@
   now escapes the content of string literals consistently.
 - ``macros.NimSym`` and ``macros.NimIdent`` is now deprecated in favor
   of the more general ``NimNode``.
+- ``macros.getImpl`` now includes the pragmas of types, instead of omitting them.
+- ``macros.hasCustomPragma`` and ``macros.getCustomPragmaVal`` now
+  also support ``ref`` and ``ptr`` types, pragmas on types and variant fields.
 
 ### Language additions
 
diff --git a/compiler/semstmts.nim b/compiler/semstmts.nim
index 3de26344c..85fe3d793 100644
--- a/compiler/semstmts.nim
+++ b/compiler/semstmts.nim
@@ -765,6 +765,15 @@ proc addGenericParamListToScope(c: PContext, n: PNode) =
     if a.kind == nkSym: addDecl(c, a.sym)
     else: illFormedAst(a)
 
+proc typeSectionTypeName(n: PNode): PNode =
+  if n.kind == nkPragmaExpr:
+    if n.len == 0: illFormedAst(n)
+    result = n.sons[0]
+  else:
+    result = n
+  if result.kind != nkSym: illFormedAst(n)
+  
+
 proc typeSectionLeftSidePass(c: PContext, n: PNode) =
   # process the symbols on the left side for the whole type section, before
   # we even look at the type definitions on the right
@@ -825,7 +834,10 @@ proc typeSectionLeftSidePass(c: PContext, n: PNode) =
       # add it here, so that recursive types are possible:
       if sfGenSym notin s.flags: addInterfaceDecl(c, s)
 
-    a.sons[0] = newSymNode(s)
+    if name.kind == nkPragmaExpr:
+      a.sons[0].sons[0] = newSymNode(s)
+    else:
+      a.sons[0] = newSymNode(s)
 
 proc checkCovariantParamsUsages(genericType: PType) =
   var body = genericType[^1]
@@ -914,8 +926,7 @@ proc typeSectionRightSidePass(c: PContext, n: PNode) =
     if a.kind == nkCommentStmt: continue
     if (a.kind != nkTypeDef): illFormedAst(a)
     checkSonsLen(a, 3)
-    let name = a.sons[0]
-    if (name.kind != nkSym): illFormedAst(a)
+    let name = typeSectionTypeName(a.sons[0])
     var s = name.sym
     if s.magic == mNone and a.sons[2].kind == nkEmpty:
       localError(a.info, errImplOfXexpected, s.name.s)
@@ -1021,8 +1032,8 @@ proc typeSectionFinalPass(c: PContext, n: PNode) =
   for i in countup(0, sonsLen(n) - 1):
     var a = n.sons[i]
     if a.kind == nkCommentStmt: continue
-    if a.sons[0].kind != nkSym: illFormedAst(a)
-    var s = a.sons[0].sym
+    let name = typeSectionTypeName(a.sons[0])
+    var s = name.sym
     # compute the type's size and check for illegal recursions:
     if a.sons[1].kind == nkEmpty:
       var x = a[2]
diff --git a/lib/core/macros.nim b/lib/core/macros.nim
index a4c819a34..e71b7cdc8 100644
--- a/lib/core/macros.nim
+++ b/lib/core/macros.nim
@@ -1232,33 +1232,73 @@ macro expandMacros*(body: typed): untyped =
   echo result.toStrLit
 
 proc customPragmaNode(n: NimNode): NimNode =
-  expectKind(n, {nnkSym, nnkDotExpr})
-  if n.kind == nnkSym:
-    let sym = n.getImpl()
-    sym.expectRoutine()
-    result = sym.pragma
-  elif n.kind == nnkDotExpr:
-    let typDef = getImpl(getTypeInst(n[0]))
-    typDef.expectKind(nnkTypeDef)
-    typDef[2].expectKind(nnkObjectTy)
-    let recList = typDef[2][2]
-    for identDefs in recList:
-      for i in 0 .. identDefs.len - 3:
-        if identDefs[i].kind == nnkPragmaExpr and
-           identDefs[i][0].kind == nnkIdent and $identDefs[i][0] == $n[1]:
-          return identDefs[i][1]
+  expectKind(n, {nnkSym, nnkDotExpr, nnkBracketExpr, nnkTypeOfExpr, nnkCheckedFieldExpr})
+  let
+    typ = n.getTypeInst()
+
+  if typ.typeKind == ntyTypeDesc:
+    return typ[1].getImpl()[0][1]
+
+  if n.kind == nnkSym: # either an variable or a proc
+    let impl = n.getImpl()
+    if impl.kind in RoutineNodes:
+      return impl.pragma
+    else:
+      return typ.getImpl()[0][1]
+
+  if n.kind in {nnkDotExpr, nnkCheckedFieldExpr}:
+    let name = (if n.kind == nnkCheckedFieldExpr: n[0][1] else: n[1])
+    var typDef = getImpl(getTypeInst(if n.kind == nnkCheckedFieldExpr or n[0].kind == nnkHiddenDeref: n[0][0] else: n[0]))
+    while typDef != nil:
+      typDef.expectKind(nnkTypeDef)
+      typDef[2].expectKind({nnkRefTy, nnkPtrTy, nnkObjectTy})
+      let isRef = typDef[2].kind in {nnkRefTy, nnkPtrTy}
+      if isRef and typDef[2][0].kind in {nnkSym, nnkBracketExpr}: # defines ref type for another object(e.g. X = ref X)
+        typDef = getImpl(typDef[2][0])
+      else: # object definition, maybe an object directly defined as a ref type
+        let
+          obj = (if isRef: typDef[2][0] else: typDef[2])
+        var identDefsStack = newSeq[NimNode](obj[2].len)
+        for i in 0..<identDefsStack.len: identDefsStack[i] = obj[2][i]
+        while identDefsStack.len > 0:
+          var identDefs = identDefsStack.pop()
+          if identDefs.kind == nnkRecCase:
+            identDefsStack.add(identDefs[0])
+            for i in 1..<identDefs.len:
+              if identDefs[i][1].kind == nnkIdentDefs:
+                identDefsStack.add(identDefs[i][1])
+              else: # nnkRecList
+                for j in 0..<identDefs[i][1].len:
+                  identDefsStack.add(identDefs[i][1][j])
+
+          else:
+            for i in 0 .. identDefs.len - 3:
+              if identDefs[i].kind == nnkPragmaExpr and
+                identDefs[i][0].kind == nnkIdent and $identDefs[i][0] == $name:
+                return identDefs[i][1]
+
+        if obj[1].kind == nnkOfInherit: # explore the parent object
+          typDef = getImpl(obj[1][0])
+        else:
+          typDef = nil
 
 macro hasCustomPragma*(n: typed, cp: typed{nkSym}): untyped =
   ## Expands to `true` if expression `n` which is expected to be `nnkDotExpr`
-  ## has custom pragma `cp`.
+  ## (if checking a field), a proc or a type has custom pragma `cp`.
+  ##
+  ## See also `getCustomPragmaVal`.
   ##
   ## .. code-block:: nim
   ##   template myAttr() {.pragma.}
   ##   type
   ##     MyObj = object
   ##       myField {.myAttr.}: int
+  ##
+  ##   proc myProc() {.myAttr.} = discard
+  ##
   ##   var o: MyObj
-  ##   assert(o.myField.hasCustomPragma(myAttr) == 0)
+  ##   assert(o.myField.hasCustomPragma(myAttr))
+  ##   assert(myProc.hasCustomPragma(myAttr))
   let pragmaNode = customPragmaNode(n)
   for p in pragmaNode:
     if (p.kind == nnkSym and p == cp) or
@@ -1268,20 +1308,25 @@ macro hasCustomPragma*(n: typed, cp: typed{nkSym}): untyped =
 
 macro getCustomPragmaVal*(n: typed, cp: typed{nkSym}): untyped =
   ## Expands to value of custom pragma `cp` of expression `n` which is expected
-  ## to be `nnkDotExpr`.
+  ## to be `nnkDotExpr`, a proc or a type.
+  ##
+  ## See also `hasCustomPragma`
   ##
   ## .. code-block:: nim
   ##   template serializationKey(key: string) {.pragma.}
   ##   type
-  ##     MyObj = object
+  ##     MyObj {.serializationKey: "mo".} = object
   ##       myField {.serializationKey: "mf".}: int
   ##   var o: MyObj
   ##   assert(o.myField.getCustomPragmaVal(serializationKey) == "mf")
+  ##   assert(o.getCustomPragmaVal(serializationKey) == "mo")
+  ##   assert(MyObj.getCustomPragmaVal(serializationKey) == "mo")
   let pragmaNode = customPragmaNode(n)
   for p in pragmaNode:
     if p.kind in nnkPragmaCallKinds and p.len > 0 and p[0].kind == nnkSym and p[0] == cp:
       return p[1]
-  return newEmptyNode()
+
+  error(n.repr & " doesn't have a pragma named " & cp.repr()) # returning an empty node results in most cases in a cryptic error,
 
 
 when not defined(booting):
diff --git a/tests/pragmas/tcustom_pragma.nim b/tests/pragmas/tcustom_pragma.nim
index 415ae6a32..28a8713ce 100644
--- a/tests/pragmas/tcustom_pragma.nim
+++ b/tests/pragmas/tcustom_pragma.nim
@@ -1,12 +1,12 @@
 import macros
- 
+
 block:
   template myAttr() {.pragma.}
 
   proc myProc():int {.myAttr.} = 2
-  const myAttrIdx = myProc.hasCustomPragma(myAttr)
-  static: 
-    assert(myAttrIdx)
+  const hasMyAttr = myProc.hasCustomPragma(myAttr)
+  static:
+    assert(hasMyAttr)
 
 block:
   template myAttr(a: string) {.pragma.}
@@ -14,14 +14,14 @@ block:
   type MyObj = object
     myField1, myField2 {.myAttr: "hi".}: int
   var o: MyObj
-  static: 
+  static:
     assert o.myField2.hasCustomPragma(myAttr)
     assert(not o.myField1.hasCustomPragma(myAttr))
 
-import custom_pragma 
+import custom_pragma
 block: # A bit more advanced case
-  type 
-    Subfield = object
+  type
+    Subfield {.defaultValue: "catman".} = object
       c {.serializationKey: "cc".}: float
 
     MySerializable = object
@@ -29,10 +29,9 @@ block: # A bit more advanced case
       b {.custom_pragma.defaultValue"hello".} : int
       field: Subfield
       d {.alternativeKey("df", 5).}: float
-      e {.alternativeKey(V = 5).}: seq[bool] 
-
+      e {.alternativeKey(V = 5).}: seq[bool]
 
-  proc myproc(x: int, s: string) {.alternativeKey(V = 5), serializationKey"myprocSS".} = 
+  proc myproc(x: int, s: string) {.alternativeKey(V = 5), serializationKey"myprocSS".} =
     echo x, s
 
 
@@ -51,3 +50,83 @@ block: # A bit more advanced case
   static: assert(procSerKey == "myprocSS")
 
   static: assert(hasCustomPragma(myproc, alternativeKey))
+
+  # pragma on an object
+  static:
+    assert Subfield.hasCustomPragma(defaultValue)
+    assert(Subfield.getCustomPragmaVal(defaultValue) == "catman")
+
+    assert hasCustomPragma(type(s.field), defaultValue)
+
+block: # ref types
+  type
+    Node = object of RootObj
+      left {.serializationKey:"l".}, right {.serializationKey:"r".}: NodeRef
+    NodeRef = ref Node
+    NodePtr = ptr Node
+
+    SpecialNodeRef = ref object of NodeRef
+      data {.defaultValue"none".}: string
+
+    MyFile {.defaultValue: "closed".} = ref object
+      path {.defaultValue: "invalid".}: string
+
+  var s = NodeRef()
+
+  const
+    leftSerKey = getCustomPragmaVal(s.left, serializationKey)
+    rightSerKey = getCustomPragmaVal(s.right, serializationKey)
+  static:
+    assert leftSerKey == "l"
+    assert rightSerKey == "r"
+
+  var specS = SpecialNodeRef()
+
+  const
+    dataDefVal = hasCustomPragma(specS.data, defaultValue)
+    specLeftSerKey = hasCustomPragma(specS.left, serializationKey)
+  static:
+    assert dataDefVal == true
+    assert specLeftSerKey == true
+
+  var ptrS = NodePtr(nil)
+  const
+    ptrRightSerKey = getCustomPragmaVal(s.right, serializationKey)
+  static:
+    assert ptrRightSerKey == "r"
+
+  var f = MyFile()
+  const
+    fileDefVal = f.getCustomPragmaVal(defaultValue)
+    filePathDefVal = f.path.getCustomPragmaVal(defaultValue)
+  static:
+    assert fileDefVal == "closed"
+    assert filePathDefVal == "invalid"
+
+block:
+  type
+    VariantKind = enum
+      variInt,
+      variFloat
+      variString
+      variNestedCase
+    Variant = object
+      case kind: VariantKind
+      of variInt: integer {.serializationKey: "int".}: BiggestInt
+      of variFloat: floatp: BiggestFloat
+      of variString: str {.serializationKey: "string".}: string
+      of variNestedCase:
+        case nestedKind: VariantKind
+        of variInt..variNestedCase: nestedItem {.defaultValue: "Nimmers of the world, unite!".}: int
+
+  let vari = Variant(kind: variInt)
+
+  const
+    hasIntSerKey = vari.integer.hasCustomPragma(serializationKey)
+    strSerKey = vari.str.getCustomPragmaVal(serializationKey)
+    nestedItemDefVal = vari.nestedItem.getCustomPragmaVal(defaultValue)
+
+  static:
+    assert hasIntSerKey
+    assert strSerKey == "string"
+    assert nestedItemDefVal == "Nimmers of the world, unite!"
\ No newline at end of file