summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--changelog.md2
-rw-r--r--compiler/ast.nim59
-rw-r--r--compiler/semstmts.nim122
-rw-r--r--compiler/semtypes.nim48
-rw-r--r--doc/manual.rst118
-rw-r--r--lib/pure/asyncmacro.nim6
-rw-r--r--tests/pragmas/tcustom_pragma.nim88
7 files changed, 303 insertions, 140 deletions
diff --git a/changelog.md b/changelog.md
index 42f469567..78c707bf4 100644
--- a/changelog.md
+++ b/changelog.md
@@ -165,6 +165,8 @@ echo f
 - `var a {.foo.}: MyType = expr` now lowers to `foo(a, MyType, expr)` for non builtin pragmas,
   enabling things like lvalue references, see `pragmas.byaddr`
 
+- `macro pragmas` can now be used in type sections.
+
 ## Language changes
 
 - Unsigned integer operators have been fixed to allow promotion of the first operand.
diff --git a/compiler/ast.nim b/compiler/ast.nim
index 33edcc953..8b000ecb9 100644
--- a/compiler/ast.nim
+++ b/compiler/ast.nim
@@ -1597,48 +1597,43 @@ proc transitionToLet*(s: PSym) =
   s.bitsize = obj.bitsize
   s.alignment = obj.alignment
 
-proc shallowCopy*(src: PNode): PNode =
-  # does not copy its sons, but provides space for them:
-  if src == nil: return nil
-  result = newNode(src.kind)
-  result.info = src.info
-  result.typ = src.typ
-  result.flags = src.flags * PersistentNodeFlags
-  result.comment = src.comment
+template copyNodeImpl(dst, src, processSonsStmt) =
+  if src == nil: return
+  dst = newNode(src.kind)
+  dst.info = src.info
+  dst.typ = src.typ
+  dst.flags = src.flags * PersistentNodeFlags
+  dst.comment = src.comment
   when defined(useNodeIds):
-    if result.id == nodeIdToDebug:
+    if dst.id == nodeIdToDebug:
       echo "COMES FROM ", src.id
   case src.kind
-  of nkCharLit..nkUInt64Lit: result.intVal = src.intVal
-  of nkFloatLiterals: result.floatVal = src.floatVal
-  of nkSym: result.sym = src.sym
-  of nkIdent: result.ident = src.ident
-  of nkStrLit..nkTripleStrLit: result.strVal = src.strVal
-  else: newSeq(result.sons, src.len)
+  of nkCharLit..nkUInt64Lit: dst.intVal = src.intVal
+  of nkFloatLiterals: dst.floatVal = src.floatVal
+  of nkSym: dst.sym = src.sym
+  of nkIdent: dst.ident = src.ident
+  of nkStrLit..nkTripleStrLit: dst.strVal = src.strVal
+  else: processSonsStmt
+
+proc shallowCopy*(src: PNode): PNode =
+  # does not copy its sons, but provides space for them:
+  copyNodeImpl(result, src):
+    newSeq(result.sons, src.len)
 
 proc copyTree*(src: PNode): PNode =
   # copy a whole syntax tree; performs deep copying
-  if src == nil:
-    return nil
-  result = newNode(src.kind)
-  result.info = src.info
-  result.typ = src.typ
-  result.flags = src.flags * PersistentNodeFlags
-  result.comment = src.comment
-  when defined(useNodeIds):
-    if result.id == nodeIdToDebug:
-      echo "COMES FROM ", src.id
-  case src.kind
-  of nkCharLit..nkUInt64Lit: result.intVal = src.intVal
-  of nkFloatLiterals: result.floatVal = src.floatVal
-  of nkSym: result.sym = src.sym
-  of nkIdent: result.ident = src.ident
-  of nkStrLit..nkTripleStrLit: result.strVal = src.strVal
-  else:
+  copyNodeImpl(result, src):
     newSeq(result.sons, src.len)
     for i in 0..<src.len:
       result[i] = copyTree(src[i])
 
+proc copyTreeWithoutNode*(src, skippedNode: PNode): PNode =
+  copyNodeImpl(result, src):
+    result.sons = newSeqOfCap[PNode](src.len)
+    for n in src.sons:
+      if n != skippedNode:
+        result.sons.add copyTreeWithoutNode(n, skippedNode)
+
 proc hasSonWith*(n: PNode, kind: TNodeKind): bool =
   for i in 0..<n.len:
     if n[i].kind == kind:
diff --git a/compiler/semstmts.nim b/compiler/semstmts.nim
index d0cab48a5..d99a02183 100644
--- a/compiler/semstmts.nim
+++ b/compiler/semstmts.nim
@@ -1051,6 +1051,68 @@ proc typeSectionTypeName(c: PContext; n: PNode): PNode =
     result = n
   if result.kind != nkSym: illFormedAst(n, c.config)
 
+proc typeDefLeftSidePass(c: PContext, typeSection: PNode, i: int) =
+  let typeDef= typeSection[i]
+  checkSonsLen(typeDef, 3, c.config)
+  var name = typeDef[0]
+  var s: PSym
+  if name.kind == nkDotExpr and typeDef[2].kind == nkObjectTy:
+    let pkgName = considerQuotedIdent(c, name[0])
+    let typName = considerQuotedIdent(c, name[1])
+    let pkg = c.graph.packageSyms.strTableGet(pkgName)
+    if pkg.isNil or pkg.kind != skPackage:
+      localError(c.config, name.info, "unknown package name: " & pkgName.s)
+    else:
+      let typsym = pkg.tab.strTableGet(typName)
+      if typsym.isNil:
+        s = semIdentDef(c, name[1], skType)
+        onDef(name[1].info, s)
+        s.typ = newTypeS(tyObject, c)
+        s.typ.sym = s
+        s.flags.incl sfForward
+        pkg.tab.strTableAdd s
+        addInterfaceDecl(c, s)
+      elif typsym.kind == skType and sfForward in typsym.flags:
+        s = typsym
+        addInterfaceDecl(c, s)
+      else:
+        localError(c.config, name.info, typsym.name.s & " is not a type that can be forwarded")
+        s = typsym
+  else:
+    s = semIdentDef(c, name, skType)
+    onDef(name.info, s)
+    s.typ = newTypeS(tyForward, c)
+    s.typ.sym = s             # process pragmas:
+    if name.kind == nkPragmaExpr:
+      let rewritten = applyTypeSectionPragmas(c, name[1], typeDef)
+      if rewritten != nil:
+        typeSection[i] = rewritten
+        typeDefLeftSidePass(c, typeSection, i)
+        return
+      pragma(c, s, name[1], typePragmas)
+    if sfForward in s.flags:
+      # check if the symbol already exists:
+      let pkg = c.module.owner
+      if not isTopLevel(c) or pkg.isNil:
+        localError(c.config, name.info, "only top level types in a package can be 'package'")
+      else:
+        let typsym = pkg.tab.strTableGet(s.name)
+        if typsym != nil:
+          if sfForward notin typsym.flags or sfNoForward notin typsym.flags:
+            typeCompleted(typsym)
+            typsym.info = s.info
+          else:
+            localError(c.config, name.info, "cannot complete type '" & s.name.s & "' twice; " &
+                    "previous type completion was here: " & c.config$typsym.info)
+          s = typsym
+    # add it here, so that recursive types are possible:
+    if sfGenSym notin s.flags: addInterfaceDecl(c, s)
+    elif s.owner == nil: s.owner = getCurrOwner(c)
+
+  if name.kind == nkPragmaExpr:
+    typeDef[0][0] = newSymNode(s)
+  else:
+    typeDef[0] = newSymNode(s)
 
 proc typeSectionLeftSidePass(c: PContext, n: PNode) =
   # process the symbols on the left side for the whole type section, before
@@ -1064,61 +1126,7 @@ proc typeSectionLeftSidePass(c: PContext, n: PNode) =
         dec c.inTypeContext
     if a.kind == nkCommentStmt: continue
     if a.kind != nkTypeDef: illFormedAst(a, c.config)
