diff options
-rw-r--r-- | lib/pure/asyncdispatch.nim | 117 | ||||
-rw-r--r-- | lib/pure/asyncfutures.nim | 12 | ||||
-rw-r--r-- | lib/pure/includes/oserr.nim | 15 | ||||
-rw-r--r-- | tests/async/tasyncclosestall.nim | 99 |
4 files changed, 191 insertions, 52 deletions
diff --git a/lib/pure/asyncdispatch.nim b/lib/pure/asyncdispatch.nim index ea35a444d..382d9d44a 100644 --- a/lib/pure/asyncdispatch.nim +++ b/lib/pure/asyncdispatch.nim @@ -568,7 +568,7 @@ when defined(windows) or defined(nimdoc): if flags.isDisconnectionError(errcode): retFuture.complete() else: - retFuture.fail(newException(OSError, osErrorMsg(errcode))) + retFuture.fail(newOSError(errcode)) ) let ret = WSASend(socket.SocketHandle, addr dataBuf, 1, addr bytesReceived, @@ -1134,11 +1134,6 @@ else: var data = newAsyncData() p.selector.registerHandle(fd.SocketHandle, {}, data) - proc closeSocket*(sock: AsyncFD) = - let disp = getGlobalDispatcher() - disp.selector.unregister(sock.SocketHandle) - sock.SocketHandle.close() - proc unregister*(fd: AsyncFD) = getGlobalDispatcher().selector.unregister(fd.SocketHandle) @@ -1174,7 +1169,9 @@ else: let p = getGlobalDispatcher() not p.selector.isEmpty() or p.timers.len != 0 or p.callbacks.len != 0 - template processBasicCallbacks(ident, rwlist: untyped) = + proc processBasicCallbacks( + fd: AsyncFD, event: Event + ): tuple[readCbListCount, writeCbListCount: int] = # Process pending descriptor and AsyncEvent callbacks. # # Invoke every callback stored in `rwlist`, until one @@ -1187,32 +1184,46 @@ else: # or it can be possible to fall into endless cycle. var curList: seq[Callback] - withData(p.selector, ident, adata) do: - shallowCopy(curList, adata.rwlist) - adata.rwlist = newSeqOfCap[Callback](InitCallbackListSize) + let selector = getGlobalDispatcher().selector + withData(selector, fd.int, fdData): + case event + of Event.Read: + shallowCopy(curList, fdData.readList) + fdData.readList = newSeqOfCap[Callback](InitCallbackListSize) + of Event.Write: + shallowCopy(curList, fdData.writeList) + fdData.writeList = newSeqOfCap[Callback](InitCallbackListSize) + else: + assert false, "Cannot process callbacks for " & $event let newLength = max(len(curList), InitCallbackListSize) var newList = newSeqOfCap[Callback](newLength) for cb in curList: - if len(newList) > 0: - # A callback has already returned with EAGAIN, don't call any others - # until next `poll`. + if not cb(fd): + # Callback wants to be called again. newList.add(cb) + # This callback has returned with EAGAIN, so we don't need to + # call any other callbacks as they are all waiting for the same event + # on the same fd. + break + + withData(selector, fd.int, fdData) do: + # Descriptor is still present in the queue. + case event + of Event.Read: + fdData.readList = newList & fdData.readList + of Event.Write: + fdData.writeList = newList & fdData.writeList else: - if not cb(fd.AsyncFD): - # Callback wants to be called again. - newList.add(cb) + assert false, "Cannot process callbacks for " & $event - withData(p.selector, ident, adata) do: - # descriptor still present in queue. - adata.rwlist = newList & adata.rwlist - rLength = len(adata.readList) - wLength = len(adata.writeList) + result.readCbListCount = len(fdData.readList) + result.writeCbListCount = len(fdData.writeList) do: - # descriptor was unregistered in callback via `unregister()`. - rLength = -1 - wLength = -1 + # Descriptor was unregistered in callback via `unregister()`. + result.readCbListCount = -1 + result.writeCbListCount = -1 template processCustomCallbacks(ident: untyped) = # Process pending custom event callbacks. Custom events are @@ -1221,7 +1232,7 @@ else: # so there is no need to iterate over list. var curList: seq[Callback] - withData(p.selector, ident, adata) do: + withData(p.selector, ident.int, adata) do: shallowCopy(curList, adata.readList) adata.readList = newSeqOfCap[Callback](InitCallbackListSize) @@ -1232,16 +1243,33 @@ else: if not cb(fd.AsyncFD): newList.add(cb) - withData(p.selector, ident, adata) do: + withData(p.selector, ident.int, adata) do: # descriptor still present in queue. adata.readList = newList & adata.readList if len(adata.readList) == 0: # if no callbacks registered with descriptor, unregister it. - p.selector.unregister(fd) + p.selector.unregister(fd.int) do: # descriptor was unregistered in callback via `unregister()`. discard + proc closeSocket*(sock: AsyncFD) = + let selector = getGlobalDispatcher().selector + if sock.SocketHandle notin selector: + raise newException(ValueError, "File descriptor not registered.") + + let data = selector.getData(sock.SocketHandle) + sock.unregister() + sock.SocketHandle.close() + # We need to unblock the read and write callbacks which could still be + # waiting for the socket to become readable and/or writeable. + for cb in data.readList & data.writeList: + if not cb(sock): + raise newException( + ValueError, "Expecting async operations to stop when fd has closed." + ) + + proc runOnce(timeout = 500): bool = let p = getGlobalDispatcher() when ioselSupportedPlatform: @@ -1257,41 +1285,42 @@ else: let nextTimer = processTimers(p, result) var count = p.selector.selectInto(adjustTimeout(timeout, nextTimer), keys) for i in 0..<count: - var custom = false - let fd = keys[i].fd + let fd = keys[i].fd.AsyncFD let events = keys[i].events - var rLength = 0 # len(data.readList) after callback - var wLength = 0 # len(data.writeList) after callback + var (readCbListCount, writeCbListCount) = (0, 0) if Event.Read in events or events == {Event.Error}: - processBasicCallbacks(fd, readList) + (readCbListCount, writeCbListCount) = + processBasicCallbacks(fd, Event.Read) result = true if Event.Write in events or events == {Event.Error}: - processBasicCallbacks(fd, writeList) + (readCbListCount, writeCbListCount) = + processBasicCallbacks(fd, Event.Write) result = true + var isCustomEvent = false if Event.User in events: - processBasicCallbacks(fd, readList) - custom = true - if rLength == 0: - p.selector.unregister(fd) + (readCbListCount, writeCbListCount) = + processBasicCallbacks(fd, Event.Read) + isCustomEvent = true + if readCbListCount == 0: + p.selector.unregister(fd.int) result = true when ioselSupportedPlatform: if (customSet * events) != {}: - custom = true + isCustomEvent = true processCustomCallbacks(fd) result = true # because state `data` can be modified in callback we need to update # descriptor events with currently registered callbacks. - if not custom: + if not isCustomEvent and (readCbListCount != -1 and writeCbListCount != -1): var newEvents: set[Event] = {} - if rLength != -1 and wLength != -1: - if rLength > 0: incl(newEvents, Event.Read) - if wLength > 0: incl(newEvents, Event.Write) - p.selector.updateHandle(SocketHandle(fd), newEvents) + if readCbListCount > 0: incl(newEvents, Event.Read) + if writeCbListCount > 0: incl(newEvents, Event.Write) + p.selector.updateHandle(SocketHandle(fd), newEvents) # Timer processing. discard processTimers(p, result) @@ -1370,7 +1399,7 @@ else: if flags.isDisconnectionError(lastError): retFuture.complete() else: - retFuture.fail(newException(OSError, osErrorMsg(lastError))) + retFuture.fail(newOSError(lastError)) else: result = false # We still want this callback to be called. else: diff --git a/lib/pure/asyncfutures.nim b/lib/pure/asyncfutures.nim index 9af72f8b3..e86a34d81 100644 --- a/lib/pure/asyncfutures.nim +++ b/lib/pure/asyncfutures.nim @@ -393,11 +393,13 @@ proc asyncCheck*[T](future: Future[T]) = ## This should be used instead of ``discard`` to discard void futures, ## or use ``waitFor`` if you need to wait for the future's completion. assert(not future.isNil, "Future is nil") - future.callback = - proc () = - if future.failed: - injectStacktrace(future) - raise future.error + # TODO: We can likely look at the stack trace here and inject the location + # where the `asyncCheck` was called to give a better error stack message. + proc asyncCheckCallback() = + if future.failed: + injectStacktrace(future) + raise future.error + future.callback = asyncCheckCallback proc `and`*[T, Y](fut1: Future[T], fut2: Future[Y]): Future[void] = ## Returns a future which will complete once both ``fut1`` and ``fut2`` diff --git a/lib/pure/includes/oserr.nim b/lib/pure/includes/oserr.nim index 68ce5d95f..34d8d0085 100644 --- a/lib/pure/includes/oserr.nim +++ b/lib/pure/includes/oserr.nim @@ -57,8 +57,10 @@ proc osErrorMsg*(errorCode: OSErrorCode): string = if errorCode != OSErrorCode(0'i32): result = $c_strerror(errorCode.int32) -proc raiseOSError*(errorCode: OSErrorCode; additionalInfo = "") {.noinline.} = - ## Raises an `OSError exception <system.html#OSError>`_. +proc newOSError*( + errorCode: OSErrorCode, additionalInfo = "" +): owned(ref OSError) {.noinline.} = + ## Creates a new `OSError exception <system.html#OSError>`_. ## ## The ``errorCode`` will determine the ## message, `osErrorMsg proc <#osErrorMsg,OSErrorCode>`_ will be used @@ -82,7 +84,14 @@ proc raiseOSError*(errorCode: OSErrorCode; additionalInfo = "") {.noinline.} = e.msg.addQuoted additionalInfo if e.msg == "": e.msg = "unknown OS error" - raise e + return e + +proc raiseOSError*(errorCode: OSErrorCode, additionalInfo = "") {.noinline.} = + ## Raises an `OSError exception <system.html#OSError>`_. + ## + ## Read the description of the `newOSError proc <#newOSError,OSErrorCode,string>`_ to learn + ## how the exception object is created. + raise newOSError(errorCode, additionalInfo) {.push stackTrace:off.} proc osLastError*(): OSErrorCode {.sideEffect.} = diff --git a/tests/async/tasyncclosestall.nim b/tests/async/tasyncclosestall.nim new file mode 100644 index 000000000..e10e23074 --- /dev/null +++ b/tests/async/tasyncclosestall.nim @@ -0,0 +1,99 @@ +discard """ + outputsub: "send has errored. As expected. All good!" + exitcode: 0 +""" +import asyncdispatch, asyncnet + +when defined(windows): + from winlean import ERROR_NETNAME_DELETED +else: + from posix import EBADF + +# This reproduces a case where a socket remains stuck waiting for writes +# even when the socket is closed. +const + port = Port(50726) + timeout = 5000 + +var sent = 0 + +proc keepSendingTo(c: AsyncSocket) {.async.} = + while true: + # This write will eventually get stuck because the client is not reading + # its messages. + let sendFut = c.send("Foobar" & $sent & "\n", flags = {}) + if not await withTimeout(sendFut, timeout): + # The write is stuck. Let's simulate a scenario where the socket + # does not respond to PING messages, and we close it. The above future + # should complete after the socket is closed, not continue stalling. + echo("Socket has stalled, closing it") + c.close() + + let timeoutFut = withTimeout(sendFut, timeout) + yield timeoutFut + if timeoutFut.failed: + let errCode = ((ref OSError)(timeoutFut.error)).errorCode + # The behaviour differs across platforms. On Windows ERROR_NETNAME_DELETED + # is raised which we classif as a "diconnection error", hence we overwrite + # the flags above in the `send` call so that this error is raised. + # + # On Linux the EBADF error code is raised, this is because the socket + # is closed. + # + # This means that by default the behaviours will differ between Windows + # and Linux. I think this is fine though, it makes sense mainly because + # Windows doesn't use a IO readiness model. We can fix this later if + # necessary to reclassify ERROR_NETNAME_DELETED as not a "disconnection + # error" (TODO) + when defined(windows): + if errCode == ERROR_NETNAME_DELETED: + echo("send has errored. As expected. All good!") + quit(QuitSuccess) + else: + raise newException(ValueError, "Test failed. Send failed with code " & $errCode) + else: + if errCode == EBADF: + echo("send has errored. As expected. All good!") + quit(QuitSuccess) + else: + raise newException(ValueError, "Test failed. Send failed with code " & $errCode) + + # The write shouldn't succeed and also shouldn't be stalled. + if timeoutFut.read(): + raise newException(ValueError, "Test failed. Send was expected to fail.") + else: + raise newException(ValueError, "Test failed. Send future is still stalled.") + sent.inc(1) + +proc startClient() {.async.} = + let client = newAsyncSocket() + await client.connect("localhost", port) + echo("Connected") + + let firstLine = await client.recvLine() + echo("Received first line as a client: ", firstLine) + echo("Now not reading anymore") + while true: await sleepAsync(1000) + +proc debug() {.async.} = + while true: + echo("Sent ", sent) + await sleepAsync(1000) + +proc server() {.async.} = + var s = newAsyncSocket() + s.setSockOpt(OptReuseAddr, true) + s.bindAddr(port) + s.listen() + + # We're now ready to accept connections, so start the client + asyncCheck startClient() + asyncCheck debug() + + while true: + let client = await accept(s) + asyncCheck keepSendingTo(client) + +when isMainModule: + waitFor server() + |