summary refs log tree commit diff stats
path: root/lib/upcoming/asyncdispatch.nim
diff options
context:
space:
mode:
Diffstat (limited to 'lib/upcoming/asyncdispatch.nim')
-rw-r--r--lib/upcoming/asyncdispatch.nim106
1 files changed, 65 insertions, 41 deletions
diff --git a/lib/upcoming/asyncdispatch.nim b/lib/upcoming/asyncdispatch.nim
index 731ef52dc..68ecbe81e 100644
--- a/lib/upcoming/asyncdispatch.nim
+++ b/lib/upcoming/asyncdispatch.nim
@@ -9,9 +9,9 @@
 
 include "system/inclrtl"
 
-import os, oids, tables, strutils, times, heapqueue
+import os, oids, tables, strutils, times, heapqueue, lists
 
-import nativesockets, net, queues
+import nativesockets, net, deques
 
 export Port, SocketFlag
 
@@ -135,7 +135,7 @@ include "../includes/asyncfutures"
 type
   PDispatcherBase = ref object of RootRef
     timers: HeapQueue[tuple[finishAt: float, fut: Future[void]]]
-    callbacks: Queue[proc ()]
+    callbacks: Deque[proc ()]
 
 proc processTimers(p: PDispatcherBase) {.inline.} =
   while p.timers.len > 0 and epochTime() >= p.timers[0].finishAt:
@@ -143,7 +143,7 @@ proc processTimers(p: PDispatcherBase) {.inline.} =
 
 proc processPendingCallbacks(p: PDispatcherBase) =
   while p.callbacks.len > 0:
-    var cb = p.callbacks.dequeue()
+    var cb = p.callbacks.popFirst()
     cb()
 
 proc adjustedTimeout(p: PDispatcherBase, timeout: int): int {.inline.} =
@@ -729,7 +729,7 @@ when defined(windows) or defined(nimdoc):
     var lpOutputBuf = newString(lpOutputLen)
     var dwBytesReceived: Dword
     let dwReceiveDataLength = 0.Dword # We don't want any data to be read.
-    let dwLocalAddressLength = Dword(sizeof (Sockaddr_in) + 16)
+    let dwLocalAddressLength = Dword(sizeof(Sockaddr_in) + 16)
     let dwRemoteAddressLength = Dword(sizeof(Sockaddr_in) + 16)
 
     template completeAccept() {.dirty.} =
@@ -1095,9 +1095,11 @@ else:
     AsyncFD* = distinct cint
     Callback = proc (fd: AsyncFD): bool {.closure,gcsafe.}
 
+    DoublyLinkedListRef = ref DoublyLinkedList[Callback]
+
     AsyncData = object
-      readCB: Callback
-      writeCB: Callback
+      readCBs: DoublyLinkedListRef
+      writeCBs: DoublyLinkedListRef
 
     AsyncEvent* = distinct SelectEvent
 
@@ -1112,7 +1114,7 @@ else:
     new result
     result.selector = newSelector[AsyncData]()
     result.timers.newHeapQueue()
-    result.callbacks = initQueue[proc ()](64)
+    result.callbacks = initDeque[proc ()](64)
 
   var gDisp{.threadvar.}: PDispatcher ## Global dispatcher
   proc getGlobalDispatcher*(): PDispatcher =
@@ -1121,7 +1123,10 @@ else:
 
   proc register*(fd: AsyncFD) =
     let p = getGlobalDispatcher()
