summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--lib/pure/collections/sharedtables.nim6
-rw-r--r--lib/pure/selectors.nim92
2 files changed, 61 insertions, 37 deletions
diff --git a/lib/pure/collections/sharedtables.nim b/lib/pure/collections/sharedtables.nim
index 2abc314d7..20e1bb7a9 100644
--- a/lib/pure/collections/sharedtables.nim
+++ b/lib/pure/collections/sharedtables.nim
@@ -16,12 +16,12 @@ import
   hashes, math, locks
 
 type
-  KeyValuePair[A, B] = tuple[hcode: THash, key: A, val: B]
+  KeyValuePair[A, B] = tuple[hcode: Hash, key: A, val: B]
   KeyValuePairSeq[A, B] = ptr array[10_000_000, KeyValuePair[A, B]]
   SharedTable* [A, B] = object ## generic hash SharedTable
     data: KeyValuePairSeq[A, B]
     counter, dataLen: int
-    lock: TLock
+    lock: Lock
 
 template maxHash(t): expr = t.dataLen-1
 
@@ -49,7 +49,7 @@ proc mget*[A, B](t: var SharedTable[A, B], key: A): var B =
   ## retrieves the value at ``t[key]``. The value can be modified.
   ## If `key` is not in `t`, the ``KeyError`` exception is raised.
   withLock t:
-    var hc: THash
+    var hc: Hash
     var index = rawGet(t, key, hc)
     let hasKey = index >= 0
     if hasKey: result = t.data[index].val
diff --git a/lib/pure/selectors.nim b/lib/pure/selectors.nim
index 9802684fd..aa8ad39d1 100644
--- a/lib/pure/selectors.nim
+++ b/lib/pure/selectors.nim
@@ -9,7 +9,7 @@
 
 # TODO: Docs.
 
-import tables, os, unsigned, hashes
+import os, unsigned, hashes
 
 when defined(linux):
   import posix, epoll
@@ -18,6 +18,17 @@ elif defined(windows):
 else:
   import posix
 
+const MultiThreaded = defined(useStdlibThreading)
+
+when MultiThreaded:
+  import sharedtables
+
+  type SelectorData = pointer
+else:
+  import tables
+
+  type SelectorData = RootRef
+
 proc hash*(x: SocketHandle): Hash {.borrow.}
 proc `$`*(x: SocketHandle): string {.borrow.}
 
@@ -28,12 +39,10 @@ type
   SelectorKey* = object
     fd*: SocketHandle
     events*: set[Event] ## The events which ``fd`` listens for.
-    data*: pointer ## User object.
+    data*: SelectorData ## User object.
 
   ReadyInfo* = tuple[key: SelectorKey, events: set[Event]]
 
-
-
 when defined(nimdoc):
   type
     Selector* = ref object
@@ -41,7 +50,7 @@ when defined(nimdoc):
       ## status.
 
   proc register*(s: Selector, fd: SocketHandle, events: set[Event],
-                 data: RootRef): SelectorKey {.discardable.} =
+                 data: SelectorData): SelectorKey {.discardable.} =
     ## Registers file descriptor ``fd`` to selector ``s`` with a set of Event
     ## ``events``.
 
@@ -76,7 +85,10 @@ elif defined(linux):
     Selector* = object
       epollFD: cint
       events: array[64, epoll_event]
-      fds: SharedTable[SocketHandle, SelectorKey]
+      when MultiThreaded:
+        fds: SharedTable[SocketHandle, SelectorKey]
+      else:
+        fds: Table[SocketHandle, SelectorKey]
 
   proc createEventStruct(events: set[Event], fd: SocketHandle): epoll_event =
     if EvRead in events:
@@ -86,8 +98,8 @@ elif defined(linux):
     result.events = result.events or EPOLLRDHUP
     result.data.fd = fd.cint
 
-  proc register*(s: Selector, fd: SocketHandle, events: set[Event],
-      data: pointer) =
+  proc register*(s: var Selector, fd: SocketHandle, events: set[Event],
+      data: SelectorData) =
     var event = createEventStruct(events, fd)
     if events != {}:
       if epoll_ctl(s.epollFD, EPOLL_CTL_ADD, fd, addr(event)) != 0:
@@ -129,7 +141,7 @@ elif defined(linux):
     s.fds.del(fd)
 
   proc close*(s: var Selector) =
-    deinitSharedTable(s.fds)
+    when MultiThreaded: deinitSharedTable(s.fds)
     if s.epollFD.close() != 0: raiseOSError(osLastError())
 
   proc epollHasFd(s: Selector, fd: SocketHandle): bool =
@@ -163,7 +175,7 @@ elif defined(linux):
       if (s.events[i].events and EPOLLIN) != 0: evSet = evSet + {EvRead}
       if (s.events[i].events and EPOLLOUT) != 0: evSet = evSet + {EvWrite}
       let selectorKey = s.fds[fd]
-      assert selectorKey != nil
+      assert selectorKey.fd != 0.SocketHandle
       result.add((selectorKey, evSet))
 
       #echo("Epoll: ", result[i].key.fd, " ", result[i].events, " ", result[i].key.events)
@@ -172,7 +184,10 @@ elif defined(linux):
     result.epollFD = epoll_create(64)
     if result.epollFD < 0:
       raiseOSError(osLastError())
