summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--changelog.md17
-rw-r--r--lib/pure/asyncdispatch.nim71
-rw-r--r--lib/pure/asyncnet.nim44
-rw-r--r--tests/stdlib/tfdleak.nim36
4 files changed, 130 insertions, 38 deletions
diff --git a/changelog.md b/changelog.md
index bb87f817c..88619b1c2 100644
--- a/changelog.md
+++ b/changelog.md
@@ -13,20 +13,21 @@
 - `deques.peekFirst` and `deques.peekLast` now have `var Deque[T] -> var T` overloads.
 - File handles created from high-level abstractions in the stdlib will no longer
   be inherited by child processes. In particular, these modules are affected:
-  `system`, `nativesockets`, `net` and `selectors`.
+  `asyncdispatch`, `asyncnet`, `system`, `nativesockets`, `net` and `selectors`.
 
-  For `net` and `nativesockets`, an `inheritable` flag has been added to all
-  `proc`s that create sockets, allowing the user to control whether the
-  resulting socket is inheritable. This flag is provided to ease the writing of
-  multi-process servers, where sockets inheritance is desired.
+  For `asyncdispatch`, `asyncnet`, `net` and `nativesockets`, an `inheritable`
+  flag has been added to all `proc`s that create sockets, allowing the user to
+  control whether the resulting socket is inheritable. This flag is provided to
+  ease the writing of multi-process servers, where sockets inheritance is
+  desired.
 
   For a transistion period, define `nimInheritHandles` to enable file handle
   inheritance by default. This flag does **not** affect the `selectors` module
   due to the differing semantics between operating systems.
 
-  `system.setInheritable` and `nativesockets.setInheritable` is also introduced
-  for setting file handle or socket inheritance. Not all platform have these
-  `proc`s defined.
+  `asyncdispatch.setInheritable`, `system.setInheritable` and
+  `nativesockets.setInheritable` is also introduced for setting file handle or
+  socket inheritance. Not all platform have these `proc`s defined.
 
 - The file descriptors created for internal bookkeeping by `ioselector_kqueue`
   and `ioselector_epoll` will no longer be leaked to child processes.
diff --git a/lib/pure/asyncdispatch.nim b/lib/pure/asyncdispatch.nim
index 91778a5bc..0800cb638 100644
--- a/lib/pure/asyncdispatch.nim
+++ b/lib/pure/asyncdispatch.nim
@@ -228,6 +228,16 @@ proc initCallSoonProc =
   if asyncfutures.getCallSoonProc().isNil:
     asyncfutures.setCallSoonProc(callSoon)
 
+template implementSetInheritable() {.dirty.} =
+  when declared(setInheritable):
+    proc setInheritable*(fd: AsyncFD, inheritable: bool): bool =
+      ## Control whether a file handle can be inherited by child processes.
+      ## Returns ``true`` on success.
+      ##
+      ## This procedure is not guaranteed to be available for all platforms.
+      ## Test for availability with `declared()`_.
+      fd.FileHandle.setInheritable(inheritable)
+
 when defined(windows) or defined(nimdoc):
   import winlean, sets, hashes
   type
@@ -695,7 +705,8 @@ when defined(windows) or defined(nimdoc):
           retFuture.complete(bytesReceived)
     return retFuture
 
-  proc acceptAddr*(socket: AsyncFD, flags = {SocketFlag.SafeDisconn}):
+  proc acceptAddr*(socket: AsyncFD, flags = {SocketFlag.SafeDisconn},
+                   inheritable = defined(nimInheritHandles)):
       owned(Future[tuple[address: string, client: AsyncFD]]) =
     ## Accepts a new connection. Returns a future containing the client socket
     ## corresponding to that connection and the remote address of the client.
@@ -704,6 +715,9 @@ when defined(windows) or defined(nimdoc):
     ## The resulting client socket is automatically registered to the
     ## dispatcher.
     ##