-    checkSonsLen(a, 3, c.config)
-    let name = a[0]
-    var s: PSym
-    if name.kind == nkDotExpr and a[2].kind == nkObjectTy:
-      let pkgName = considerQuotedIdent(c, name[0])
-      let typName = considerQuotedIdent(c, name[1])
-      let pkg = c.graph.packageSyms.strTableGet(pkgName)
-      if pkg.isNil or pkg.kind != skPackage:
-        localError(c.config, name.info, "unknown package name: " & pkgName.s)
-      else:
-        let typsym = pkg.tab.strTableGet(typName)
-        if typsym.isNil:
-          s = semIdentDef(c, name[1], skType)
-          onDef(name[1].info, s)
-          s.typ = newTypeS(tyObject, c)
-          s.typ.sym = s
-          s.flags.incl sfForward
-          pkg.tab.strTableAdd s
-          addInterfaceDecl(c, s)
-        elif typsym.kind == skType and sfForward in typsym.flags:
-          s = typsym
-          addInterfaceDecl(c, s)
-        else:
-          localError(c.config, name.info, typsym.name.s & " is not a type that can be forwarded")
-          s = typsym
-    else:
-      s = semIdentDef(c, name, skType)
-      onDef(name.info, s)
-      s.typ = newTypeS(tyForward, c)
-      s.typ.sym = s             # process pragmas:
-      if name.kind == nkPragmaExpr:
-        pragma(c, s, name[1], typePragmas)
-      if sfForward in s.flags:
-        # check if the symbol already exists:
-        let pkg = c.module.owner
-        if not isTopLevel(c) or pkg.isNil:
-          localError(c.config, name.info, "only top level types in a package can be 'package'")
-        else:
-          let typsym = pkg.tab.strTableGet(s.name)
-          if typsym != nil:
-            if sfForward notin typsym.flags or sfNoForward notin typsym.flags:
-              typeCompleted(typsym)
-              typsym.info = s.info
-            else:
-              localError(c.config, name.info, "cannot complete type '" & s.name.s & "' twice; " &
-                      "previous type completion was here: " & c.config$typsym.info)
-            s = typsym
-      # add it here, so that recursive types are possible:
-      if sfGenSym notin s.flags: addInterfaceDecl(c, s)
-      elif s.owner == nil: s.owner = getCurrOwner(c)
-
-    if name.kind == nkPragmaExpr:
-      a[0][0] = newSymNode(s)
-    else:
-      a[0] = newSymNode(s)
+    typeDefLeftSidePass(c, n, i)
 
 proc checkCovariantParamsUsages(c: PContext; genericType: PType) =
   var body = genericType[^1]
@@ -1453,8 +1461,8 @@ proc semProcAnnotation(c: PContext, prc: PNode;
   var n = prc[pragmasPos]
   if n == nil or n.kind == nkEmpty: return
   for i in 0..<n.len:
-    var it = n[i]
-    var key = if it.kind in nkPragmaCallKinds and it.len >= 1: it[0] else: it
+    let it = n[i]
+    let key = if it.kind in nkPragmaCallKinds and it.len >= 1: it[0] else: it
 
     if whichPragma(it) != wInvalid:
       # Not a custom pragma
diff --git a/compiler/semtypes.nim b/compiler/semtypes.nim
index 6e423181c..7cddc7520 100644
--- a/compiler/semtypes.nim
+++ b/compiler/semtypes.nim
@@ -1567,9 +1567,44 @@ proc semTypeClass(c: PContext, n: PNode, prev: PType): PType =
   result.n[3] = semConceptBody(c, n[3])
   closeScope(c)
 
+proc applyTypeSectionPragmas(c: PContext; pragmas, operand: PNode): PNode =
+  for p in pragmas:
+    let key = if p.kind in nkPragmaCallKinds and p.len >= 1: p[0] else: p
+
+    if p.kind == nkEmpty or whichPragma(p) != wInvalid:
+      discard "builtin pragma"
+    elif strTableGet(c.userPragmas, considerQuotedIdent(c, key)) != nil:
+      discard "User-defined pragma"
+    else:
+      # we transform ``(arg1, arg2: T) {.m, rest.}`` into ``m((arg1, arg2: T) {.rest.})`` and
+      # let the semantic checker deal with it:
+      var x = newNodeI(nkCall, key.info)
+      x.add(key)
+      if p.kind in nkPragmaCallKinds and p.len > 1:
+        # pass pragma arguments to the macro too:
+        for i in 1 ..< p.len:
+          x.add(p[i])
+      # Also pass the node the pragma has been applied to
+      x.add(operand.copyTreeWithoutNode(p))
+      # recursion assures that this works for multiple macro annotations too:
+      var r = semOverloadedCall(c, x, x, {skMacro, skTemplate}, {efNoUndeclared})
+      if r != nil:
+        doAssert r[0].kind == nkSym
+        let m = r[0].sym
+        case m.kind
+        of skMacro: return semMacroExpr(c, r, r, m, {efNoSemCheck})
+        of skTemplate: return semTemplateExpr(c, r, m, {efNoSemCheck})
+        else: doAssert(false, "cannot happen")
+
 proc semProcTypeWithScope(c: PContext, n: PNode,
-                        prev: PType, kind: TSymKind): PType =
+                          prev: PType, kind: TSymKind): PType =
   checkSonsLen(n, 2, c.config)
+
+  if n[1].kind != nkEmpty and n[1].len > 0:
+    let macroEval = applyTypeSectionPragmas(c, n[1], n)
+    if macroEval != nil:
+      return semTypeNode(c, macroEval, prev)
+
   openScope(c)
   result = semProcTypeNode(c, n[0], nil, prev, kind, isType=true)
   # start with 'ccClosure', but of course pragmas can overwrite this:
@@ -1845,11 +1880,12 @@ proc semTypeNode(c: PContext, n: PNode, prev: PType): PType =
       result.addSonSkipIntLit(child)
     else:
       result = semProcTypeWithScope(c, n, prev, skIterator)
-      result.flags.incl(tfIterator)
-      if n.lastSon.kind == nkPragma and hasPragma(n.lastSon, wInline):
-        result.callConv = ccInline
-      else:
-        result.callConv = ccClosure
+      if result.kind == tyProc:
+        result.flags.incl(tfIterator)
+        if n.lastSon.kind == nkPragma and hasPragma(n.lastSon, wInline):
+          result.callConv = ccInline
+        else:
+          result.callConv = ccClosure
   of nkProcTy:
     if n.len == 0:
       result = newConstraint(c, tyProc)
diff --git a/doc/manual.rst b/doc/manual.rst
index bffb9e60c..4c1d5c659 100644
--- a/doc/manual.rst
+++ b/doc/manual.rst
@@ -5353,26 +5353,6 @@ powerful programming construct that still suffices. So the "check list" is:
 (4) Else: Use a macro.
 
 
-Macros as pragmas
------------------
-
-Whole routines (procs, iterators etc.) can also be passed to a template or
-a macro via the pragma notation:
-
-.. code-block:: nim
-  template m(s: untyped) = discard
-
-  proc p() {.m.} = discard
-
-This is a simple syntactic transformation into:
-
-.. code-block:: nim
-  template m(s: untyped) = discard
-
-  m:
-    proc p() = discard
-
-
 For Loop Macro
 --------------
 
@@ -6408,29 +6388,6 @@ the created global variables within a module is not defined, but all of them
 will be initialized after any top-level variables in their originating module
 and before any variable in a module that imports it.
 
-pragma pragma
--------------
-
-The ``pragma`` pragma can be used to declare user defined pragmas. This is
-useful because Nim's templates and macros do not affect pragmas. User
-defined pragmas are in a different module-wide scope than all other symbols.
-They cannot be imported from a module.
-
-Example:
-
-.. code-block:: nim
-  when appType == "lib":
-    {.pragma: rtl, exportc, dynlib, cdecl.}
-  else:
-    {.pragma: rtl, importc, dynlib: "client.dll", cdecl.}
-
-  proc p*(a, b: int): int {.rtl.} =
-    result = a+b
-
-In the example a new pragma named ``rtl`` is introduced that either imports
-a symbol from a dynamic library or exports the symbol for dynamic library
-generation.
-
 Disabling certain messages
 --------------------------
 Nim generates some warnings and hints ("line too long") that may annoy the
@@ -7127,6 +7084,34 @@ used. To see if a value was provided, `defined(FooBar)` can be used.
 
 The syntax `-d:flag` is actually just a shortcut for `-d:flag=true`.
 
+User-defined pragmas
+====================
+
+
+pragma pragma
+-------------
+
+The ``pragma`` pragma can be used to declare user defined pragmas. This is
+useful because Nim's templates and macros do not affect pragmas. User
+defined pragmas are in a different module-wide scope than all other symbols.
+They cannot be imported from a module.
+
+Example:
+
+.. code-block:: nim
+  when appType == "lib":
+    {.pragma: rtl, exportc, dynlib, cdecl.}
+  else:
+    {.pragma: rtl, importc, dynlib: "client.dll", cdecl.}
+
+  proc p*(a, b: int): int {.rtl.} =
+    result = a+b
+
+In the example a new pragma named ``rtl`` is introduced that either imports
+a symbol from a dynamic library or exports the symbol for dynamic library
+generation.
+
+
 Custom annotations
 ------------------
 It is possible to define custom typed pragmas. Custom pragmas do not effect
@@ -7193,6 +7178,51 @@ More examples with custom pragmas:
     alpha {.editRange: [0.0..1.0], animatable.}: float32
 
 
+Macro pragmas
+-------------
+
+All macros and templates can also be used as pragmas. They can be attached
+to routines (procs, iterators, etc), type names or type expressions. The
+compiler will perform the following simple syntactic transformations:
+
+.. code-block:: nim
+  template command(name: string, def: untyped) = discard
+
+  proc p() {.command("print").} = discard
+
+This is translated to:
+
+.. code-block:: nim
+  command("print"):
+    proc p() = discard
+
+------
+
+.. code-block:: nim
+  type
+    AsyncEventHandler = proc (x: Event) {.async.}
+
+This is translated to:
+
+.. code-block:: nim
+  type
+    AsyncEventHandler = async(proc (x: Event))
+
+------
+
+.. code-block:: nim
+  type
+    MyObject {.schema: "schema.protobuf".} = object
+
+This is translated to a call to the ``schema`` macro with a `nnkTypeDef`
+AST node capturing both the left-hand side and right-hand side of the
+definition. The macro can return a potentially modified `nnkTypeDef` tree
+will replace the original row in the type section.
+
+When multiple macro pragmas are applied to the same definition, the
+compiler will apply them consequently from left to right. Each macro
+will receive as input the output of the previous one.
+
 
 
 Foreign function interface
