summary refs log tree commit diff stats
path: root/lib/pure/asyncdispatch.nim
diff options
context:
space:
mode:
Diffstat (limited to 'lib/pure/asyncdispatch.nim')
-rw-r--r--lib/pure/asyncdispatch.nim258
1 files changed, 223 insertions, 35 deletions
diff --git a/lib/pure/asyncdispatch.nim b/lib/pure/asyncdispatch.nim
index c5b516b39..92a737a47 100644
--- a/lib/pure/asyncdispatch.nim
+++ b/lib/pure/asyncdispatch.nim
@@ -434,6 +434,9 @@ when defined(windows) or defined(nimdoc):
       fd*: AsyncFD # TODO: Rename this.
       cb*: proc (fd: AsyncFD, bytesTransferred: Dword,
                 errcode: OSErrorCode) {.closure,gcsafe.}
+      cell*: ForeignCell # we need this `cell` to protect our `cb` environment,
+                         # when using RegisterWaitForSingleObject, because
+                         # waiting is done in different thread.
 
     PDispatcher* = ref object of PDispatcherBase
       ioPort: Handle
@@ -517,6 +520,13 @@ when defined(windows) or defined(nimdoc):
 
       customOverlapped.data.cb(customOverlapped.data.fd,
           lpNumberOfBytesTransferred, OSErrorCode(-1))
+
+      # If cell.data != nil, then system.protect(rawEnv(cb)) was called,
+      # so we need to dispose our `cb` environment, because it is not needed
+      # anymore.
+      if customOverlapped.data.cell.data != nil:
+        system.dispose(customOverlapped.data.cell)
+
       GC_unref(customOverlapped)
     else:
       let errCode = osLastError()
@@ -524,6 +534,8 @@ when defined(windows) or defined(nimdoc):
         assert customOverlapped.data.fd == lpCompletionKey.AsyncFD
         customOverlapped.data.cb(customOverlapped.data.fd,
             lpNumberOfBytesTransferred, errCode)
+        if customOverlapped.data.cell.data != nil:
+          system.dispose(customOverlapped.data.cell)
         GC_unref(customOverlapped)
       else:
         if errCode.int32 == WAIT_TIMEOUT:
@@ -850,6 +862,101 @@ when defined(windows) or defined(nimdoc):
       # free ``ol``.
     return retFuture
 
+  proc sendTo*(socket: AsyncFD, data: pointer, size: int, saddr: ptr SockAddr,
+               saddrLen: Socklen,
+               flags = {SocketFlag.SafeDisconn}): Future[void] =
+    ## Sends ``data`` to specified destination ``saddr``, using
+    ## socket ``socket``. The returned future will complete once all data
+    ## has been sent.
+    verifyPresence(socket)
+    var retFuture = newFuture[void]("sendTo")
+    var dataBuf: TWSABuf
+    dataBuf.buf = cast[cstring](data)
+    dataBuf.len = size.ULONG
+    var bytesSent = 0.Dword
+    var lowFlags = 0.Dword
+
+    # we will preserve address in our stack
+    var staddr: array[128, char] # SOCKADDR_STORAGE size is 128 bytes
+    var stalen: cint = cint(saddrLen)
+    zeroMem(addr(staddr[0]), 128)
+    copyMem(addr(staddr[0]), saddr, saddrLen)
+
+    var ol = PCustomOverlapped()
+    GC_ref(ol)
+    ol.data = CompletionData(fd: socket, cb:
+      proc (fd: AsyncFD, bytesCount: Dword, errcode: OSErrorCode) =
+        if not retFuture.finished:
+          if errcode == OSErrorCode(-1):
+            retFuture.complete()
+          else:
+            retFuture.fail(newException(OSError, osErrorMsg(errcode)))
+    )
+
+    let ret = WSASendTo(socket.SocketHandle, addr dataBuf, 1, addr bytesSent,
+                        lowFlags, cast[ptr SockAddr](addr(staddr[0])),
+                        stalen, cast[POVERLAPPED](ol), nil)
+    if ret == -1:
+      let err = osLastError()
+      if err.int32 != ERROR_IO_PENDING:
+        GC_unref(ol)
+        retFuture.fail(newException(OSError, osErrorMsg(err)))
+    else:
+      retFuture.complete()
+      # We don't deallocate ``ol`` here because even though this completed
+      # immediately poll will still be notified about its completion and it will
+      # free ``ol``.
+    return retFuture
+
+  proc recvFromInto*(socket: AsyncFD, data: pointer, size: int,
+                     saddr: ptr SockAddr, saddrLen: ptr SockLen,
+                     flags = {SocketFlag.SafeDisconn}): Future[int] =
+    ## Receives a datagram data from ``socket`` into ``buf``, which must
+    ## be at least of size ``size``, address of datagram's sender will be
+    ## stored into ``saddr`` and ``saddrLen``. Returned future will complete
+    ## once one datagram has been received, and will return size of packet
+    ## received.
+    verifyPresence(socket)
+    var retFuture = newFuture[int]("recvFromInto")
+
+    var dataBuf = TWSABuf(buf: cast[cstring](data), len: size.ULONG)
+
+    var bytesReceived = 0.Dword
+    var lowFlags = 0.Dword
+
+    var ol = PCustomOverlapped()
+    GC_ref(ol)
+    ol.data = CompletionData(fd: socket, cb:
+      proc (fd: AsyncFD, bytesCount: Dword, errcode: OSErrorCode) =
+        if not retFuture.finished:
+          if errcode == OSErrorCode(-1):
+            assert bytesCount <= size
+            retFuture.complete(bytesCount)
+          else:
+            # datagram sockets don't have disconnection,
+            # so we can just raise an exception
+            retFuture.fail(newException(OSError, osErrorMsg(errcode)))
+    )
+
+    let res = WSARecvFrom(socket.SocketHandle, addr dataBuf, 1,
+                          addr bytesReceived, addr lowFlags,
+                          saddr, cast[ptr cint](saddrLen),
+                          cast[POVERLAPPED](ol), nil)
+    if res == -1:
+      let err = osLastError()
+      if err.int32 != ERROR_IO_PENDING:
+        GC_unref(ol)
+        retFuture.fail(newException(OSError, osErrorMsg(err)))
+    else:
+      # Request completed immediately.
+      if bytesReceived != 0:
+        assert bytesReceived <= size
+        retFuture.complete(bytesReceived)
+      else:
+        if hasOverlappedIoCompleted(cast[POVERLAPPED](ol)):
+          retFuture.complete(bytesReceived)
+    return retFuture
+
   proc acceptAddr*(socket: AsyncFD, flags = {SocketFlag.SafeDisconn}):
       Future[tuple[address: string, client: AsyncFD]] =
     ## Accepts a new connection. Returns a future containing the client socket
@@ -1026,6 +1133,10 @@ when defined(windows) or defined(nimdoc):
               # poll()
               GC_ref(pcd.ovl)
     )
+    # We need to protect our callback environment value, so GC will not free it
+    # accidentally.
+    ol.data.cell = system.protect(rawEnv(ol.data.cb))
+
     # This is main part of `hacky way` is using WSAEventSelect, so `hEvent`
     # will be signaled when appropriate `mask` events will be triggered.
     if wsaEventSelect(fd.SocketHandle, hEvent, mask) != 0:
@@ -1343,6 +1454,60 @@ else:
     addWrite(socket, cb)
     return retFuture
 
+  proc sendTo*(socket: AsyncFD, data: pointer, size: int, saddr: ptr SockAddr,
+               saddrLen: SockLen,
+               flags = {SocketFlag.SafeDisconn}): Future[void] =
+    ## Sends ``data`` of size ``size`` in bytes to specified destination
+    ## (``saddr`` of size ``saddrLen`` in bytes, using socket ``socket``.
+    ## The returned future will complete once all data has been sent.
+    var retFuture = newFuture[void]("sendTo")
+
+    # we will preserve address in our stack
+    var staddr: array[128, char] # SOCKADDR_STORAGE size is 128 bytes
+    var stalen = saddrLen
+    zeroMem(addr(staddr[0]), 128)
+    copyMem(addr(staddr[0]), saddr, saddrLen)
+
+    proc cb(sock: AsyncFD): bool =
+      result = true
+      let res = sendto(sock.SocketHandle, data, size, MSG_NOSIGNAL,
+                       cast[ptr SockAddr](addr(staddr[0])), stalen)
+      if res < 0:
+        let lastError = osLastError()
+        if lastError.int32 notin {EINTR, EWOULDBLOCK, EAGAIN}:
+          retFuture.fail(newException(OSError, osErrorMsg(lastError)))
+        else:
+          result = false # We still want this callback to be called.
+      else:
+        retFuture.complete()
+
+    addWrite(socket, cb)
+    return retFuture
+
+  proc recvFromInto*(socket: AsyncFD, data: pointer, size: int,
+                     saddr: ptr SockAddr, saddrLen: ptr SockLen,
+                     flags = {SocketFlag.SafeDisconn}): Future[int] =
+    ## Receives a datagram data from ``socket`` into ``data``, which must
+    ## be at least of size ``size`` in bytes, address of datagram's sender
+    ## will be stored into ``saddr`` and ``saddrLen``. Returned future will
+    ## complete once one datagram has been received, and will return size
+    ## of packet received.
+    var retFuture = newFuture[int]("recvFromInto")
+    proc cb(sock: AsyncFD): bool =
+      result = true
+      let res = recvfrom(sock.SocketHandle, data, size.cint, flags.toOSFlags(),
+                         saddr, saddrLen)
+      if res < 0:
+        let lastError = osLastError()
+        if lastError.int32 notin {EINTR, EWOULDBLOCK, EAGAIN}:
+          retFuture.fail(newException(OSError, osErrorMsg(lastError)))
+        else:
+          result = false
+      else:
+        retFuture.complete(res)
+    addRead(socket, cb)
+    return retFuture
+
   proc acceptAddr*(socket: AsyncFD, flags = {SocketFlag.SafeDisconn}):
       Future[tuple[address: string, client: AsyncFD]] =
     var retFuture = newFuture[tuple[address: string,
@@ -1377,6 +1542,24 @@ proc sleepAsync*(ms: int): Future[void] =
   p.timers.push((epochTime() + (ms / 1000), retFuture))
   return retFuture
 
+proc withTimeout*[T](fut: Future[T], timeout: int): Future[bool] =
+  ## Returns a future which will complete once ``fut`` completes or after
+  ## ``timeout`` milliseconds has elapsed.
+  ##
+  ## If ``fut`` completes first the returned future will hold true,
+  ## otherwise, if ``timeout`` milliseconds has elapsed first, the returned
+  ## future will hold false.
+
+  var retFuture = newFuture[bool]("asyncdispatch.`withTimeout`")
+  var timeoutFuture = sleepAsync(timeout)
+  fut.callback =
+    proc () =
+      if not retFuture.finished: retFuture.complete(true)
+  timeoutFuture.callback =
+    proc () =
+      if not retFuture.finished: retFuture.complete(false)
+  return retFuture
+
 proc accept*(socket: AsyncFD,
     flags = {SocketFlag.SafeDisconn}): Future[AsyncFD] =
   ## Accepts a new connection. Returns a future containing the client socket
@@ -1516,15 +1699,17 @@ proc processBody(node, retFutureSym: NimNode,
       else:
         result.add newCall(newIdentNode("complete"), retFutureSym)
     else:
-      result.add newCall(newIdentNode("complete"), retFutureSym,
-        node[0].processBody(retFutureSym, subTypeIsVoid, tryStmt))
+      let x = node[0].processBody(retFutureSym, subTypeIsVoid, tryStmt)
+      if x.kind == nnkYieldStmt: result.add x
+      else:
+        result.add newCall(newIdentNode("complete"), retFutureSym, x)
 
     result.add newNimNode(nnkReturnStmt, node).add(newNilLit())
     return # Don't process the children of this return stmt
   of nnkCommand, nnkCall:
     if node[0].kind == nnkIdent and node[0].ident == !"await":
       case node[1].kind
-      of nnkIdent, nnkInfix:
+      of nnkIdent, nnkInfix, nnkDotExpr:
         # await x
         # await x or y
         result = newNimNode(nnkYieldStmt, node).add(node[1]) # -> yield x
@@ -1687,38 +1872,40 @@ proc asyncSingleProc(prc: NimNode): NimNode {.compileTime.} =
   # ->   complete(retFuture, result)
   var iteratorNameSym = genSym(nskIterator, $prc[0].getName & "Iter")
   var procBody = prc[6].processBody(retFutureSym, subtypeIsVoid, nil)
-  if not subtypeIsVoid:
-    procBody.insert(0, newNimNode(nnkPragma).add(newIdentNode("push"),
-      newNimNode(nnkExprColonExpr).add(newNimNode(nnkBracketExpr).add(
-        newIdentNode("warning"), newIdentNode("resultshadowed")),
-      newIdentNode("off")))) # -> {.push warning[resultshadowed]: off.}
-
-    procBody.insert(1, newNimNode(nnkVarSection, prc[6]).add(
-      newIdentDefs(newIdentNode("result"), baseType))) # -> var result: T
-
-    procBody.insert(2, newNimNode(nnkPragma).add(
-      newIdentNode("pop"))) # -> {.pop.})
-
-    procBody.add(
-      newCall(newIdentNode("complete"),
-        retFutureSym, newIdentNode("result"))) # -> complete(retFuture, result)
-  else:
-    # -> complete(retFuture)
-    procBody.add(newCall(newIdentNode("complete"), retFutureSym))
+  # don't do anything with forward bodies (empty)
+  if procBody.kind != nnkEmpty:
+    if not subtypeIsVoid:
+      procBody.insert(0, newNimNode(nnkPragma).add(newIdentNode("push"),
+        newNimNode(nnkExprColonExpr).add(newNimNode(nnkBracketExpr).add(
+          newIdentNode("warning"), newIdentNode("resultshadowed")),
+        newIdentNode("off")))) # -> {.push warning[resultshadowed]: off.}
+
+      procBody.insert(1, newNimNode(nnkVarSection, prc[6]).add(
+        newIdentDefs(newIdentNode("result"), baseType))) # -> var result: T
+
+      procBody.insert(2, newNimNode(nnkPragma).add(
+        newIdentNode("pop"))) # -> {.pop.})
+
+      procBody.add(
+        newCall(newIdentNode("complete"),
+          retFutureSym, newIdentNode("result"))) # -> complete(retFuture, result)
+    else:
+      # -> complete(retFuture)
+      procBody.add(newCall(newIdentNode("complete"), retFutureSym))
 
-  var closureIterator = newProc(iteratorNameSym, [newIdentNode("FutureBase")],
-                                procBody, nnkIteratorDef)
-  closureIterator[4] = newNimNode(nnkPragma, prc[6]).add(newIdentNode("closure"))
-  outerProcBody.add(closureIterator)
+    var closureIterator = newProc(iteratorNameSym, [newIdentNode("FutureBase")],
+                                  procBody, nnkIteratorDef)
+    closureIterator[4] = newNimNode(nnkPragma, prc[6]).add(newIdentNode("closure"))
+    outerProcBody.add(closureIterator)
 
-  # -> createCb(retFuture)
-  #var cbName = newIdentNode("cb")
-  var procCb = newCall(bindSym"createCb", retFutureSym, iteratorNameSym,
-                       newStrLitNode(prc[0].getName))
-  outerProcBody.add procCb
+    # -> createCb(retFuture)
+    #var cbName = newIdentNode("cb")
+    var procCb = getAst createCb(retFutureSym, iteratorNameSym,
+                         newStrLitNode(prc[0].getName))
+    outerProcBody.add procCb
 
-  # -> return retFuture
-  outerProcBody.add newNimNode(nnkReturnStmt, prc[6][prc[6].len-1]).add(retFutureSym)
+    # -> return retFuture
+    outerProcBody.add newNimNode(nnkReturnStmt, prc[6][prc[6].len-1]).add(retFutureSym)
 
   result = prc
 
@@ -1732,9 +1919,8 @@ proc asyncSingleProc(prc: NimNode): NimNode {.compileTime.} =
     if returnType.kind == nnkEmpty:
       # Add Future[void]
       result[3][0] = parseExpr("Future[void]")
-
-  result[6] = outerProcBody
-
+  if procBody.kind != nnkEmpty:
+    result[6] = outerProcBody
   #echo(treeRepr(result))
   #if prc[0].getName == "testInfix":
   #  echo(toStrLit(result))
@@ -1748,6 +1934,8 @@ macro async*(prc: stmt): stmt {.immediate.} =
       result.add asyncSingleProc(oneProc)
   else:
     result = asyncSingleProc(prc)
+  when defined(nimDumpAsync):
+    echo repr result
 
 proc recvLine*(socket: AsyncFD): Future[string] {.async.} =
   ## Reads a line of data from ``socket``. Returned future will complete once