+    ## If ``inheritable`` is false (the default), the resulting client socket will
+    ## not be inheritable by child processes.
+    ##
     ## The ``accept`` call may result in an error if the connecting socket
     ## disconnects during the duration of the ``accept``. If the ``SafeDisconn``
     ## flag is specified then this error will not be raised and instead
@@ -711,7 +725,7 @@ when defined(windows) or defined(nimdoc):
     verifyPresence(socket)
     var retFuture = newFuture[tuple[address: string, client: AsyncFD]]("acceptAddr")
 
-    var clientSock = createNativeSocket()
+    var clientSock = createNativeSocket(inheritable = inheritable)
     if clientSock == osInvalidSocket: raiseOSError(osLastError())
 
     const lpOutputLen = 1024
@@ -788,6 +802,8 @@ when defined(windows) or defined(nimdoc):
 
     return retFuture
 
+  implementSetInheritable()
+
   proc closeSocket*(socket: AsyncFD) =
     ## Closes a socket and ensures that it is unregistered.
     socket.SocketHandle.close()
@@ -1090,6 +1106,9 @@ else:
   import selectors
   from posix import EINTR, EAGAIN, EINPROGRESS, EWOULDBLOCK, MSG_PEEK,
                     MSG_NOSIGNAL
+  when declared(posix.accept4):
+    from posix import accept4, SOCK_CLOEXEC
+
   const
     InitCallbackListSize = 4         # initial size of callbacks sequence,
                                      # associated with file/socket descriptor.
@@ -1263,6 +1282,8 @@ else:
       # descriptor was unregistered in callback via `unregister()`.
       discard
 
+  implementSetInheritable()
+
   proc closeSocket*(sock: AsyncFD) =
     let selector = getGlobalDispatcher().selector
     if sock.SocketHandle notin selector:
@@ -1484,7 +1505,8 @@ else:
     addRead(socket, cb)
     return retFuture
 
-  proc acceptAddr*(socket: AsyncFD, flags = {SocketFlag.SafeDisconn}):
+  proc acceptAddr*(socket: AsyncFD, flags = {SocketFlag.SafeDisconn},
+                   inheritable = defined(nimInheritHandles)):
       owned(Future[tuple[address: string, client: AsyncFD]]) =
     var retFuture = newFuture[tuple[address: string,
         client: AsyncFD]]("acceptAddr")
@@ -1492,8 +1514,21 @@ else:
       result = true
       var sockAddress: Sockaddr_storage
       var addrLen = sizeof(sockAddress).SockLen
-      var client = accept(sock.SocketHandle,
-                          cast[ptr SockAddr](addr(sockAddress)), addr(addrLen))
+      var client =
+        when declared(accept4):
+          accept4(sock.SocketHandle, cast[ptr SockAddr](addr(sockAddress)),
+                  addr(addrLen), if inheritable: 0 else: SOCK_CLOEXEC)
+        else:
+          accept(sock.SocketHandle, cast[ptr SockAddr](addr(sockAddress)),
+                 addr(addrLen))
+      when declared(setInheritable) and not declared(accept4):
+        if client != osInvalidSocket and not setInheritable(client, inheritable):
+          # Set failure first because close() itself can fail,
+          # altering osLastError().
+          retFuture.fail(newOSError(osLastError()))
+          close client
+          return false
+
       if client == osInvalidSocket:
         let lastError = osLastError()
         assert lastError.int32 != EWOULDBLOCK and lastError.int32 != EAGAIN
@@ -1578,8 +1613,9 @@ proc poll*(timeout = 500) =
   ## `epoll`:idx: or `kqueue`:idx: primitive only once.
   discard runOnce(timeout)
 
-template createAsyncNativeSocketImpl(domain, sockType, protocol) =
-  let handle = createNativeSocket(domain, sockType, protocol)
+template createAsyncNativeSocketImpl(domain, sockType, protocol: untyped,
+                                     inheritable = defined(nimInheritHandles)) =
+  let handle = createNativeSocket(domain, sockType, protocol, inheritable)
   if handle == osInvalidSocket:
     return osInvalidSocket.AsyncFD
   handle.setBlocking(false)
