summary refs log tree commit diff stats
path: root/lib
diff options
context:
space:
mode:
Diffstat (limited to 'lib')
-rw-r--r--lib/pure/asyncdispatch.nim25
-rw-r--r--lib/pure/asyncfile.nim15
2 files changed, 29 insertions, 11 deletions
diff --git a/lib/pure/asyncdispatch.nim b/lib/pure/asyncdispatch.nim
index 4c96aa614..7211fabc7 100644
--- a/lib/pure/asyncdispatch.nim
+++ b/lib/pure/asyncdispatch.nim
@@ -219,7 +219,7 @@ when defined(windows) or defined(nimdoc):
 
     PCustomOverlapped* = ref CustomOverlapped
 
-    AsyncFD* = distinct int
+    AsyncFD* = distinct int ## An FD that is registered in the dispatcher.
 
     PostCallbackData = object
       ioPort: Handle
@@ -262,13 +262,22 @@ when defined(windows) or defined(nimdoc):
       setGlobalDispatcher(newDispatcher())
     result = gDisp
 
-  proc register*(fd: AsyncFD) =
+  proc register*(fd: cint | SocketHandle | AsyncFD): AsyncFD {.discardable.} =
     ## Registers ``fd`` with the dispatcher.
+    ##
+    ## By convention, an ``AsyncFD`` is said to be already registered in the
+    ## dispatcher. This procedure will raise an exception if ``fd`` has already
+    ## been registered, but only if the type of the ``fd`` isn't ``AsyncFD``.
     let p = getGlobalDispatcher()
+    when fd is AsyncFD:
+      if fd in p.handles:
+        return
+
     if createIoCompletionPort(fd.Handle, p.ioPort,
                               cast[CompletionKey](fd), 1) == 0:
       raiseOSError(osLastError())
     p.handles.incl(fd)
+    return fd.AsyncFD
 
   proc verifyPresence(fd: AsyncFD) =
     ## Ensures that file descriptor has been registered with the dispatcher.
@@ -753,6 +762,9 @@ when defined(windows) or defined(nimdoc):
     ## Unregisters ``fd``.
     getGlobalDispatcher().handles.excl(fd)
 
+  proc contains*(disp: PDispatcher, fd: AsyncFd | SocketHandle): bool =
+    return fd.SocketHandle in disp.handles
+
   {.push stackTrace:off.}
   proc waitableCallback(param: pointer,
                         timerOrWaitFired: WINBOOL): void {.stdcall.} =
@@ -1091,10 +1103,14 @@ else:
       setGlobalDispatcher(newDispatcher())
     result = gDisp
 
-  proc register*(fd: AsyncFD) =
+  proc register*(fd: cint | SocketHandle | AsyncFD): AsyncFD {.discardable.} =
     let p = getGlobalDispatcher()
+    when fd is AsyncFD:
+      if fd.SocketHandle in p.selector:
+        return
     var data = newAsyncData()
     p.selector.registerHandle(fd.SocketHandle, {}, data)
+    return fd.AsyncFD
 
   proc closeSocket*(sock: AsyncFD) =
     let disp = getGlobalDispatcher()
@@ -1106,6 +1122,9 @@ else:
 
   proc unregister*(ev: AsyncEvent) =
     getGlobalDispatcher().selector.unregister(SelectEvent(ev))
+  
+  proc contains*(disp: PDispatcher, fd: AsyncFd | SocketHandle): bool =
+    return fd.SocketHandle in disp.selector
 
   proc addRead*(fd: AsyncFD, cb: Callback) =
     let p = getGlobalDispatcher()
diff --git a/lib/pure/asyncfile.nim b/lib/pure/asyncfile.nim
index 9f4da16a3..6cd62efa4 100644
--- a/lib/pure/asyncfile.nim
+++ b/lib/pure/asyncfile.nim
@@ -81,11 +81,10 @@ proc getFileSize*(f: AsyncFile): int64 =
   else:
     result = lseek(f.fd.cint, 0, SEEK_END)
 
-proc newAsyncFile*(fd: AsyncFd): AsyncFile =
+proc newAsyncFile*(fd: cint | AsyncFd): AsyncFile =
   ## Creates `AsyncFile` with a previously opened file descriptor `fd`.
   new result
-  result.fd = fd
-  register(result.fd)
+  result.fd = register(result.fd)
 
 proc openAsync*(filename: string, mode = fmRead): AsyncFile =
   ## Opens a file specified by the path in ``filename`` using
@@ -97,16 +96,16 @@ proc openAsync*(filename: string, mode = fmRead): AsyncFile =
     when useWinUnicode:
       let fd = createFileW(newWideCString(filename), desiredAccess,
           FILE_SHARE_READ,
-          nil, creationDisposition, flags, 0).AsyncFd
+          nil, creationDisposition, flags, 0)
     else:
       let fd = createFileA(filename, desiredAccess,
           FILE_SHARE_READ,
-          nil, creationDisposition, flags, 0).AsyncFd
+          nil, creationDisposition, flags, 0)
 
-    if fd.Handle == INVALID_HANDLE_VALUE:
+    if fd == INVALID_HANDLE_VALUE:
       raiseOSError(osLastError())
 
-    result = newAsyncFile(fd)
+    result = newAsyncFile(fd.cint)
 
     if mode == fmAppend:
       result.offset = getFileSize(result)
@@ -115,7 +114,7 @@ proc openAsync*(filename: string, mode = fmRead): AsyncFile =
     let flags = getPosixFlags(mode)
     # RW (Owner), RW (Group), R (Other)
     let perm = S_IRUSR or S_IWUSR or S_IRGRP or S_IWGRP or S_IROTH
-    let fd = open(filename, flags, perm).AsyncFD
+    let fd = open(filename, flags, perm)
     if fd.cint == -1:
       raiseOSError(osLastError())