-    var data = AsyncData()
+    var data = AsyncData(
+      readCBs: DoublyLinkedListRef(),
+      writeCBs: DoublyLinkedListRef()
+    )
     p.selector.registerHandle(fd.SocketHandle, {}, data)
 
   proc newAsyncNativeSocket*(domain: cint, sockType: cint,
@@ -1156,8 +1161,9 @@ else:
     let p = getGlobalDispatcher()
     var newEvents = {Event.Read}
     withData(p.selector, fd.SocketHandle, adata) do:
-      adata.readCB = cb
-      if adata.writeCB != nil:
+      adata.readCBs[].append(cb)
+      newEvents.incl(Event.Read)
+      if not isNil(adata.writeCBs.head):
         newEvents.incl(Event.Write)
     do:
       raise newException(ValueError, "File descriptor not registered.")
@@ -1167,8 +1173,9 @@ else:
     let p = getGlobalDispatcher()
     var newEvents = {Event.Write}
     withData(p.selector, fd.SocketHandle, adata) do:
-      adata.writeCB = cb
-      if adata.readCB != nil:
+      adata.writeCBs[].append(cb)
+      newEvents.incl(Event.Write)
+      if not isNil(adata.readCBs.head):
         newEvents.incl(Event.Read)
     do:
       raise newException(ValueError, "File descriptor not registered.")
@@ -1195,31 +1202,32 @@ else:
         let events = keys[i].events
 
         if Event.Read in events or events == {Event.Error}:
-          let cb = keys[i].data.readCB
-          if cb != nil:
-            if cb(fd.AsyncFD):
-              p.selector.withData(fd, adata) do:
-                if adata.readCB == cb:
-                  adata.readCB = nil
+          for node in keys[i].data.readCBs[].nodes():
+            let cb = node.value
+            if cb != nil:
+              if cb(fd.AsyncFD):
+                keys[i].data.readCBs[].remove(node)
+              else:
+                break
 
         if Event.Write in events or events == {Event.Error}:
-          let cb = keys[i].data.writeCB
-          if cb != nil:
-            if cb(fd.AsyncFD):
-              p.selector.withData(fd, adata) do:
-                if adata.writeCB == cb:
-                  adata.writeCB = nil
+          for node in keys[i].data.writeCBs[].nodes():
+            let cb = node.value
+            if cb != nil:
+              if cb(fd.AsyncFD):
+                keys[i].data.writeCBs[].remove(node)
+              else:
+                break
 
         when supportedPlatform:
           if (customSet * events) != {}:
-            let cb = keys[i].data.readCB
-            doAssert(cb != nil)
-            custom = true
-            if cb(fd.AsyncFD):
-              p.selector.withData(fd, adata) do:
-                if adata.readCB == cb:
-                  adata.readCB = nil
-                  p.selector.unregister(fd)
+            for node in keys[i].data.readCBs[].nodes():
+              let cb = node.value
+              doAssert(cb != nil)
+              custom = true
+              if cb(fd.AsyncFD):
+                keys[i].data.readCBs[].remove(node)
+                p.selector.unregister(fd)
 
         # because state `data` can be modified in callback we need to update
         # descriptor events with currently registered callbacks.
@@ -1227,8 +1235,8 @@ else:
           var update = false
           var newEvents: set[Event] = {}
           p.selector.withData(fd, adata) do:
-            if adata.readCB != nil: incl(newEvents, Event.Read)
-            if adata.writeCB != nil: incl(newEvents, Event.Write)
+            if not isNil(adata.readCBs.head): incl(newEvents, Event.Read)
+            if not isNil(adata.writeCBs.head): incl(newEvents, Event.Write)
             update = true
           if update:
             p.selector.updateHandle(fd, newEvents)
@@ -1491,21 +1499,33 @@ else:
       ## ``oneshot`` - if ``true`` only one event will be dispatched,
       ## if ``false`` continuous events every ``timeout`` milliseconds.
       let p = getGlobalDispatcher()
-      var data = AsyncData(readCB: cb)
+      var data = AsyncData(
+        readCBs: DoublyLinkedListRef(),
+        writeCBs: DoublyLinkedListRef()
+      )
+      data.readCBs[].append(cb)
       p.selector.registerTimer(timeout, oneshot, data)
 
     proc addSignal*(signal: int, cb: Callback) =
       ## Start watching signal ``signal``, and when signal appears, call the
       ## callback ``cb``.
       let p = getGlobalDispatcher()
-      var data = AsyncData(readCB: cb)
+      var data = AsyncData(
+        readCBs: DoublyLinkedListRef(),
+        writeCBs: DoublyLinkedListRef()
+      )
+      data.readCBs[].append(cb)
       p.selector.registerSignal(signal, data)
 
     proc addProcess*(pid: int, cb: Callback) =
       ## Start watching for process exit with pid ``pid``, and then call
       ## the callback ``cb``.
       let p = getGlobalDispatcher()
-      var data = AsyncData(readCB: cb)
+      var data = AsyncData(
+        readCBs: DoublyLinkedListRef(),
+        writeCBs: DoublyLinkedListRef()
+      )
+      data.readCBs[].append(cb)
       p.selector.registerProcess(pid, data)
 
   proc newAsyncEvent*(): AsyncEvent =
@@ -1524,7 +1544,11 @@ else:
     ## Start watching for event ``ev``, and call callback ``cb``, when
     ## ev will be set to signaled state.
     let p = getGlobalDispatcher()
-    var data = AsyncData(readCB: cb)
+    var data = AsyncData(
+      readCBs: DoublyLinkedListRef(),
+      writeCBs: DoublyLinkedListRef()
+    )
+    data.readCBs[].append(cb)
     p.selector.registerEvent(SelectEvent(ev), data)
 
 proc sleepAsync*(ms: int): Future[void] =
@@ -1591,7 +1615,7 @@ proc recvLine*(socket: AsyncFD): Future[string] {.async.} =
   ## **Note**: This procedure is mostly used for testing. You likely want to
   ## use ``asyncnet.recvLine`` instead.
 
-  template addNLIfEmpty(): stmt =
+  template addNLIfEmpty(): typed =
     if result.len == 0:
       result.add("\c\L")
 
@@ -1614,7 +1638,7 @@ proc recvLine*(socket: AsyncFD): Future[string] {.async.} =
 proc callSoon*(cbproc: proc ()) =
   ## Schedule `cbproc` to be called as soon as possible.
   ## The callback is called when control returns to the event loop.
-  getGlobalDispatcher().callbacks.enqueue(cbproc)
+  getGlobalDispatcher().callbacks.addLast(cbproc)
 
 proc runForever*() =
   ## Begins a never ending global dispatcher poll loop.