@@ -1589,13 +1625,15 @@ template createAsyncNativeSocketImpl(domain, sockType, protocol) =
   register(result)
 
 proc createAsyncNativeSocket*(domain: cint, sockType: cint,
-                           protocol: cint): AsyncFD =
-  createAsyncNativeSocketImpl(domain, sockType, protocol)
+                              protocol: cint,
+                              inheritable = defined(nimInheritHandles)): AsyncFD =
+  createAsyncNativeSocketImpl(domain, sockType, protocol, inheritable)
 
 proc createAsyncNativeSocket*(domain: Domain = Domain.AF_INET,
-                           sockType: SockType = SOCK_STREAM,
-                           protocol: Protocol = IPPROTO_TCP): AsyncFD =
-  createAsyncNativeSocketImpl(domain, sockType, protocol)
+                              sockType: SockType = SOCK_STREAM,
+                              protocol: Protocol = IPPROTO_TCP,
+                              inheritable = defined(nimInheritHandles)): AsyncFD =
+  createAsyncNativeSocketImpl(domain, sockType, protocol, inheritable)
 
 proc newAsyncNativeSocket*(domain: cint, sockType: cint,
     protocol: cint): AsyncFD {.deprecated: "use createAsyncNativeSocket instead".} =
@@ -1824,12 +1862,17 @@ proc withTimeout*[T](fut: Future[T], timeout: int): owned(Future[bool]) =
   return retFuture
 
 proc accept*(socket: AsyncFD,
-    flags = {SocketFlag.SafeDisconn}): owned(Future[AsyncFD]) =
+             flags = {SocketFlag.SafeDisconn},
+             inheritable = defined(nimInheritHandles)): owned(Future[AsyncFD]) =
   ## Accepts a new connection. Returns a future containing the client socket
   ## corresponding to that connection.
+  ##
+  ## If ``inheritable`` is false (the default), the resulting client socket
+  ## will not be inheritable by child processes.
+  ##
   ## The future will complete when the connection is successfully accepted.
   var retFut = newFuture[AsyncFD]("accept")
-  var fut = acceptAddr(socket, flags)
+  var fut = acceptAddr(socket, flags, inheritable)
   fut.callback =
     proc (future: Future[tuple[address: string, client: AsyncFD]]) =
       assert future.finished
diff --git a/lib/pure/asyncnet.nim b/lib/pure/asyncnet.nim
index 93334d0e2..7ccb469cf 100644
--- a/lib/pure/asyncnet.nim
+++ b/lib/pure/asyncnet.nim
@@ -129,12 +129,17 @@ type
   AsyncSocket* = ref AsyncSocketDesc
 
 proc newAsyncSocket*(fd: AsyncFD, domain: Domain = AF_INET,
-    sockType: SockType = SOCK_STREAM,
-    protocol: Protocol = IPPROTO_TCP, buffered = true): owned(AsyncSocket) =
+                     sockType: SockType = SOCK_STREAM,
+                     protocol: Protocol = IPPROTO_TCP,
+                     buffered = true,
+                     inheritable = defined(nimInheritHandles)): owned(AsyncSocket) =
   ## Creates a new ``AsyncSocket`` based on the supplied params.
   ##
   ## The supplied ``fd``'s non-blocking state will be enabled implicitly.
   ##
+  ## If ``inheritable`` is false (the default), the supplied ``fd`` will not
+  ## be inheritable by child processes.
+  ##
   ## **Note**: This procedure will **NOT** register ``fd`` with the global
   ## async dispatcher. You need to do this manually. If you have used
   ## ``newAsyncNativeSocket`` to create ``fd`` then it's already registered.