@@ -7269,7 +7299,6 @@ In the example the external name of ``p`` is set to ``prefixp``. Only ``$1``
 is available and a literal dollar sign must be written as ``$$``.
 
 
-
 Bycopy pragma
 -------------
 
@@ -7313,6 +7342,7 @@ checked.
 **Future directions**: GC'ed memory should be allowed in unions and the GC
 should scan unions conservatively.
 
+
 Packed pragma
 -------------
 The ``packed`` pragma can be applied to any ``object`` type. It ensures
diff --git a/lib/pure/asyncmacro.nim b/lib/pure/asyncmacro.nim
index ce84491eb..c6f4ae04a 100644
--- a/lib/pure/asyncmacro.nim
+++ b/lib/pure/asyncmacro.nim
@@ -212,6 +212,12 @@ proc verifyReturnType(typeName: string) {.compileTime.} =
 proc asyncSingleProc(prc: NimNode): NimNode {.compileTime.} =
   ## This macro transforms a single procedure into a closure iterator.
   ## The ``async`` macro supports a stmtList holding multiple async procedures.
+  if prc.kind == nnkProcTy:
+    result = prc
+    if prc[0][0].kind == nnkEmpty:
+      result[0][0] = parseExpr("Future[void]")
+    return result
+
   if prc.kind notin {nnkProcDef, nnkLambda, nnkMethodDef, nnkDo}:
     error("Cannot transform this node kind into an async proc." &
           " proc/method definition or lambda node expected.")
diff --git a/tests/pragmas/tcustom_pragma.nim b/tests/pragmas/tcustom_pragma.nim
index 719af4d50..2d970aba1 100644
--- a/tests/pragmas/tcustom_pragma.nim
+++ b/tests/pragmas/tcustom_pragma.nim
@@ -1,6 +1,6 @@
 {.experimental: "notnil".}
 
-import macros
+import macros, asyncmacro, asyncfutures
 
 block:
   template myAttr() {.pragma.}
@@ -249,3 +249,89 @@ block:
   var e {.fooBar("foo", 123, 'u').}: int
   doAssert(hasCustomPragma(e, fooBar))
   doAssert(getCustomPragmaVal(e, fooBar).c == 123)
+
+block:
+  macro expectedAst(expectedRepr: static[string], input: untyped): untyped =
+    assert input.treeRepr & "\n" == expectedRepr
+    return input
+
+  const procTypeAst = """
+ProcTy
+  FormalParams
+    Empty
+    IdentDefs
+      Ident "x"
+      Ident "int"
+      Empty
+  Pragma
+    Ident "async"
+"""
+
+  type
+    Foo = proc (x: int) {.expectedAst(procTypeAst), async.}
+
+  static: assert Foo is proc(x: int): Future[void]
+
+  const asyncProcTypeAst = """
+ProcTy
+  FormalParams
+    BracketExpr
+      Ident "Future"
+      Ident "void"
+    IdentDefs
+      Ident "s"
+      Ident "string"
+      Empty
+  Pragma
+"""
+
+  type
+    Bar = proc (s: string) {.async, expectedAst(asyncProcTypeAst).}
+
+  static: assert Bar is proc(x: string): Future[void]
+
+  const typeAst = """
+TypeDef
+  PragmaExpr
+    Ident "Baz"
+    Pragma
+  Empty
+  ObjectTy
+    Empty
+    Empty
+    RecList
+      IdentDefs
+        Ident "x"
+        Ident "string"
+        Empty
+"""
+
+  type
+    Baz {.expectedAst(typeAst).} = object
+      x: string
+
+  static: assert Baz.x is string
+
+  const procAst = """
+ProcDef
+  Ident "bar"
+  Empty
+  Empty
+  FormalParams
+    Ident "string"
+    IdentDefs
+      Ident "s"
+      Ident "string"
+      Empty
+  Empty
+  Empty
+  StmtList
+    ReturnStmt
+      Ident "s"
+"""
+
+  proc bar(s: string): string {.expectedAst(procAst).} =
+    return s
+
+  static: assert bar("x") == "x"
+