summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--lib/core/macros.nim78
-rw-r--r--tests/pragmas/tcustom_pragma.nim10
2 files changed, 50 insertions, 38 deletions
diff --git a/lib/core/macros.nim b/lib/core/macros.nim
index fc68781da..79c3ba28b 100644
--- a/lib/core/macros.nim
+++ b/lib/core/macros.nim
@@ -1536,6 +1536,14 @@ proc getPragmaNodeFromTypeSym(sym: NimNode): NimNode =
     if pragmaExpr.kind == nnkPragmaExpr:
       result = pragmaExpr[1]
 
+proc getPragmaNodeFromType(typ: NimNode): NimNode =
+  case typ.kind
+  of nnkSym:
+    result = getPragmaNodeFromTypeSym(typ)
+  of nnkProcTy:
+    result = typ[1]
+  else: error("illegal typ kind for argument: " & $typ.kind, typ)
+
 proc getPragmaNodeFromVarLetSym(sym: NimNode): NimNode =
   sym.expectKind nnkSym
   if sym.symKind notin {nskVar, nskLet}: error("expected var/let sym", sym)
@@ -1546,73 +1554,64 @@ proc getPragmaNodeFromVarLetSym(sym: NimNode): NimNode =
   if pragmaExpr.kind == nnkPragmaExpr:
     result = pragmaExpr[1]
 
-proc getPragmaByName(pragmaExpr: NimNode, name: string): NimNode =
+proc getPragmasByName(pragmaExpr: NimNode, name: string): seq[NimNode] =
   if pragmaExpr.kind == nnkPragma:
     for it in pragmaExpr:
       if it.kind in nnkPragmaCallKinds:
         if eqIdent(it[0], name):
-          return it
+          result.add it
       elif it.kind == nnkSym:
         if eqIdent(it, name):
-          return it
+          result.add it
 
-proc getCustomPragmaNode(sym: NimNode, name: string): NimNode =
+proc getCustomPragmaNodes(sym: NimNode, name: string): seq[NimNode] =
   sym.expectKind nnkSym
   case sym.symKind
   of nskField:
-    result = getPragmaNodeFromObjFieldSym(sym).getPragmaByName(name)
+    result = getPragmaNodeFromObjFieldSym(sym).getPragmasByName(name)
   of nskProc:
-    result = getPragmaNodeFromProcSym(sym).getPragmaByName(name)
+    result = getPragmaNodeFromProcSym(sym).getPragmasByName(name)
   of nskType:
-    result = getPragmaNodeFromTypeSym(sym).getPragmaByName(name)
+    result = getPragmaNodeFromTypeSym(sym).getPragmasByName(name)
   of nskParam:
     # When a typedesc parameter is passed to the macro, it will be of nskParam.
     let typeInst = getTypeInst(sym)
     if typeInst.kind == nnkBracketExpr and eqIdent(typeInst[0], "typeDesc"):
-      result = getPragmaNodeFromTypeSym(typeInst[1]).getPragmaByName(name)
+      result = getPragmaNodeFromTypeSym(typeInst[1]).getPragmasByName(name)
     else:
       error("illegal sym kind for argument: " & $sym.symKind, sym)
   of nskVar, nskLet:
-    # I think it is a bad idea to fall back to the typeSym. The API
-    # explicity requests a var/let symbol, not a type symbol.
-    result = getPragmaNodeFromVarLetSym(sym).getPragmaByName(name) or
-             getPragmaNodeFromTypeSym(sym.getTypeInst).getPragmaByName(name)
+    # This checks the type of the sym too, this is consistent with how
+    # field expressions are handled too. If this is changed, make sure to
+    # change it for fields expressions too.
+    result = getPragmaNodeFromType(sym.getTypeInst).getPragmasByName(name)
+    result.add getPragmaNodeFromVarLetSym(sym).getPragmasByName(name)
   else:
     error("illegal sym kind for argument: " & $sym.symKind, sym)
 
 since (1, 5):
-  export getCustomPragmaNode
+  export getCustomPragmaNodes
 
 proc hasCustomPragma*(n: NimNode, name: string): bool =
   n.expectKind nnkSym
-  let pragmaNode = getCustomPragmaNode(n, name)
-  result = pragmaNode != nil
+  result = getCustomPragmaNodes(n, name).len > 0
 
-proc getCustomPragmaNodeSmart(n: NimNode, name: string): NimNode =
+proc getCustomPragmaNodesSmart(n: NimNode, name: string): seq[NimNode] =
   case n.kind
   of nnkDotExpr:
-    result = getCustomPragmaNode(n[1], name)
+    result = getCustomPragmaNodes(n[1], name)
   of nnkCheckedFieldExpr:
     expectKind n[0], nnkDotExpr
-    result = getCustomPragmaNode(n[0][1], name)
+    result = getCustomPragmaNodes(n[0][1], name)
   of nnkSym:
-    result = getCustomPragmaNode(n, name)
+    result = getCustomPragmaNodes(n, name)
   of nnkTypeOfExpr:
-    var typeSym = n.getTypeInst
-    while typeSym.kind == nnkBracketExpr and typeSym[0].eqIdent "typeDesc":
-      typeSym = typeSym[1]
-    case typeSym.kind:
-    of nnkSym:
-      result = getCustomPragmaNode(typeSym, name)
-    of nnkProcTy:
-      # It is a bad idea to support this. The annotation can't be part
-      # of a symbol.
-      let pragmaExpr = typeSym[1]
-      result = getPragmaByName(pragmaExpr, name)
-    else:
-      typeSym.expectKind nnkSym
+    var typ = n.getTypeInst
+    while typ.kind == nnkBracketExpr and typ[0].eqIdent "typeDesc":
+      typ = typ[1]
+    result = getPragmaNodeFromType(typ).getPragmasByName(name)
   of nnkBracketExpr:
-    result = nil #false
+    discard
   else:
     n.expectKind({nnkDotExpr, nnkCheckedFieldExpr, nnkSym, nnkTypeOfExpr})
 
@@ -1633,7 +1632,7 @@ macro hasCustomPragma*(n: typed, cp: typed{nkSym}): bool =
   ##   var o: MyObj
   ##   assert(o.myField.hasCustomPragma(myAttr))
   ##   assert(myProc.hasCustomPragma(myAttr))
-  result = newLit(getCustomPragmaNodeSmart(n, $cp) != nil)
+  result = newLit(getCustomPragmaNodesSmart(n, $cp).len > 0)
 
 iterator iterOverFormalArgs(f: NimNode): tuple[name, typ, val: NimNode] =
   f.expectKind nnkFormalParams
@@ -1644,7 +1643,7 @@ iterator iterOverFormalArgs(f: NimNode): tuple[name, typ, val: NimNode] =
     for j in 0..<f[i].len-2:
       yield (f[i][j], typ, val)
 
-macro getCustomPragmaVal*(n: typed, cp: typed{nkSym}): untyped =
+macro getCustomPragmaVal*(n: typed, cp: typed): untyped =
   ## Expands to value of custom pragma `cp` of expression `n` which is expected
   ## to be `nnkDotExpr`, a proc or a type.
   ##
@@ -1659,12 +1658,15 @@ macro getCustomPragmaVal*(n: typed, cp: typed{nkSym}): untyped =
   ##   assert(o.myField.getCustomPragmaVal(serializationKey) == "mf")
   ##   assert(o.getCustomPragmaVal(serializationKey) == "mo")
   ##   assert(MyObj.getCustomPragmaVal(serializationKey) == "mo")
-  n.expectKind({nnkDotExpr, nnkCheckedFieldExpr, nnkSym, nnkTypeOfExpr})
-  let pragmaNode = getCustomPragmaNodeSmart(n, $cp)
+  n.expectKind {nnkDotExpr, nnkCheckedFieldExpr, nnkSym, nnkTypeOfExpr}
+  cp.expectKind {nnkSym, nnkOpenSymChoice, nnkClosedSymChoice}
+  let pragmaNodes = getCustomPragmaNodesSmart(n, $cp)
+  if pragmaNodes.len == 0:
+    error(n.repr & " doesn't have any custom pragmas")
+  let pragmaNode = pragmaNodes[^1]
 
   case pragmaNode.kind
   of nnkPragmaCallKinds:
-    assert pragmaNode[0] == cp
     if pragmaNode.len == 2:
       result = pragmaNode[1]
     else:
diff --git a/tests/pragmas/tcustom_pragma.nim b/tests/pragmas/tcustom_pragma.nim
index 30b08b44e..3ef5bdcb6 100644
--- a/tests/pragmas/tcustom_pragma.nim
+++ b/tests/pragmas/tcustom_pragma.nim
@@ -165,11 +165,14 @@ type
 let a {.defaultValue(4).}: proc(x: int)  = nil
 var b: MyAnnotatedProcType = nil
 var c: proc(x: int): void {.defaultValue(5).}  = nil
+var d {.defaultValue(44).}: MyAnnotatedProcType = nil
 static:
   doAssert hasCustomPragma(a, defaultValue)
   doAssert hasCustomPragma(MyAnnotatedProcType, defaultValue)
   doAssert hasCustomPragma(b, defaultValue)
   doAssert hasCustomPragma(typeof(c), defaultValue)
+  doAssert getCustomPragmaVal(d, defaultValue) == 44
+  doAssert getCustomPragmaVal(typeof(d), defaultValue) == 4
 
 # bug #8371
 template thingy {.pragma.}
@@ -405,3 +408,10 @@ template hehe(key, val: string, haha) {.pragma.}
 type A {.haha, hoho, haha, hehe("hi", "hu", "he").} = int
 
 assert A.getCustomPragmaVal(hehe) == (key: "hi", val: "hu", haha: "he")
+
+template hehe(key, val: int) {.pragma.}
+
+var bb {.haha, hoho, hehe(1, 2), haha, hehe("hi", "hu", "he").} = 3
+
+# left-to-right priority/override order for getCustomPragmaVal
+assert bb.getCustomPragmaVal(hehe) == (key: "hi", val: "hu", haha: "he")