summary refs log tree commit diff stats
path: root/lib/pure/asyncdispatch.nim
diff options
context:
space:
mode:
Diffstat (limited to 'lib/pure/asyncdispatch.nim')
-rw-r--r--lib/pure/asyncdispatch.nim117
1 files changed, 73 insertions, 44 deletions
diff --git a/lib/pure/asyncdispatch.nim b/lib/pure/asyncdispatch.nim
index ea35a444d..382d9d44a 100644
--- a/lib/pure/asyncdispatch.nim
+++ b/lib/pure/asyncdispatch.nim
@@ -568,7 +568,7 @@ when defined(windows) or defined(nimdoc):
             if flags.isDisconnectionError(errcode):
               retFuture.complete()
             else:
-              retFuture.fail(newException(OSError, osErrorMsg(errcode)))
+              retFuture.fail(newOSError(errcode))
     )
 
     let ret = WSASend(socket.SocketHandle, addr dataBuf, 1, addr bytesReceived,
@@ -1134,11 +1134,6 @@ else:
     var data = newAsyncData()
     p.selector.registerHandle(fd.SocketHandle, {}, data)
 
-  proc closeSocket*(sock: AsyncFD) =
-    let disp = getGlobalDispatcher()
-    disp.selector.unregister(sock.SocketHandle)
-    sock.SocketHandle.close()
-
   proc unregister*(fd: AsyncFD) =
     getGlobalDispatcher().selector.unregister(fd.SocketHandle)
 
@@ -1174,7 +1169,9 @@ else:
     let p = getGlobalDispatcher()
     not p.selector.isEmpty() or p.timers.len != 0 or p.callbacks.len != 0
 
-  template processBasicCallbacks(ident, rwlist: untyped) =
+  proc processBasicCallbacks(
+    fd: AsyncFD, event: Event
+  ): tuple[readCbListCount, writeCbListCount: int] =
     # Process pending descriptor and AsyncEvent callbacks.
     #
     # Invoke every callback stored in `rwlist`, until one
@@ -1187,32 +1184,46 @@ else:
     # or it can be possible to fall into endless cycle.
     var curList: seq[Callback]
 
-    withData(p.selector, ident, adata) do:
-      shallowCopy(curList, adata.rwlist)
-      adata.rwlist = newSeqOfCap[Callback](InitCallbackListSize)
+    let selector = getGlobalDispatcher().selector
+    withData(selector, fd.int, fdData):
+      case event
+      of Event.Read:
+        shallowCopy(curList, fdData.readList)
+        fdData.readList = newSeqOfCap[Callback](InitCallbackListSize)
+      of Event.Write:
+        shallowCopy(curList, fdData.writeList)
+        fdData.writeList = newSeqOfCap[Callback](InitCallbackListSize)
+      else:
+        assert false, "Cannot process callbacks for " & $event
 
     let newLength = max(len(curList), InitCallbackListSize)
     var newList = newSeqOfCap[Callback](newLength)
 
     for cb in curList:
-      if len(newList) > 0:
-        # A callback has already returned with EAGAIN, don't call any others
-        # until next `poll`.
+      if not cb(fd):
+        # Callback wants to be called again.
         newList.add(cb)
+        # This callback has returned with EAGAIN, so we don't need to
+        # call any other callbacks as they are all waiting for the same event
+        # on the same fd.
+        break
+
+    withData(selector, fd.int, fdData) do:
+      # Descriptor is still present in the queue.
+      case event
+      of Event.Read:
+        fdData.readList = newList & fdData.readList
+      of Event.Write:
+        fdData.writeList = newList & fdData.writeList
       else:
-        if not cb(fd.AsyncFD):
-          # Callback wants to be called again.
-          newList.add(cb)
+        assert false, "Cannot process callbacks for " & $event
 
-    withData(p.selector, ident, adata) do:
-      # descriptor still present in queue.
-      adata.rwlist = newList & adata.rwlist
-      rLength = len(adata.readList)
-      wLength = len(adata.writeList)
+      result.readCbListCount = len(fdData.readList)
+      result.writeCbListCount = len(fdData.writeList)
     do:
-      # descriptor was unregistered in callback via `unregister()`.
-      rLength = -1
-      wLength = -1
+      # Descriptor was unregistered in callback via `unregister()`.
+      result.readCbListCount = -1
+      result.writeCbListCount = -1
 
   template processCustomCallbacks(ident: untyped) =
     # Process pending custom event callbacks. Custom events are
@@ -1221,7 +1232,7 @@ else:
     # so there is no need to iterate over list.
     var curList: seq[Callback]
 
-    withData(p.selector, ident, adata) do:
+    withData(p.selector, ident.int, adata) do:
       shallowCopy(curList, adata.readList)
       adata.readList = newSeqOfCap[Callback](InitCallbackListSize)
 
@@ -1232,16 +1243,33 @@ else:
     if not cb(fd.AsyncFD):
       newList.add(cb)
 
-    withData(p.selector, ident, adata) do:
+    withData(p.selector, ident.int, adata) do:
       # descriptor still present in queue.
       adata.readList = newList & adata.readList
       if len(adata.readList) == 0:
         # if no callbacks registered with descriptor, unregister it.
-        p.selector.unregister(fd)
+        p.selector.unregister(fd.int)
     do:
       # descriptor was unregistered in callback via `unregister()`.
       discard
 
+  proc closeSocket*(sock: AsyncFD) =
+    let selector = getGlobalDispatcher().selector
+    if sock.SocketHandle notin selector:
+      raise newException(ValueError, "File descriptor not registered.")
+
+    let data = selector.getData(sock.SocketHandle)
+    sock.unregister()
+    sock.SocketHandle.close()
+    # We need to unblock the read and write callbacks which could still be
+    # waiting for the socket to become readable and/or writeable.
+    for cb in data.readList & data.writeList:
+      if not cb(sock):
+        raise newException(
+          ValueError, "Expecting async operations to stop when fd has closed."
+        )
+
+
   proc runOnce(timeout = 500): bool =
     let p = getGlobalDispatcher()
     when ioselSupportedPlatform:
@@ -1257,41 +1285,42 @@ else:
     let nextTimer = processTimers(p, result)
     var count = p.selector.selectInto(adjustTimeout(timeout, nextTimer), keys)
     for i in 0..<count:
-      var custom = false
-      let fd = keys[i].fd
+      let fd = keys[i].fd.AsyncFD
       let events = keys[i].events
-      var rLength = 0 # len(data.readList) after callback
-      var wLength = 0 # len(data.writeList) after callback
+      var (readCbListCount, writeCbListCount) = (0, 0)
 
       if Event.Read in events or events == {Event.Error}:
-        processBasicCallbacks(fd, readList)
+        (readCbListCount, writeCbListCount) =
+          processBasicCallbacks(fd, Event.Read)
         result = true
 
       if Event.Write in events or events == {Event.Error}:
-        processBasicCallbacks(fd, writeList)
+        (readCbListCount, writeCbListCount) =
+          processBasicCallbacks(fd, Event.Write)
         result = true
 
+      var isCustomEvent = false
       if Event.User in events:
-        processBasicCallbacks(fd, readList)
-        custom = true
-        if rLength == 0:
-          p.selector.unregister(fd)
+        (readCbListCount, writeCbListCount) =
+          processBasicCallbacks(fd, Event.Read)
+        isCustomEvent = true
+        if readCbListCount == 0:
+          p.selector.unregister(fd.int)
         result = true
 
       when ioselSupportedPlatform:
         if (customSet * events) != {}:
-          custom = true
+          isCustomEvent = true
           processCustomCallbacks(fd)
           result = true
 
       # because state `data` can be modified in callback we need to update
       # descriptor events with currently registered callbacks.
-      if not custom:
+      if not isCustomEvent and (readCbListCount != -1 and writeCbListCount != -1):
         var newEvents: set[Event] = {}
-        if rLength != -1 and wLength != -1:
-          if rLength > 0: incl(newEvents, Event.Read)
-          if wLength > 0: incl(newEvents, Event.Write)
-          p.selector.updateHandle(SocketHandle(fd), newEvents)
+        if readCbListCount > 0: incl(newEvents, Event.Read)
+        if writeCbListCount > 0: incl(newEvents, Event.Write)
+        p.selector.updateHandle(SocketHandle(fd), newEvents)
 
     # Timer processing.
     discard processTimers(p, result)
@@ -1370,7 +1399,7 @@ else:
           if flags.isDisconnectionError(lastError):
             retFuture.complete()
           else:
-            retFuture.fail(newException(OSError, osErrorMsg(lastError)))
+            retFuture.fail(newOSError(lastError))
         else:
           result = false # We still want this callback to be called.
       else: