summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--lib/upcoming/asyncdispatch.nim96
1 files changed, 60 insertions, 36 deletions
diff --git a/lib/upcoming/asyncdispatch.nim b/lib/upcoming/asyncdispatch.nim
index 731ef52dc..e7dc4abcc 100644
--- a/lib/upcoming/asyncdispatch.nim
+++ b/lib/upcoming/asyncdispatch.nim
@@ -9,7 +9,7 @@
 
 include "system/inclrtl"
 
-import os, oids, tables, strutils, times, heapqueue
+import os, oids, tables, strutils, times, heapqueue, lists
 
 import nativesockets, net, queues
 
@@ -729,7 +729,7 @@ when defined(windows) or defined(nimdoc):
     var lpOutputBuf = newString(lpOutputLen)
     var dwBytesReceived: Dword
     let dwReceiveDataLength = 0.Dword # We don't want any data to be read.
-    let dwLocalAddressLength = Dword(sizeof (Sockaddr_in) + 16)
+    let dwLocalAddressLength = Dword(sizeof(Sockaddr_in) + 16)
     let dwRemoteAddressLength = Dword(sizeof(Sockaddr_in) + 16)
 
     template completeAccept() {.dirty.} =
@@ -1095,9 +1095,11 @@ else:
     AsyncFD* = distinct cint
     Callback = proc (fd: AsyncFD): bool {.closure,gcsafe.}
 
+    DoublyLinkedListRef = ref DoublyLinkedList[Callback]
+
     AsyncData = object
-      readCB: Callback
-      writeCB: Callback
+      readCBs: DoublyLinkedListRef
+      writeCBs: DoublyLinkedListRef
 
     AsyncEvent* = distinct SelectEvent
 
@@ -1121,7 +1123,10 @@ else:
 
   proc register*(fd: AsyncFD) =
     let p = getGlobalDispatcher()
