summary refs log tree commit diff stats
path: root/lib/pure/asyncmacro.nim
diff options
context:
space:
mode:
Diffstat (limited to 'lib/pure/asyncmacro.nim')
-rw-r--r--lib/pure/asyncmacro.nim179
1 files changed, 100 insertions, 79 deletions
diff --git a/lib/pure/asyncmacro.nim b/lib/pure/asyncmacro.nim
index eecec4278..d4e72c28a 100644
--- a/lib/pure/asyncmacro.nim
+++ b/lib/pure/asyncmacro.nim
@@ -7,10 +7,14 @@
 #    distribution, for details about the copyright.
 #
 
-## `asyncdispatch` module depends on the `asyncmacro` module to work properly.
+## Implements the `async` and `multisync` macros for `asyncdispatch`.
 
-import macros, strutils, asyncfutures
+import std/[macros, strutils, asyncfutures]
 
+type
+  Context = ref object
+    inTry: int
+    hasRet: bool
 
 # TODO: Ref https://github.com/nim-lang/Nim/issues/5617
 # TODO: Add more line infos
@@ -21,9 +25,8 @@ proc newCallWithLineInfo(fromNode: NimNode; theProc: NimNode, args: varargs[NimN
 template createCb(retFutureSym, iteratorNameSym,
                   strName, identName, futureVarCompletions: untyped) =
   bind finished
-
   var nameIterVar = iteratorNameSym
-  proc identName {.closure.} =
+  proc identName {.closure, stackTrace: off.} =
     try:
       if not nameIterVar.finished:
         var next = nameIterVar()
@@ -35,14 +38,11 @@ template createCb(retFutureSym, iteratorNameSym,
 
         if next == nil:
           if not retFutureSym.finished:
-            let msg = "Async procedure ($1) yielded `nil`, are you await'ing a " &
-                    "`nil` Future?"
+            let msg = "Async procedure ($1) yielded `nil`, are you await'ing a `nil` Future?"
             raise newException(AssertionDefect, msg % strName)
         else:
           {.gcsafe.}:
-            {.push hint[ConvFromXtoItselfNotNeeded]: off.}
             next.addCallback cast[proc() {.closure, gcsafe.}](identName)
-            {.pop.}
     except:
       futureVarCompletions
       if retFutureSym.finished:
@@ -68,10 +68,7 @@ proc createFutureVarCompletions(futureVarIdents: seq[NimNode], fromNode: NimNode
       )
     )
 
-proc processBody(node, retFutureSym: NimNode,
-                 subTypeIsVoid: bool,
-                 futureVarIdents: seq[NimNode]): NimNode =
-  #echo(node.treeRepr)
+proc processBody(ctx: Context; node, needsCompletionSym, retFutureSym: NimNode, futureVarIdents: seq[NimNode]): NimNode =
   result = node
   case node.kind
   of nnkReturnStmt:
@@ -80,29 +77,53 @@ proc processBody(node, retFutureSym: NimNode,
     # As I've painfully found out, the order here really DOES matter.
     result.add createFutureVarCompletions(futureVarIdents, node)
 
+    ctx.hasRet = true
     if node[0].kind == nnkEmpty:
-      if not subTypeIsVoid:
-        result.add newCall(newIdentNode("complete"), retFutureSym,
-            newIdentNode("result"))
+      if ctx.inTry == 0:
+        result.add newCallWithLineInfo(node, newIdentNode("complete"), retFutureSym, newIdentNode("result"))
       else:
-        result.add newCall(newIdentNode("complete"), retFutureSym)
+        result.add newAssignment(needsCompletionSym, newLit(true))
     else:
-      let x = node[0].processBody(retFutureSym, subTypeIsVoid,
-                                  futureVarIdents)
+      let x = processBody(ctx, node[0], needsCompletionSym, retFutureSym, futureVarIdents)
       if x.kind == nnkYieldStmt: result.add x
+      elif ctx.inTry == 0:
+        result.add newCallWithLineInfo(node, newIdentNode("complete"), retFutureSym, x)
       else:
-        result.add newCall(newIdentNode("complete"), retFutureSym, x)
+        result.add newAssignment(newIdentNode("result"), x)
+        result.add newAssignment(needsCompletionSym, newLit(true))
 
     result.add newNimNode(nnkReturnStmt, node).add(newNilLit())
     return # Don't process the children of this return stmt
   of RoutineNodes-{nnkTemplateDef}:
     # skip all the nested procedure definitions
     return
-  else: discard
-
-  for i in 0 ..< result.len:
-    result[i] = processBody(result[i], retFutureSym, subTypeIsVoid,
-                            futureVarIdents)
+  of nnkTryStmt:
+    if result[^1].kind == nnkFinally:
+      inc ctx.inTry
+      result[0] = processBody(ctx, result[0], needsCompletionSym, retFutureSym, futureVarIdents)
+      dec ctx.inTry
+      for i in 1 ..< result.len:
+        result[i] = processBody(ctx, result[i], needsCompletionSym, retFutureSym, futureVarIdents)
+      if ctx.inTry == 0 and ctx.hasRet:
+        let finallyNode = copyNimNode(result[^1])
+        let stmtNode = newNimNode(nnkStmtList)
+        for child in result[^1]:
+          stmtNode.add child
+        stmtNode.add newIfStmt(
+          ( needsCompletionSym,
+            newCallWithLineInfo(node, newIdentNode("complete"), retFutureSym,
+            newIdentNode("result")
+            )
+          )
+        )
+        finallyNode.add stmtNode
+        result[^1] = finallyNode
+    else:
+      for i in 0 ..< result.len:
+        result[i] = processBody(ctx, result[i], needsCompletionSym, retFutureSym, futureVarIdents)
+  else:
+    for i in 0 ..< result.len:
+      result[i] = processBody(ctx, result[i], needsCompletionSym, retFutureSym, futureVarIdents)
 
   # echo result.repr
 
@@ -139,9 +160,21 @@ template await*(f: typed): untyped {.used.} =
     error "await expects Future[T], got " & $typeof(f)
 
 template await*[T](f: Future[T]): auto {.used.} =
-  var internalTmpFuture: FutureBase = f
-  yield internalTmpFuture
-  (cast[typeof(f)](internalTmpFuture)).read()
+  when not defined(nimHasTemplateRedefinitionPragma):
+    {.pragma: redefine.}
+  template yieldFuture {.redefine.} = yield FutureBase()
+
+  when compiles(yieldFuture):
+    var internalTmpFuture: FutureBase = f
+    yield internalTmpFuture
+    (cast[typeof(f)](internalTmpFuture)).read()
+  else:
+    macro errorAsync(futureError: Future[T]) =
+      error(
+        "Can only 'await' inside a proc marked as 'async'. Use " &
+        "'waitFor' when calling an 'async' proc in a non-async scope instead",
+        futureError)
+    errorAsync(f)
 
 proc asyncSingleProc(prc: NimNode): NimNode =
   ## This macro transforms a single procedure into a closure iterator.
@@ -149,9 +182,13 @@ proc asyncSingleProc(prc: NimNode): NimNode =
   if prc.kind == nnkProcTy:
     result = prc
     if prc[0][0].kind == nnkEmpty:
-      result[0][0] = parseExpr("Future[void]")
+      result[0][0] = quote do: Future[void]
     return result
 
+  if prc.kind in RoutineNodes and prc.name.kind != nnkEmpty:
+    # Only non anonymous functions need/can have stack trace disabled
+    prc.addPragma(nnkExprColonExpr.newTree(ident"stackTrace", ident"off"))
+
   if prc.kind notin {nnkProcDef, nnkLambda, nnkMethodDef, nnkDo}:
     error("Cannot transform this node kind into an async proc." &
           " proc/method definition or lambda node expected.", prc)
@@ -183,11 +220,7 @@ proc asyncSingleProc(prc: NimNode): NimNode =
   else:
     verifyReturnType(repr(returnType), returnType)
 
-  let subtypeIsVoid = returnType.kind == nnkEmpty or
-        (baseType.kind == nnkIdent and returnType[1].eqIdent("void"))
-
   let futureVarIdents = getFutureVarIdents(prc.params)
-
   var outerProcBody = newNimNode(nnkStmtList, prc.body)
 
   # Extract the documentation comment from the original procedure declaration.
@@ -214,43 +247,42 @@ proc asyncSingleProc(prc: NimNode): NimNode =
   # ->   {.pop.}
   # ->   <proc_body>
   # ->   complete(retFuture, result)
-  var iteratorNameSym = genSym(nskIterator, $prcName & "Iter")
-  var procBody = prc.body.processBody(retFutureSym, subtypeIsVoid,
-                                    futureVarIdents)
+  var iteratorNameSym = genSym(nskIterator, $prcName & " (Async)")
+  var needsCompletionSym = genSym(nskVar, "needsCompletion")
+  var ctx = Context()
+  var procBody = processBody(ctx, prc.body, needsCompletionSym, retFutureSym, futureVarIdents)
   # don't do anything with forward bodies (empty)
   if procBody.kind != nnkEmpty:
     # fix #13899, defer should not escape its original scope
-    procBody = newStmtList(newTree(nnkBlockStmt, newEmptyNode(), procBody))
-
+    let blockStmt = newStmtList(newTree(nnkBlockStmt, newEmptyNode(), procBody))
+    procBody = newStmtList()
+    let resultIdent = ident"result"
+    procBody.add quote do:
+      # Check whether there is an implicit return
+      when typeof(`blockStmt`) is void:
+        `blockStmt`
+      else:
+        `resultIdent` = `blockStmt`
     procBody.add(createFutureVarCompletions(futureVarIdents, nil))
+    procBody.insert(0): quote do:
+      {.push warning[resultshadowed]: off.}
+      when `subRetType` isnot void:
+        var `resultIdent`: `subRetType`
+      else:
+        var `resultIdent`: Future[void]
+      {.pop.}
 
-    if not subtypeIsVoid:
-      procBody.insert(0, newNimNode(nnkPragma).add(newIdentNode("push"),
-        newNimNode(nnkExprColonExpr).add(newNimNode(nnkBracketExpr).add(
-          newIdentNode("warning"), newIdentNode("resultshadowed")),
-        newIdentNode("off")))) # -> {.push warning[resultshadowed]: off.}
-
-      procBody.insert(1, newNimNode(nnkVarSection, prc.body).add(
-        newIdentDefs(newIdentNode("result"), baseType))) # -> var result: T
-
-      procBody.insert(2, newNimNode(nnkPragma).add(
-        newIdentNode("pop"))) # -> {.pop.})
-
-      procBody.add(
-        newCall(newIdentNode("complete"),
-          retFutureSym, newIdentNode("result"))) # -> complete(retFuture, result)
-    else:
-      # -> complete(retFuture)
-      procBody.add(newCall(newIdentNode("complete"), retFutureSym))
+      var `needsCompletionSym` = false
+    procBody.add quote do:
+      complete(`retFutureSym`, `resultIdent`)
 
-    var closureIterator = newProc(iteratorNameSym, [parseExpr("owned(FutureBase)")],
+    var closureIterator = newProc(iteratorNameSym, [quote do: owned(FutureBase)],
                                   procBody, nnkIteratorDef)
     closureIterator.pragma = newNimNode(nnkPragma, lineInfoFrom = prc.body)
     closureIterator.addPragma(newIdentNode("closure"))
 
     # If proc has an explicit gcsafe pragma, we add it to iterator as well.
-    if prc.pragma.findChild(it.kind in {nnkSym, nnkIdent} and $it ==
-        "gcsafe") != nil:
+    if prc.pragma.findChild(it.kind in {nnkSym, nnkIdent} and $it == "gcsafe") != nil:
       closureIterator.addPragma(newIdentNode("gcsafe"))
     outerProcBody.add(closureIterator)
 
@@ -259,21 +291,20 @@ proc asyncSingleProc(prc: NimNode): NimNode =
     # friendlier stack traces:
     var cbName = genSym(nskProc, prcName & NimAsyncContinueSuffix)
     var procCb = getAst createCb(retFutureSym, iteratorNameSym,
-                         newStrLitNode(prcName),
-                         cbName,
-                         createFutureVarCompletions(futureVarIdents, nil))
+                          newStrLitNode(prcName),
+                          cbName,
+                          createFutureVarCompletions(futureVarIdents, nil)
+                        )
     outerProcBody.add procCb
 
     # -> return retFuture
     outerProcBody.add newNimNode(nnkReturnStmt, prc.body[^1]).add(retFutureSym)
 
   result = prc
-
-  if subtypeIsVoid:
-    # Add discardable pragma.
-    if returnType.kind == nnkEmpty:
-      # Add Future[void]
-      result.params[0] = parseExpr("owned(Future[void])")
+  # Add discardable pragma.
+  if returnType.kind == nnkEmpty:
+    # xxx consider removing `owned`? it's inconsistent with non-void case
+    result.params[0] = quote do: owned(Future[void])
 
   # based on the yglukhov's patch to chronos: https://github.com/status-im/nim-chronos/pull/47
   if procBody.kind != nnkEmpty:
@@ -281,10 +312,6 @@ proc asyncSingleProc(prc: NimNode): NimNode =
       `outerProcBody`
     result.body = body2
 
-  #echo(treeRepr(result))
-  #if prcName == "recvLineInto":
-  #  echo(toStrLit(result))
-
 macro async*(prc: untyped): untyped =
   ## Macro which processes async procedures into the appropriate
   ## iterators and yield statements.
@@ -300,8 +327,8 @@ macro async*(prc: untyped): untyped =
 proc splitParamType(paramType: NimNode, async: bool): NimNode =
   result = paramType
   if paramType.kind == nnkInfix and paramType[0].strVal in ["|", "or"]:
-    let firstAsync = "async" in paramType[1].strVal.normalize
-    let secondAsync = "async" in paramType[2].strVal.normalize
+    let firstAsync = "async" in paramType[1].toStrLit().strVal.normalize
+    let secondAsync = "async" in paramType[2].toStrLit().strVal.normalize
 
     if firstAsync:
       result = paramType[if async: 1 else: 2]
@@ -354,9 +381,3 @@ macro multisync*(prc: untyped): untyped =
   result = newStmtList()
   result.add(asyncSingleProc(asyncPrc))
   result.add(sync)
-  # echo result.repr
-
-# overload for await as a fallback handler, based on the yglukhov's patch to chronos: https://github.com/status-im/nim-chronos/pull/47
-# template await*(f: typed): untyped =
-  # static:
-    # error "await only available within {.async.}"