@@ -142,6 +147,8 @@ proc newAsyncSocket*(fd: AsyncFD, domain: Domain = AF_INET,
   new(result)
   result.fd = fd.SocketHandle
   fd.SocketHandle.setBlocking(false)
+  if not fd.SocketHandle.setInheritable(inheritable):
+    raiseOSError(osLastError())
   result.isBuffered = buffered
   result.domain = domain
   result.sockType = sockType
@@ -150,15 +157,19 @@ proc newAsyncSocket*(fd: AsyncFD, domain: Domain = AF_INET,
     result.currPos = 0
 
 proc newAsyncSocket*(domain: Domain = AF_INET, sockType: SockType = SOCK_STREAM,
-    protocol: Protocol = IPPROTO_TCP, buffered = true): owned(AsyncSocket) =
+                     protocol: Protocol = IPPROTO_TCP, buffered = true,
+                     inheritable = defined(nimInheritHandles)): owned(AsyncSocket) =
   ## Creates a new asynchronous socket.
   ##
   ## This procedure will also create a brand new file descriptor for
   ## this socket.
-  let fd = createAsyncNativeSocket(domain, sockType, protocol)
+  ##
+  ## If ``inheritable`` is false (the default), the new file descriptor will not
+  ## be inheritable by child processes.
+  let fd = createAsyncNativeSocket(domain, sockType, protocol, inheritable)
   if fd.SocketHandle == osInvalidSocket:
     raiseOSError(osLastError())
-  result = newAsyncSocket(fd, domain, sockType, protocol, buffered)
+  result = newAsyncSocket(fd, domain, sockType, protocol, buffered, inheritable)
 
 proc getLocalAddr*(socket: AsyncSocket): (string, Port) =
   ## Get the socket's local address and port number.
@@ -173,16 +184,20 @@ proc getPeerAddr*(socket: AsyncSocket): (string, Port) =
   getPeerAddr(socket.fd, socket.domain)
 
 proc newAsyncSocket*(domain, sockType, protocol: cint,
-    buffered = true): owned(AsyncSocket) =
+                     buffered = true,
+                     inheritable = defined(nimInheritHandles)): owned(AsyncSocket) =
   ## Creates a new asynchronous socket.
   ##
   ## This procedure will also create a brand new file descriptor for
   ## this socket.
-  let fd = createAsyncNativeSocket(domain, sockType, protocol)
+  ##
+  ## If ``inheritable`` is false (the default), the new file descriptor will not
+  ## be inheritable by child processes.
+  let fd = createAsyncNativeSocket(domain, sockType, protocol, inheritable)
   if fd.SocketHandle == osInvalidSocket:
     raiseOSError(osLastError())
   result = newAsyncSocket(fd, Domain(domain), SockType(sockType),
-                          Protocol(protocol), buffered)
+                          Protocol(protocol), buffered, inheritable)
 
 when defineSsl:
   proc getSslError(handle: SslPtr, err: cint): cint =
@@ -443,13 +458,18 @@ proc send*(socket: AsyncSocket, data: string,
   else:
     await send(socket.fd.AsyncFD, data, flags)
 
-proc acceptAddr*(socket: AsyncSocket, flags = {SocketFlag.SafeDisconn}):
+proc acceptAddr*(socket: AsyncSocket, flags = {SocketFlag.SafeDisconn},
+                 inheritable = defined(nimInheritHandles)):
       owned(Future[tuple[address: string, client: AsyncSocket]]) =
   ## Accepts a new connection. Returns a future containing the client socket
   ## corresponding to that connection and the remote address of the client.
+  ##
+  ## If ``inheritable`` is false (the default), the resulting client socket will
+  ## not be inheritable by child processes.
+  ##
   ## The future will complete when the connection is successfully accepted.
   var retFuture = newFuture[tuple[address: string, client: AsyncSocket]]("asyncnet.acceptAddr")
-  var fut = acceptAddr(socket.fd.AsyncFD, flags)
+  var fut = acceptAddr(socket.fd.AsyncFD, flags, inheritable)
   fut.callback =
     proc (future: Future[tuple[address: string, client: AsyncFD]]) =
       assert future.finished
@@ -458,7 +478,7 @@ proc acceptAddr*(socket: AsyncSocket, flags = {SocketFlag.SafeDisconn}):
       else:
         let resultTup = (future.read.address,
                          newAsyncSocket(future.read.client, socket.domain,
-                         socket.sockType, socket.protocol, socket.isBuffered))
+                         socket.sockType, socket.protocol, socket.isBuffered, inheritable))
         retFuture.complete(resultTup)
   return retFuture
 
