summary refs log tree commit diff stats
path: root/lib
diff options
context:
space:
mode:
authoralaviss <leorize+oss@disroot.org>2020-05-20 07:42:55 +0000
committerGitHub <noreply@github.com>2020-05-20 09:42:55 +0200
commit4ae341353de5c58dc339e47b0eec2bbb4649dc10 (patch)
tree70933c5a55b47bbce8a5044ff034536d19aa4aa5 /lib
parent1450924b1e68ad3cd2dc8db2c54f9741315ca212 (diff)
downloadNim-4ae341353de5c58dc339e47b0eec2bbb4649dc10.tar.gz
asyncdispatch, asyncnet: add inheritance control (#14362)
* asyncdispatch, asyncnet: add inheritance control

* asyncnet, asyncdispatch: cleanup
Diffstat (limited to 'lib')
-rw-r--r--lib/pure/asyncdispatch.nim71
-rw-r--r--lib/pure/asyncnet.nim44
2 files changed, 90 insertions, 25 deletions
diff --git a/lib/pure/asyncdispatch.nim b/lib/pure/asyncdispatch.nim
index 91778a5bc..0800cb638 100644
--- a/lib/pure/asyncdispatch.nim
+++ b/lib/pure/asyncdispatch.nim
@@ -228,6 +228,16 @@ proc initCallSoonProc =
   if asyncfutures.getCallSoonProc().isNil:
     asyncfutures.setCallSoonProc(callSoon)
 
+template implementSetInheritable() {.dirty.} =
+  when declared(setInheritable):
+    proc setInheritable*(fd: AsyncFD, inheritable: bool): bool =
+      ## Control whether a file handle can be inherited by child processes.
+      ## Returns ``true`` on success.
+      ##
+      ## This procedure is not guaranteed to be available for all platforms.
+      ## Test for availability with `declared()`_.
+      fd.FileHandle.setInheritable(inheritable)
+
 when defined(windows) or defined(nimdoc):
   import winlean, sets, hashes
   type
@@ -695,7 +705,8 @@ when defined(windows) or defined(nimdoc):
           retFuture.complete(bytesReceived)
     return retFuture
 
-  proc acceptAddr*(socket: AsyncFD, flags = {SocketFlag.SafeDisconn}):
+  proc acceptAddr*(socket: AsyncFD, flags = {SocketFlag.SafeDisconn},
+                   inheritable = defined(nimInheritHandles)):
       owned(Future[tuple[address: string, client: AsyncFD]]) =
     ## Accepts a new connection. Returns a future containing the client socket
     ## corresponding to that connection and the remote address of the client.
@@ -704,6 +715,9 @@ when defined(windows) or defined(nimdoc):
     ## The resulting client socket is automatically registered to the
     ## dispatcher.
     ##
+    ## If ``inheritable`` is false (the default), the resulting client socket will
+    ## not be inheritable by child processes.
+    ##
     ## The ``accept`` call may result in an error if the connecting socket
     ## disconnects during the duration of the ``accept``. If the ``SafeDisconn``
     ## flag is specified then this error will not be raised and instead
@@ -711,7 +725,7 @@ when defined(windows) or defined(nimdoc):
     verifyPresence(socket)
     var retFuture = newFuture[tuple[address: string, client: AsyncFD]]("acceptAddr")
 
-    var clientSock = createNativeSocket()
+    var clientSock = createNativeSocket(inheritable = inheritable)
     if clientSock == osInvalidSocket: raiseOSError(osLastError())
 
     const lpOutputLen = 1024
@@ -788,6 +802,8 @@ when defined(windows) or defined(nimdoc):
 
     return retFuture
 
+  implementSetInheritable()
+
   proc closeSocket*(socket: AsyncFD) =
     ## Closes a socket and ensures that it is unregistered.
     socket.SocketHandle.close()
@@ -1090,6 +1106,9 @@ else:
   import selectors
   from posix import EINTR, EAGAIN, EINPROGRESS, EWOULDBLOCK, MSG_PEEK,
                     MSG_NOSIGNAL
+  when declared(posix.accept4):
+    from posix import accept4, SOCK_CLOEXEC
+
   const
     InitCallbackListSize = 4         # initial size of callbacks sequence,
                                      # associated with file/socket descriptor.
@@ -1263,6 +1282,8 @@ else:
       # descriptor was unregistered in callback via `unregister()`.
       discard
 
+  implementSetInheritable()
+
   proc closeSocket*(sock: AsyncFD) =
     let selector = getGlobalDispatcher().selector
     if sock.SocketHandle notin selector:
@@ -1484,7 +1505,8 @@ else:
     addRead(socket, cb)
     return retFuture
 
-  proc acceptAddr*(socket: AsyncFD, flags = {SocketFlag.SafeDisconn}):
+  proc acceptAddr*(socket: AsyncFD, flags = {SocketFlag.SafeDisconn},
+                   inheritable = defined(nimInheritHandles)):
       owned(Future[tuple[address: string, client: AsyncFD]]) =
     var retFuture = newFuture[tuple[address: string,
         client: AsyncFD]]("acceptAddr")
@@ -1492,8 +1514,21 @@ else:
       result = true
       var sockAddress: Sockaddr_storage
       var addrLen = sizeof(sockAddress).SockLen
-      var client = accept(sock.SocketHandle,
-                          cast[ptr SockAddr](addr(sockAddress)), addr(addrLen))
+      var client =
+        when declared(accept4):
+          accept4(sock.SocketHandle, cast[ptr SockAddr](addr(sockAddress)),
+                  addr(addrLen), if inheritable: 0 else: SOCK_CLOEXEC)
+        else:
+          accept(sock.SocketHandle, cast[ptr SockAddr](addr(sockAddress)),
+                 addr(addrLen))
+      when declared(setInheritable) and not declared(accept4):
+        if client != osInvalidSocket and not setInheritable(client, inheritable):
+          # Set failure first because close() itself can fail,
+          # altering osLastError().
+          retFuture.fail(newOSError(osLastError()))
+          close client
+          return false
+
       if client == osInvalidSocket:
         let lastError = osLastError()
         assert lastError.int32 != EWOULDBLOCK and lastError.int32 != EAGAIN
@@ -1578,8 +1613,9 @@ proc poll*(timeout = 500) =
   ## `epoll`:idx: or `kqueue`:idx: primitive only once.
   discard runOnce(timeout)
 
-template createAsyncNativeSocketImpl(domain, sockType, protocol) =
-  let handle = createNativeSocket(domain, sockType, protocol)
+template createAsyncNativeSocketImpl(domain, sockType, protocol: untyped,
+                                     inheritable = defined(nimInheritHandles)) =
+  let handle = createNativeSocket(domain, sockType, protocol, inheritable)
   if handle == osInvalidSocket:
     return osInvalidSocket.AsyncFD
   handle.setBlocking(false)
@@ -1589,13 +1625,15 @@ template createAsyncNativeSocketImpl(domain, sockType, protocol) =
   register(result)
 
 proc createAsyncNativeSocket*(domain: cint, sockType: cint,
-                           protocol: cint): AsyncFD =
-  createAsyncNativeSocketImpl(domain, sockType, protocol)
+                              protocol: cint,
+                              inheritable = defined(nimInheritHandles)): AsyncFD =
+  createAsyncNativeSocketImpl(domain, sockType, protocol, inheritable)
 
 proc createAsyncNativeSocket*(domain: Domain = Domain.AF_INET,
-                           sockType: SockType = SOCK_STREAM,
-                           protocol: Protocol = IPPROTO_TCP): AsyncFD =
-  createAsyncNativeSocketImpl(domain, sockType, protocol)
+                              sockType: SockType = SOCK_STREAM,
+                              protocol: Protocol = IPPROTO_TCP,
+                              inheritable = defined(nimInheritHandles)): AsyncFD =
+  createAsyncNativeSocketImpl(domain, sockType, protocol, inheritable)
 
 proc newAsyncNativeSocket*(domain: cint, sockType: cint,
     protocol: cint): AsyncFD {.deprecated: "use createAsyncNativeSocket instead".} =
@@ -1824,12 +1862,17 @@ proc withTimeout*[T](fut: Future[T], timeout: int): owned(Future[bool]) =
   return retFuture
 
 proc accept*(socket: AsyncFD,
-    flags = {SocketFlag.SafeDisconn}): owned(Future[AsyncFD]) =
+             flags = {SocketFlag.SafeDisconn},
+             inheritable = defined(nimInheritHandles)): owned(Future[AsyncFD]) =
   ## Accepts a new connection. Returns a future containing the client socket
   ## corresponding to that connection.
+  ##
+  ## If ``inheritable`` is false (the default), the resulting client socket
+  ## will not be inheritable by child processes.
+  ##
   ## The future will complete when the connection is successfully accepted.
   var retFut = newFuture[AsyncFD]("accept")
-  var fut = acceptAddr(socket, flags)
+  var fut = acceptAddr(socket, flags, inheritable)
   fut.callback =
     proc (future: Future[tuple[address: string, client: AsyncFD]]) =
       assert future.finished
diff --git a/lib/pure/asyncnet.nim b/lib/pure/asyncnet.nim
index 93334d0e2..7ccb469cf 100644
--- a/lib/pure/asyncnet.nim
+++ b/lib/pure/asyncnet.nim
@@ -129,12 +129,17 @@ type
   AsyncSocket* = ref AsyncSocketDesc
 
 proc newAsyncSocket*(fd: AsyncFD, domain: Domain = AF_INET,
-    sockType: SockType = SOCK_STREAM,
-    protocol: Protocol = IPPROTO_TCP, buffered = true): owned(AsyncSocket) =
+                     sockType: SockType = SOCK_STREAM,
+                     protocol: Protocol = IPPROTO_TCP,
+                     buffered = true,
+                     inheritable = defined(nimInheritHandles)): owned(AsyncSocket) =
   ## Creates a new ``AsyncSocket`` based on the supplied params.
   ##
   ## The supplied ``fd``'s non-blocking state will be enabled implicitly.
   ##
+  ## If ``inheritable`` is false (the default), the supplied ``fd`` will not
+  ## be inheritable by child processes.
+  ##
   ## **Note**: This procedure will **NOT** register ``fd`` with the global
   ## async dispatcher. You need to do this manually. If you have used
   ## ``newAsyncNativeSocket`` to create ``fd`` then it's already registered.
@@ -142,6 +147,8 @@ proc newAsyncSocket*(fd: AsyncFD, domain: Domain = AF_INET,
   new(result)
   result.fd = fd.SocketHandle
   fd.SocketHandle.setBlocking(false)
+  if not fd.SocketHandle.setInheritable(inheritable):
+    raiseOSError(osLastError())
   result.isBuffered = buffered
   result.domain = domain
   result.sockType = sockType
@@ -150,15 +157,19 @@ proc newAsyncSocket*(fd: AsyncFD, domain: Domain = AF_INET,
     result.currPos = 0
 
 proc newAsyncSocket*(domain: Domain = AF_INET, sockType: SockType = SOCK_STREAM,
-    protocol: Protocol = IPPROTO_TCP, buffered = true): owned(AsyncSocket) =
+                     protocol: Protocol = IPPROTO_TCP, buffered = true,
+                     inheritable = defined(nimInheritHandles)): owned(AsyncSocket) =
   ## Creates a new asynchronous socket.
   ##
   ## This procedure will also create a brand new file descriptor for
   ## this socket.
-  let fd = createAsyncNativeSocket(domain, sockType, protocol)
+  ##
+  ## If ``inheritable`` is false (the default), the new file descriptor will not
+  ## be inheritable by child processes.
+  let fd = createAsyncNativeSocket(domain, sockType, protocol, inheritable)
   if fd.SocketHandle == osInvalidSocket:
     raiseOSError(osLastError())
-  result = newAsyncSocket(fd, domain, sockType, protocol, buffered)
+  result = newAsyncSocket(fd, domain, sockType, protocol, buffered, inheritable)
 
 proc getLocalAddr*(socket: AsyncSocket): (string, Port) =
   ## Get the socket's local address and port number.
@@ -173,16 +184,20 @@ proc getPeerAddr*(socket: AsyncSocket): (string, Port) =
   getPeerAddr(socket.fd, socket.domain)
 
 proc newAsyncSocket*(domain, sockType, protocol: cint,
-    buffered = true): owned(AsyncSocket) =
+                     buffered = true,
+                     inheritable = defined(nimInheritHandles)): owned(AsyncSocket) =
   ## Creates a new asynchronous socket.
   ##
   ## This procedure will also create a brand new file descriptor for
   ## this socket.