-    result.fds = initSharedTable[SocketHandle, SelectorKey]()
+    when MultiThreaded:
+      result.fds = initSharedTable[SocketHandle, SelectorKey]()
+    else:
+      result.fds = initTable[SocketHandle, SelectorKey]()
 
   proc contains*(s: Selector, fd: SocketHandle): bool =
     ## Determines whether selector contains a file descriptor.
@@ -191,14 +206,15 @@ elif defined(linux):
 
 elif not defined(nimdoc):
   # TODO: kqueue for bsd/mac os x.
-  import sharedtables
-
   type
     Selector* = object
-      fds: SharedTable[SocketHandle, SelectorKey]
+      when MultiThreaded:
+        fds: SharedTable[SocketHandle, SelectorKey]
+      else:
+        fds: Table[SocketHandle, SelectorKey]
 
   proc register*(s: var Selector, fd: SocketHandle, events: set[Event],
-                 data: pointer) =
+                 data: SelectorData) =
     let result = SelectorKey(fd: fd, events: events, data: data)
     if s.fds.hasKeyOrPut(fd, result):
       raise newException(ValueError, "File descriptor already exists.")
@@ -211,7 +227,8 @@ elif not defined(nimdoc):
   proc unregister*(s: var Selector, fd: SocketHandle) =
     s.fds.del(fd)
 
-  proc close*(s: var Selector) = deinitSharedTable(s.fds)
+  proc close*(s: var Selector) =
+    when MultiThreaded: deinitSharedTable(s.fds)
 
   proc timeValFromMilliseconds(timeout: int): TimeVal =
     if timeout != -1:
@@ -219,10 +236,9 @@ elif not defined(nimdoc):
       result.tv_sec = seconds.int32
       result.tv_usec = ((timeout - seconds * 1000) * 1000).int32
 
-  proc createFdSet(rd, wr: var TFdSet, fds: SharedTable[SocketHandle, SelectorKey],
-      m: var int) =
+  proc createFdSet(rd, wr: var TFdSet, s: Selector, m: var int) =
     FD_ZERO(rd); FD_ZERO(wr)
-    for k, v in pairs(fds):
+    for k, v in pairs(s.fds):
       if EvRead in v.events:
         m = max(m, int(k))
         FD_SET(k, rd)
@@ -231,9 +247,9 @@ elif not defined(nimdoc):
         FD_SET(k, wr)
 
   proc getReadyFDs(rd, wr: var TFdSet,
-                   fds: SharedTable[SocketHandle, SelectorKey]): seq[ReadyInfo] =
+                   s: var Selector): seq[ReadyInfo] =
     result = @[]
-    for k, v in pairs(fds):
+    for k, v in pairs(s.fds):
       var events: set[Event] = {}
       if FD_ISSET(k, rd) != 0'i32:
         events = events + {EvRead}
@@ -241,13 +257,12 @@ elif not defined(nimdoc):
         events = events + {EvWrite}
       result.add((v, events))
 
-  proc select(fds: var SharedTable[SocketHandle, SelectorKey],
-              timeout = 500): seq[ReadyInfo] =
+  proc select*(s: var Selector, timeout: int): seq[ReadyInfo] =
     var tv {.noInit.}: TimeVal = timeValFromMilliseconds(timeout)
 
     var rd, wr: TFdSet
     var m = 0
-    createFdSet(rd, wr, fds, m)
+    createFdSet(rd, wr, s, m)
 
     var retCode = 0
     if timeout != -1:
@@ -260,13 +275,13 @@ elif not defined(nimdoc):
     elif retCode == 0:
       return @[]
     else:
-      return getReadyFDs(rd, wr, fds)
-
-  proc select*(s: Selector, timeout: int): seq[ReadyInfo] =
-    result = select(s.fds, timeout)
+      return getReadyFDs(rd, wr, s)
 
   proc newSelector*(): Selector =
-    result.fds = initSharedTable[SocketHandle, SelectorKey]()
+    when MultiThreaded:
+      result.fds = initSharedTable[SocketHandle, SelectorKey]()
+    else:
+      result.fds = initTable[SocketHandle, SelectorKey]()
 
   proc contains*(s: Selector, fd: SocketHandle): bool =
     return s.fds.hasKey(fd)
@@ -289,9 +304,15 @@ proc contains*(s: Selector, key: SelectorKey): bool =
 when not defined(testing) and isMainModule and not defined(nimdoc):
   # Select()
   import sockets
-  type
-    SockWrapper = object
-      sock: Socket
+
+  when MultiThreaded:
+    type
+      SockWrapper = object
+        sock: Socket
+  else:
+    type
+      SockWrapper = ref object of RootObj
+        sock: Socket
 
   var sock = socket()
   if sock == sockets.invalidSocket: raiseOSError(osLastError())
@@ -300,7 +321,10 @@ when not defined(testing) and isMainModule and not defined(nimdoc):
 
   var selector = newSelector()
   var data = SockWrapper(sock: sock)
-  let key = selector.register(sock.getFD, {EvWrite}, addr data)
+  when MultiThreaded:
+    selector.register(sock.getFD, {EvWrite}, addr data)
+  else:
+    selector.register(sock.getFD, {EvWrite}, data)
   var i = 0
   while true:
     let ready = selector.select(1000)
@@ -308,6 +332,6 @@ when not defined(testing) and isMainModule and not defined(nimdoc):
     if ready.len > 0: echo ready[0].events
     i.inc
     if i == 6:
-      assert selector.unregister(sock.getFD).fd == sock.getFD
+      selector.unregister(sock.getFD)
       selector.close()
       break