@@ -466,6 +486,8 @@ proc accept*(socket: AsyncSocket,
     flags = {SocketFlag.SafeDisconn}): owned(Future[AsyncSocket]) =
   ## Accepts a new connection. Returns a future containing the client socket
   ## corresponding to that connection.
+  ## If ``inheritable`` is false (the default), the resulting client socket will
+  ## not be inheritable by child processes.
   ## The future will complete when the connection is successfully accepted.
   var retFut = newFuture[AsyncSocket]("asyncnet.accept")
   var fut = acceptAddr(socket, flags)
diff --git a/tests/stdlib/tfdleak.nim b/tests/stdlib/tfdleak.nim
index 08ef06da3..c4f144db5 100644
--- a/tests/stdlib/tfdleak.nim
+++ b/tests/stdlib/tfdleak.nim
@@ -4,13 +4,15 @@ discard """
   matrix: "; -d:nimInheritHandles"
 """
 
-import os, osproc, strutils, nativesockets, net, selectors, memfiles
+import os, osproc, strutils, nativesockets, net, selectors, memfiles,
+       asyncdispatch, asyncnet
 when defined(windows):
   import winlean
 else:
   import posix
 
-proc leakCheck(f: int | FileHandle | SocketHandle, msg: string, expectLeak = defined(nimInheritHandles)) =
+proc leakCheck(f: AsyncFD | int | FileHandle | SocketHandle, msg: string,
+               expectLeak = defined(nimInheritHandles)) =
   discard startProcess(
     getAppFilename(),
     args = @[$f.int, msg, $expectLeak],
@@ -40,14 +42,14 @@ proc main() =
     let sock = createNativeSocket()
     defer: close sock
     leakCheck(sock, "createNativeSocket()")
-    if sock.setInheritable(true):
-      leakCheck(sock, "createNativeSocket()", true)
+    if sock.setInheritable(not defined(nimInheritHandles)):
+      leakCheck(sock, "createNativeSocket()", not defined(nimInheritHandles))
     else:
       raiseOSError osLastError()
 
     let server = newSocket()
     defer: close server
-    server.bindAddr()
+    server.bindAddr(address = "127.0.0.1")
     server.listen()
     let (_, port) = server.getLocalAddr
 
@@ -74,6 +76,30 @@ proc main() =
       leakCheck(mf.mapHandle, "memfiles.open().mapHandle", false)
     else:
       leakCheck(mf.handle, "memfiles.open().handle", false)
+
+    let sockAsync = createAsyncNativeSocket()
+    defer: closeSocket sockAsync
+    leakCheck(sockAsync, "createAsyncNativeSocket()")
+    if sockAsync.setInheritable(not defined(nimInheritHandles)):
+      leakCheck(sockAsync, "createAsyncNativeSocket()", not defined(nimInheritHandles))
+    else:
+      raiseOSError osLastError()
+
+    let serverAsync = newAsyncSocket()
+    defer: close serverAsync
+    serverAsync.bindAddr(address = "127.0.0.1")
+    serverAsync.listen()
+    let (_, portAsync) = serverAsync.getLocalAddr
+
+    leakCheck(serverAsync.getFd, "newAsyncSocket()")
+
+    let clientAsync = newAsyncSocket()
+    defer: close clientAsync
+    waitFor clientAsync.connect("127.0.0.1", portAsync)
+
+    let inputAsync = waitFor serverAsync.accept()
+
+    leakCheck(inputAsync.getFd, "accept() async")
   else:
     let
       fd = parseInt(paramStr 1)