-  let fd = createAsyncNativeSocket(domain, sockType, protocol)
+  ##
+  ## If ``inheritable`` is false (the default), the new file descriptor will not
+  ## be inheritable by child processes.
+  let fd = createAsyncNativeSocket(domain, sockType, protocol, inheritable)
   if fd.SocketHandle == osInvalidSocket:
     raiseOSError(osLastError())
   result = newAsyncSocket(fd, Domain(domain), SockType(sockType),
-                          Protocol(protocol), buffered)
+                          Protocol(protocol), buffered, inheritable)
 
 when defineSsl:
   proc getSslError(handle: SslPtr, err: cint): cint =
@@ -443,13 +458,18 @@ proc send*(socket: AsyncSocket, data: string,
   else:
     await send(socket.fd.AsyncFD, data, flags)
 
-proc acceptAddr*(socket: AsyncSocket, flags = {SocketFlag.SafeDisconn}):
+proc acceptAddr*(socket: AsyncSocket, flags = {SocketFlag.SafeDisconn},
+                 inheritable = defined(nimInheritHandles)):
       owned(Future[tuple[address: string, client: AsyncSocket]]) =
   ## Accepts a new connection. Returns a future containing the client socket
   ## corresponding to that connection and the remote address of the client.
+  ##
+  ## If ``inheritable`` is false (the default), the resulting client socket will
+  ## not be inheritable by child processes.
+  ##
   ## The future will complete when the connection is successfully accepted.
   var retFuture = newFuture[tuple[address: string, client: AsyncSocket]]("asyncnet.acceptAddr")
-  var fut = acceptAddr(socket.fd.AsyncFD, flags)
+  var fut = acceptAddr(socket.fd.AsyncFD, flags, inheritable)
   fut.callback =
     proc (future: Future[tuple[address: string, client: AsyncFD]]) =
       assert future.finished
@@ -458,7 +478,7 @@ proc acceptAddr*(socket: AsyncSocket, flags = {SocketFlag.SafeDisconn}):
       else:
         let resultTup = (future.read.address,
                          newAsyncSocket(future.read.client, socket.domain,
-                         socket.sockType, socket.protocol, socket.isBuffered))
+                         socket.sockType, socket.protocol, socket.isBuffered, inheritable))
         retFuture.complete(resultTup)
   return retFuture
 
@@ -466,6 +486,8 @@ proc accept*(socket: AsyncSocket,
     flags = {SocketFlag.SafeDisconn}): owned(Future[AsyncSocket]) =
   ## Accepts a new connection. Returns a future containing the client socket
   ## corresponding to that connection.
+  ## If ``inheritable`` is false (the default), the resulting client socket will
+  ## not be inheritable by child processes.
   ## The future will complete when the connection is successfully accepted.
   var retFut = newFuture[AsyncSocket]("asyncnet.accept")
   var fut = acceptAddr(socket, flags)