-    var data = AsyncData()
+    var data = AsyncData(
+      readCBs: DoublyLinkedListRef(),
+      writeCBs: DoublyLinkedListRef()
+    )
     p.selector.registerHandle(fd.SocketHandle, {}, data)
 
   proc newAsyncNativeSocket*(domain: cint, sockType: cint,
@@ -1156,8 +1161,9 @@ else:
     let p = getGlobalDispatcher()
     var newEvents = {Event.Read}
     withData(p.selector, fd.SocketHandle, adata) do:
-      adata.readCB = cb
-      if adata.writeCB != nil:
+      adata.readCBs[].append(cb)
+      newEvents.incl(Event.Read)
+      if not isNil(adata.writeCBs.head):
         newEvents.incl(Event.Write)
     do:
       raise newException(ValueError, "File descriptor not registered.")
@@ -1167,8 +1173,9 @@ else:
     let p = getGlobalDispatcher()
     var newEvents = {Event.Write}
     withData(p.selector, fd.SocketHandle, adata) do:
-      adata.writeCB = cb
-      if adata.readCB != nil:
+      adata.writeCBs[].append(cb)
+      newEvents.incl(Event.Write)
+      if not isNil(adata.readCBs.head):
         newEvents.incl(Event.Read)
     do:
       raise newException(ValueError, "File descriptor not registered.")
@@ -1195,31 +1202,32 @@ else:
         let events = keys[i].events
 
         if Event.Read in events or events == {Event.Error}:
-          let cb = keys[i].data.readCB
-          if cb != nil:
-            if cb(fd.AsyncFD):
-              p.selector.withData(fd, adata) do:
-                if adata.readCB == cb:
-                  adata.readCB = nil
+          for node in keys[i].data.readCBs[].nodes():
+            let cb = node.value
+            if cb != nil:
+              if cb(fd.AsyncFD):
+                keys[i].data.readCBs[].remove(node)
+              else:
+                break
 
         if Event.Write in events or events == {Event.Error}:
-          let cb = keys[i].data.writeCB
-          if cb != nil:
-            if cb(fd.AsyncFD):
-              p.selector.withData(fd, adata) do:
-                if adata.writeCB == cb:
-                  adata.writeCB = nil
+          for node in keys[i].data.writeCBs[].nodes():
+            let cb = node.value
+            if cb != nil:
+              if cb(fd.AsyncFD):
+                keys[i].data.writeCBs[].remove(node)
+              else:
+                break
 
         when supportedPlatform:
           if (customSet * events) != {}:
-            let cb = keys[i].data.readCB
-            doAssert(cb != nil)
-            custom = true
-            if cb(fd.AsyncFD):
-              p.selector.withData(fd, adata) do:
-                if adata.readCB == cb:
-                  adata.readCB = nil
-                  p.selector.unregister(fd)
+            for node in keys[i].data.readCBs[].nodes():
+              let cb = node.value
+              doAssert(cb != nil)
+              custom = true
+              if cb(fd.AsyncFD):
+                keys[i].data.readCBs[].remove(node)
+                p.selector.unregister(fd)
 
         # because state `data` can be modified in callback we need to update
         # descriptor events with currently registered callbacks.
@@ -1227,8 +1235,8 @@ else:
           var update = false
           var newEvents: set[Event] = {}
           p.selector.withData(fd, adata) do:
-            if adata.readCB != nil: incl(newEvents, Event.Read)
-            if adata.writeCB != nil: incl(newEvents, Event.Write)
+            if not isNil(adata.readCBs.head): incl(newEvents, Event.Read)
+            if not isNil(adata.writeCBs.head): incl(newEvents, Event.Write)
             update = true
           if update:
             p.selector.updateHandle(fd, newEvents)
@@ -1491,21 +1499,33 @@ else:
       ## ``oneshot`` - if ``true`` only one event will be dispatched,
       ## if ``false`` continuous events every ``timeout`` milliseconds.
       let p = getGlobalDispatcher()
-      var data = AsyncData(readCB: cb)
+      var data = AsyncData(
+        readCBs: DoublyLinkedListRef(),
+        writeCBs: DoublyLinkedListRef()
+      )
+      data.readCBs[].append(cb)
       p.selector.registerTimer(timeout, oneshot, data)
 
     proc addSignal*(signal: int, cb: Callback) =
       ## Start watching signal ``signal``, and when signal appears, call the
       ## callback ``cb``.
       let p = getGlobalDispatcher()
-      var data = AsyncData(readCB: cb)
+      var data = AsyncData(
+        readCBs: DoublyLinkedListRef(),
+        writeCBs: DoublyLinkedListRef()
+      )
+      data.readCBs[].append(cb)
       p.selector.registerSignal(signal, data)
 
     proc addProcess*(pid: int, cb: Callback) =
       ## Start watching for process exit with pid ``pid``, and then call
       ## the callback ``cb``.
       let p = getGlobalDispatcher()
-      var data = AsyncData(readCB: cb)
+      var data = AsyncData(
+        readCBs: DoublyLinkedListRef(),
+        writeCBs: DoublyLinkedListRef()
+      )
+      data.readCBs[].append(cb)
       p.selector.registerProcess(pid, data)
 
   proc newAsyncEvent*(): AsyncEvent =
@@ -1524,7 +1544,11 @@ else:
     ## Start watching for event ``ev``, and call callback ``cb``, when
     ## ev will be set to signaled state.
     let p = getGlobalDispatcher()
-    var data = AsyncData(readCB: cb)
+    var data = AsyncData(
+      readCBs: DoublyLinkedListRef(),
+      writeCBs: DoublyLinkedListRef()
+    )
+    data.readCBs[].append(cb)
     p.selector.registerEvent(SelectEvent(ev), data)
 
 proc sleepAsync*(ms: int): Future[void] =
@@ -1591,7 +1615,7 @@ proc recvLine*(socket: AsyncFD): Future[string] {.async.} =
   ## **Note**: This procedure is mostly used for testing. You likely want to
   ## use ``asyncnet.recvLine`` instead.
 
-  template addNLIfEmpty(): stmt =
+  template addNLIfEmpty(): typed =
     if result.len == 0:
       result.add("\c\L")