summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--lib/pure/asyncdispatch.nim117
-rw-r--r--lib/pure/asyncfutures.nim12
-rw-r--r--lib/pure/includes/oserr.nim15
-rw-r--r--tests/async/tasyncclosestall.nim99
4 files changed, 191 insertions, 52 deletions
diff --git a/lib/pure/asyncdispatch.nim b/lib/pure/asyncdispatch.nim
index ea35a444d..382d9d44a 100644
--- a/lib/pure/asyncdispatch.nim
+++ b/lib/pure/asyncdispatch.nim
@@ -568,7 +568,7 @@ when defined(windows) or defined(nimdoc):
             if flags.isDisconnectionError(errcode):
               retFuture.complete()
             else:
-              retFuture.fail(newException(OSError, osErrorMsg(errcode)))
+              retFuture.fail(newOSError(errcode))
     )
 
     let ret = WSASend(socket.SocketHandle, addr dataBuf, 1, addr bytesReceived,
@@ -1134,11 +1134,6 @@ else:
     var data = newAsyncData()
     p.selector.registerHandle(fd.SocketHandle, {}, data)
 
-  proc closeSocket*(sock: AsyncFD) =
-    let disp = getGlobalDispatcher()
-    disp.selector.unregister(sock.SocketHandle)
-    sock.SocketHandle.close()
-
   proc unregister*(fd: AsyncFD) =
     getGlobalDispatcher().selector.unregister(fd.SocketHandle)
 
@@ -1174,7 +1169,9 @@ else:
     let p = getGlobalDispatcher()
     not p.selector.isEmpty() or p.timers.len != 0 or p.callbacks.len != 0
 
-  template processBasicCallbacks(ident, rwlist: untyped) =
+  proc processBasicCallbacks(
+    fd: AsyncFD, event: Event
+  ): tuple[readCbListCount, writeCbListCount: int] =
     # Process pending descriptor and AsyncEvent callbacks.
     #
     # Invoke every callback stored in `rwlist`, until one
@@ -1187,32 +1184,46 @@ else:
     # or it can be possible to fall into endless cycle.
     var curList: seq[Callback]
 
-    withData(p.selector, ident, adata) do:
-      shallowCopy(curList, adata.rwlist)
-      adata.rwlist = newSeqOfCap[Callback](InitCallbackListSize)
+    let selector = getGlobalDispatcher().selector
+    withData(selector, fd.int, fdData):
+      case event
+      of Event.Read:
+        shallowCopy(curList, fdData.readList)
+        fdData.readList = newSeqOfCap[Callback](InitCallbackListSize)
+      of Event.Write:
+        shallowCopy(curList, fdData.writeList)
+        fdData.writeList = newSeqOfCap[Callback](InitCallbackListSize)
+      else:
+        assert false, "Cannot process callbacks for " & $event
 
     let newLength = max(len(curList), InitCallbackListSize)
     var newList = newSeqOfCap[Callback](newLength)
 
     for cb in curList:
-      if len(newList) > 0:
-        # A callback has already returned with EAGAIN, don't call any others
-        # until next `poll`.
+      if not cb(fd):
+        # Callback wants to be called again.
         newList.add(cb)
+        # This callback has returned with EAGAIN, so we don't need to
+        # call any other callbacks as they are all waiting for the same event
+        # on the same fd.
+        break
+
+    withData(selector, fd.int, fdData) do:
+      # Descriptor is still present in the queue.
+      case event
+      of Event.Read:
+        fdData.readList = newList & fdData.readList
+      of Event.Write:
+        fdData.writeList = newList & fdData.writeList
       else:
-        if not cb(fd.AsyncFD):
-          # Callback wants to be called again.
-          newList.add(cb)
+        assert false, "Cannot process callbacks for " & $event
 
-    withData(p.selector, ident, adata) do:
-      # descriptor still present in queue.
-      adata.rwlist = newList & adata.rwlist
-      rLength = len(adata.readList)
-      wLength = len(adata.writeList)
+      result.readCbListCount = len(fdData.readList)
+      result.writeCbListCount = len(fdData.writeList)
     do:
-      # descriptor was unregistered in callback via `unregister()`.
-      rLength = -1
-      wLength = -1
+      # Descriptor was unregistered in callback via `unregister()`.
+      result.readCbListCount = -1
+      result.writeCbListCount = -1
 
   template processCustomCallbacks(ident: untyped) =
     # Process pending custom event callbacks. Custom events are
@@ -1221,7 +1232,7 @@ else:
     # so there is no need to iterate over list.
     var curList: seq[Callback]
 
-    withData(p.selector, ident, adata) do:
+    withData(p.selector, ident.int, adata) do:
       shallowCopy(curList, adata.readList)
       adata.readList = newSeqOfCap[Callback](InitCallbackListSize)
 
@@ -1232,16 +1243,33 @@ else:
     if not cb(fd.AsyncFD):
       newList.add(cb)
 
-    withData(p.selector, ident, adata) do:
+    withData(p.selector, ident.int, adata) do:
       # descriptor still present in queue.
       adata.readList = newList & adata.readList
       if len(adata.readList) == 0:
         # if no callbacks registered with descriptor, unregister it.
-        p.selector.unregister(fd)
+        p.selector.unregister(fd.int)
     do:
       # descriptor was unregistered in callback via `unregister()`.
       discard
 
+  proc closeSocket*(sock: AsyncFD) =
+    let selector = getGlobalDispatcher().selector
+    if sock.SocketHandle notin selector:
+      raise newException(ValueError, "File descriptor not registered.")
+
+    let data = selector.getData(sock.SocketHandle)
+    sock.unregister()
+    sock.SocketHandle.close()
+    # We need to unblock the read and write callbacks which could still be
+    # waiting for the socket to become readable and/or writeable.
+    for cb in data.readList & data.writeList:
+      if not cb(sock):
+        raise newException(
+          ValueError, "Expecting async operations to stop when fd has closed."
+        )
+
+
   proc runOnce(timeout = 500): bool =
     let p = getGlobalDispatcher()
     when ioselSupportedPlatform:
@@ -1257,41 +1285,42 @@ else:
     let nextTimer = processTimers(p, result)
     var count = p.selector.selectInto(adjustTimeout(timeout, nextTimer), keys)
     for i in 0..<count:
-      var custom = false
-      let fd = keys[i].fd
+      let fd = keys[i].fd.AsyncFD
       let events = keys[i].events
-      var rLength = 0 # len(data.readList) after callback
-      var wLength = 0 # len(data.writeList) after callback
+      var (readCbListCount, writeCbListCount) = (0, 0)
 
       if Event.Read in events or events == {Event.Error}:
-        processBasicCallbacks(fd, readList)
+        (readCbListCount, writeCbListCount) =
+          processBasicCallbacks(fd, Event.Read)
         result = true
 
       if Event.Write in events or events == {Event.Error}:
-        processBasicCallbacks(fd, writeList)
+        (readCbListCount, writeCbListCount) =
+          processBasicCallbacks(fd, Event.Write)
         result = true
 
+      var isCustomEvent = false
       if Event.User in events:
-        processBasicCallbacks(fd, readList)
-        custom = true
-        if rLength == 0:
-          p.selector.unregister(fd)
+        (readCbListCount, writeCbListCount) =
+          processBasicCallbacks(fd, Event.Read)
+        isCustomEvent = true
+        if readCbListCount == 0:
+          p.selector.unregister(fd.int)
         result = true
 
       when ioselSupportedPlatform:
         if (customSet * events) != {}:
-          custom = true
+          isCustomEvent = true
           processCustomCallbacks(fd)
           result = true
 
       # because state `data` can be modified in callback we need to update
       # descriptor events with currently registered callbacks.
-      if not custom:
+      if not isCustomEvent and (readCbListCount != -1 and writeCbListCount != -1):
         var newEvents: set[Event] = {}
-        if rLength != -1 and wLength != -1:
-          if rLength > 0: incl(newEvents, Event.Read)
-          if wLength > 0: incl(newEvents, Event.Write)
-          p.selector.updateHandle(SocketHandle(fd), newEvents)
+        if readCbListCount > 0: incl(newEvents, Event.Read)
+        if writeCbListCount > 0: incl(newEvents, Event.Write)
+        p.selector.updateHandle(SocketHandle(fd), newEvents)
 
     # Timer processing.
     discard processTimers(p, result)
@@ -1370,7 +1399,7 @@ else:
           if flags.isDisconnectionError(lastError):
             retFuture.complete()
           else:
-            retFuture.fail(newException(OSError, osErrorMsg(lastError)))
+            retFuture.fail(newOSError(lastError))
         else:
           result = false # We still want this callback to be called.
       else:
diff --git a/lib/pure/asyncfutures.nim b/lib/pure/asyncfutures.nim
index 9af72f8b3..e86a34d81 100644
--- a/lib/pure/asyncfutures.nim
+++ b/lib/pure/asyncfutures.nim
@@ -393,11 +393,13 @@ proc asyncCheck*[T](future: Future[T]) =
   ## This should be used instead of ``discard`` to discard void futures,
   ## or use ``waitFor`` if you need to wait for the future's completion.
   assert(not future.isNil, "Future is nil")
-  future.callback =
-    proc () =
-      if future.failed:
-        injectStacktrace(future)
-        raise future.error
+  # TODO: We can likely look at the stack trace here and inject the location
+  # where the `asyncCheck` was called to give a better error stack message.
+  proc asyncCheckCallback() =
+    if future.failed:
+      injectStacktrace(future)
+      raise future.error
+  future.callback = asyncCheckCallback
 
 proc `and`*[T, Y](fut1: Future[T], fut2: Future[Y]): Future[void] =
   ## Returns a future which will complete once both ``fut1`` and ``fut2``
diff --git a/lib/pure/includes/oserr.nim b/lib/pure/includes/oserr.nim
index 68ce5d95f..34d8d0085 100644
--- a/lib/pure/includes/oserr.nim
+++ b/lib/pure/includes/oserr.nim
@@ -57,8 +57,10 @@ proc osErrorMsg*(errorCode: OSErrorCode): string =
     if errorCode != OSErrorCode(0'i32):
       result = $c_strerror(errorCode.int32)
 
-proc raiseOSError*(errorCode: OSErrorCode; additionalInfo = "") {.noinline.} =
-  ## Raises an `OSError exception <system.html#OSError>`_.
+proc newOSError*(
+  errorCode: OSErrorCode, additionalInfo = ""
+): owned(ref OSError) {.noinline.} =
+  ## Creates a new `OSError exception <system.html#OSError>`_.
   ##
   ## The ``errorCode`` will determine the
   ## message, `osErrorMsg proc <#osErrorMsg,OSErrorCode>`_ will be used
@@ -82,7 +84,14 @@ proc raiseOSError*(errorCode: OSErrorCode; additionalInfo = "") {.noinline.} =
     e.msg.addQuoted additionalInfo
   if e.msg == "":
     e.msg = "unknown OS error"
-  raise e
+  return e
+
+proc raiseOSError*(errorCode: OSErrorCode, additionalInfo = "") {.noinline.} =
+  ## Raises an `OSError exception <system.html#OSError>`_.
+  ##
+  ## Read the description of the `newOSError proc <#newOSError,OSErrorCode,string>`_ to learn
+  ## how the exception object is created.
+  raise newOSError(errorCode, additionalInfo)
 
 {.push stackTrace:off.}
 proc osLastError*(): OSErrorCode {.sideEffect.} =
diff --git a/tests/async/tasyncclosestall.nim b/tests/async/tasyncclosestall.nim
new file mode 100644
index 000000000..e10e23074
--- /dev/null
+++ b/tests/async/tasyncclosestall.nim
@@ -0,0 +1,99 @@
+discard """
+  outputsub: "send has errored. As expected. All good!"
+  exitcode: 0
+"""
+import asyncdispatch, asyncnet
+
+when defined(windows):
+  from winlean import ERROR_NETNAME_DELETED
+else:
+  from posix import EBADF
+
+# This reproduces a case where a socket remains stuck waiting for writes
+# even when the socket is closed.
+const
+  port = Port(50726)
+  timeout = 5000
+
+var sent = 0
+
+proc keepSendingTo(c: AsyncSocket) {.async.} =
+  while true:
+    # This write will eventually get stuck because the client is not reading
+    # its messages.
+    let sendFut = c.send("Foobar" & $sent & "\n", flags = {})
+    if not await withTimeout(sendFut, timeout):
+      # The write is stuck. Let's simulate a scenario where the socket
+      # does not respond to PING messages, and we close it. The above future
+      # should complete after the socket is closed, not continue stalling.
+      echo("Socket has stalled, closing it")
+      c.close()
+
+      let timeoutFut = withTimeout(sendFut, timeout)
+      yield timeoutFut
+      if timeoutFut.failed:
+        let errCode = ((ref OSError)(timeoutFut.error)).errorCode
+        # The behaviour differs across platforms. On Windows ERROR_NETNAME_DELETED
+        # is raised which we classif as a "diconnection error", hence we overwrite
+        # the flags above in the `send` call so that this error is raised.
+        #
+        # On Linux the EBADF error code is raised, this is because the socket
+        # is closed.
+        #
+        # This means that by default the behaviours will differ between Windows
+        # and Linux. I think this is fine though, it makes sense mainly because
+        # Windows doesn't use a IO readiness model. We can fix this later if
+        # necessary to reclassify ERROR_NETNAME_DELETED as not a "disconnection
+        # error" (TODO)
+        when defined(windows):
+          if errCode == ERROR_NETNAME_DELETED:
+            echo("send has errored. As expected. All good!")
+            quit(QuitSuccess)
+          else:
+            raise newException(ValueError, "Test failed. Send failed with code " & $errCode)
+        else:
+          if errCode == EBADF:
+            echo("send has errored. As expected. All good!")
+            quit(QuitSuccess)
+          else:
+            raise newException(ValueError, "Test failed. Send failed with code " & $errCode)
+
+      # The write shouldn't succeed and also shouldn't be stalled.
+      if timeoutFut.read():
+        raise newException(ValueError, "Test failed. Send was expected to fail.")
+      else:
+        raise newException(ValueError, "Test failed. Send future is still stalled.")
+    sent.inc(1)
+
+proc startClient() {.async.} =
+  let client = newAsyncSocket()
+  await client.connect("localhost", port)
+  echo("Connected")
+
+  let firstLine = await client.recvLine()
+  echo("Received first line as a client: ", firstLine)
+  echo("Now not reading anymore")
+  while true: await sleepAsync(1000)
+
+proc debug() {.async.} =
+  while true:
+    echo("Sent ", sent)
+    await sleepAsync(1000)
+
+proc server() {.async.} =
+  var s = newAsyncSocket()
+  s.setSockOpt(OptReuseAddr, true)
+  s.bindAddr(port)
+  s.listen()
+
+  # We're now ready to accept connections, so start the client
+  asyncCheck startClient()
+  asyncCheck debug()
+
+  while true:
+    let client = await accept(s)
+    asyncCheck keepSendingTo(client)
+
+when isMainModule:
+  waitFor server()
+