summary refs log tree commit diff stats
path: root/lib/pure/asyncnet.nim
diff options
context:
space:
mode:
authoralaviss <leorize+oss@disroot.org>2020-05-20 07:42:55 +0000
committerGitHub <noreply@github.com>2020-05-20 09:42:55 +0200
commit4ae341353de5c58dc339e47b0eec2bbb4649dc10 (patch)
tree70933c5a55b47bbce8a5044ff034536d19aa4aa5 /lib/pure/asyncnet.nim
parent1450924b1e68ad3cd2dc8db2c54f9741315ca212 (diff)
downloadNim-4ae341353de5c58dc339e47b0eec2bbb4649dc10.tar.gz
asyncdispatch, asyncnet: add inheritance control (#14362)
* asyncdispatch, asyncnet: add inheritance control

* asyncnet, asyncdispatch: cleanup
Diffstat (limited to 'lib/pure/asyncnet.nim')
-rw-r--r--lib/pure/asyncnet.nim44
1 files changed, 33 insertions, 11 deletions
diff --git a/lib/pure/asyncnet.nim b/lib/pure/asyncnet.nim
index 93334d0e2..7ccb469cf 100644
--- a/lib/pure/asyncnet.nim
+++ b/lib/pure/asyncnet.nim
@@ -129,12 +129,17 @@ type
   AsyncSocket* = ref AsyncSocketDesc
 
 proc newAsyncSocket*(fd: AsyncFD, domain: Domain = AF_INET,
-    sockType: SockType = SOCK_STREAM,
-    protocol: Protocol = IPPROTO_TCP, buffered = true): owned(AsyncSocket) =
+                     sockType: SockType = SOCK_STREAM,
+                     protocol: Protocol = IPPROTO_TCP,
+                     buffered = true,
+                     inheritable = defined(nimInheritHandles)): owned(AsyncSocket) =
   ## Creates a new ``AsyncSocket`` based on the supplied params.
   ##
   ## The supplied ``fd``'s non-blocking state will be enabled implicitly.
   ##
+  ## If ``inheritable`` is false (the default), the supplied ``fd`` will not
+  ## be inheritable by child processes.
+  ##
   ## **Note**: This procedure will **NOT** register ``fd`` with the global
   ## async dispatcher. You need to do this manually. If you have used
   ## ``newAsyncNativeSocket`` to create ``fd`` then it's already registered.
@@ -142,6 +147,8 @@ proc newAsyncSocket*(fd: AsyncFD, domain: Domain = AF_INET,
   new(result)
   result.fd = fd.SocketHandle
   fd.SocketHandle.setBlocking(false)
+  if not fd.SocketHandle.setInheritable(inheritable):
+    raiseOSError(osLastError())
   result.isBuffered = buffered
   result.domain = domain
   result.sockType = sockType
@@ -150,15 +157,19 @@ proc newAsyncSocket*(fd: AsyncFD, domain: Domain = AF_INET,
     result.currPos = 0
 
 proc newAsyncSocket*(domain: Domain = AF_INET, sockType: SockType = SOCK_STREAM,
-    protocol: Protocol = IPPROTO_TCP, buffered = true): owned(AsyncSocket) =
+                     protocol: Protocol = IPPROTO_TCP, buffered = true,
+                     inheritable = defined(nimInheritHandles)): owned(AsyncSocket) =
   ## Creates a new asynchronous socket.
   ##
   ## This procedure will also create a brand new file descriptor for
   ## this socket.
-  let fd = createAsyncNativeSocket(domain, sockType, protocol)
+  ##
+  ## If ``inheritable`` is false (the default), the new file descriptor will not
+  ## be inheritable by child processes.
+  let fd = createAsyncNativeSocket(domain, sockType, protocol, inheritable)
   if fd.SocketHandle == osInvalidSocket:
     raiseOSError(osLastError())
-  result = newAsyncSocket(fd, domain, sockType, protocol, buffered)
+  result = newAsyncSocket(fd, domain, sockType, protocol, buffered, inheritable)
 
 proc getLocalAddr*(socket: AsyncSocket): (string, Port) =
   ## Get the socket's local address and port number.
@@ -173,16 +184,20 @@ proc getPeerAddr*(socket: AsyncSocket): (string, Port) =
   getPeerAddr(socket.fd, socket.domain)
 
 proc newAsyncSocket*(domain, sockType, protocol: cint,
-    buffered = true): owned(AsyncSocket) =
+                     buffered = true,
+                     inheritable = defined(nimInheritHandles)): owned(AsyncSocket) =
   ## Creates a new asynchronous socket.
   ##
   ## This procedure will also create a brand new file descriptor for
   ## this socket.
-  let fd = createAsyncNativeSocket(domain, sockType, protocol)
+  ##
+  ## If ``inheritable`` is false (the default), the new file descriptor will not
+  ## be inheritable by child processes.
+  let fd = createAsyncNativeSocket(domain, sockType, protocol, inheritable)
   if fd.SocketHandle == osInvalidSocket:
     raiseOSError(osLastError())
   result = newAsyncSocket(fd, Domain(domain), SockType(sockType),
-                          Protocol(protocol), buffered)
+                          Protocol(protocol), buffered, inheritable)
 
 when defineSsl:
   proc getSslError(handle: SslPtr, err: cint): cint =
@@ -443,13 +458,18 @@ proc send*(socket: AsyncSocket, data: string,
   else:
     await send(socket.fd.AsyncFD, data, flags)
 
-proc acceptAddr*(socket: AsyncSocket, flags = {SocketFlag.SafeDisconn}):
+proc acceptAddr*(socket: AsyncSocket, flags = {SocketFlag.SafeDisconn},
+                 inheritable = defined(nimInheritHandles)):
       owned(Future[tuple[address: string, client: AsyncSocket]]) =
   ## Accepts a new connection. Returns a future containing the client socket
   ## corresponding to that connection and the remote address of the client.
+  ##
+  ## If ``inheritable`` is false (the default), the resulting client socket will
+  ## not be inheritable by child processes.
+  ##
   ## The future will complete when the connection is successfully accepted.
   var retFuture = newFuture[tuple[address: string, client: AsyncSocket]]("asyncnet.acceptAddr")
-  var fut = acceptAddr(socket.fd.AsyncFD, flags)
+  var fut = acceptAddr(socket.fd.AsyncFD, flags, inheritable)
   fut.callback =
     proc (future: Future[tuple[address: string, client: AsyncFD]]) =
       assert future.finished
@@ -458,7 +478,7 @@ proc acceptAddr*(socket: AsyncSocket, flags = {SocketFlag.SafeDisconn}):
       else:
         let resultTup = (future.read.address,
                          newAsyncSocket(future.read.client, socket.domain,
-                         socket.sockType, socket.protocol, socket.isBuffered))
+                         socket.sockType, socket.protocol, socket.isBuffered, inheritable))
         retFuture.complete(resultTup)
   return retFuture
 
@@ -466,6 +486,8 @@ proc accept*(socket: AsyncSocket,
     flags = {SocketFlag.SafeDisconn}): owned(Future[AsyncSocket]) =
   ## Accepts a new connection. Returns a future containing the client socket
   ## corresponding to that connection.
+  ## If ``inheritable`` is false (the default), the resulting client socket will
+  ## not be inheritable by child processes.
   ## The future will complete when the connection is successfully accepted.
   var retFut = newFuture[AsyncSocket]("asyncnet.accept")
   var fut = acceptAddr(socket, flags)