diff options
author | alaviss <leorize+oss@disroot.org> | 2020-05-20 07:42:55 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-05-20 09:42:55 +0200 |
commit | 4ae341353de5c58dc339e47b0eec2bbb4649dc10 (patch) | |
tree | 70933c5a55b47bbce8a5044ff034536d19aa4aa5 /lib | |
parent | 1450924b1e68ad3cd2dc8db2c54f9741315ca212 (diff) | |
download | Nim-4ae341353de5c58dc339e47b0eec2bbb4649dc10.tar.gz |
asyncdispatch, asyncnet: add inheritance control (#14362)
* asyncdispatch, asyncnet: add inheritance control * asyncnet, asyncdispatch: cleanup
Diffstat (limited to 'lib')
-rw-r--r-- | lib/pure/asyncdispatch.nim | 71 | ||||
-rw-r--r-- | lib/pure/asyncnet.nim | 44 |
2 files changed, 90 insertions, 25 deletions
diff --git a/lib/pure/asyncdispatch.nim b/lib/pure/asyncdispatch.nim index 91778a5bc..0800cb638 100644 --- a/lib/pure/asyncdispatch.nim +++ b/lib/pure/asyncdispatch.nim @@ -228,6 +228,16 @@ proc initCallSoonProc = if asyncfutures.getCallSoonProc().isNil: asyncfutures.setCallSoonProc(callSoon) +template implementSetInheritable() {.dirty.} = + when declared(setInheritable): + proc setInheritable*(fd: AsyncFD, inheritable: bool): bool = + ## Control whether a file handle can be inherited by child processes. + ## Returns ``true`` on success. + ## + ## This procedure is not guaranteed to be available for all platforms. + ## Test for availability with `declared()`_. + fd.FileHandle.setInheritable(inheritable) + when defined(windows) or defined(nimdoc): import winlean, sets, hashes type @@ -695,7 +705,8 @@ when defined(windows) or defined(nimdoc): retFuture.complete(bytesReceived) return retFuture - proc acceptAddr*(socket: AsyncFD, flags = {SocketFlag.SafeDisconn}): + proc acceptAddr*(socket: AsyncFD, flags = {SocketFlag.SafeDisconn}, + inheritable = defined(nimInheritHandles)): owned(Future[tuple[address: string, client: AsyncFD]]) = ## Accepts a new connection. Returns a future containing the client socket ## corresponding to that connection and the remote address of the client. @@ -704,6 +715,9 @@ when defined(windows) or defined(nimdoc): ## The resulting client socket is automatically registered to the ## dispatcher. ## + ## If ``inheritable`` is false (the default), the resulting client socket will + ## not be inheritable by child processes. + ## ## The ``accept`` call may result in an error if the connecting socket ## disconnects during the duration of the ``accept``. If the ``SafeDisconn`` ## flag is specified then this error will not be raised and instead @@ -711,7 +725,7 @@ when defined(windows) or defined(nimdoc): verifyPresence(socket) var retFuture = newFuture[tuple[address: string, client: AsyncFD]]("acceptAddr") - var clientSock = createNativeSocket() + var clientSock = createNativeSocket(inheritable = inheritable) if clientSock == osInvalidSocket: raiseOSError(osLastError()) const lpOutputLen = 1024 @@ -788,6 +802,8 @@ when defined(windows) or defined(nimdoc): return retFuture + implementSetInheritable() + proc closeSocket*(socket: AsyncFD) = ## Closes a socket and ensures that it is unregistered. socket.SocketHandle.close() @@ -1090,6 +1106,9 @@ else: import selectors from posix import EINTR, EAGAIN, EINPROGRESS, EWOULDBLOCK, MSG_PEEK, MSG_NOSIGNAL + when declared(posix.accept4): + from posix import accept4, SOCK_CLOEXEC + const InitCallbackListSize = 4 # initial size of callbacks sequence, # associated with file/socket descriptor. @@ -1263,6 +1282,8 @@ else: # descriptor was unregistered in callback via `unregister()`. discard + implementSetInheritable() + proc closeSocket*(sock: AsyncFD) = let selector = getGlobalDispatcher().selector if sock.SocketHandle notin selector: @@ -1484,7 +1505,8 @@ else: addRead(socket, cb) return retFuture - proc acceptAddr*(socket: AsyncFD, flags = {SocketFlag.SafeDisconn}): + proc acceptAddr*(socket: AsyncFD, flags = {SocketFlag.SafeDisconn}, + inheritable = defined(nimInheritHandles)): owned(Future[tuple[address: string, client: AsyncFD]]) = var retFuture = newFuture[tuple[address: string, client: AsyncFD]]("acceptAddr") @@ -1492,8 +1514,21 @@ else: result = true var sockAddress: Sockaddr_storage var addrLen = sizeof(sockAddress).SockLen - var client = accept(sock.SocketHandle, - cast[ptr SockAddr](addr(sockAddress)), addr(addrLen)) + var client = + when declared(accept4): + accept4(sock.SocketHandle, cast[ptr SockAddr](addr(sockAddress)), + addr(addrLen), if inheritable: 0 else: SOCK_CLOEXEC) + else: + accept(sock.SocketHandle, cast[ptr SockAddr](addr(sockAddress)), + addr(addrLen)) + when declared(setInheritable) and not declared(accept4): + if client != osInvalidSocket and not setInheritable(client, inheritable): + # Set failure first because close() itself can fail, + # altering osLastError(). + retFuture.fail(newOSError(osLastError())) + close client + return false + if client == osInvalidSocket: let lastError = osLastError() assert lastError.int32 != EWOULDBLOCK and lastError.int32 != EAGAIN @@ -1578,8 +1613,9 @@ proc poll*(timeout = 500) = ## `epoll`:idx: or `kqueue`:idx: primitive only once. discard runOnce(timeout) -template createAsyncNativeSocketImpl(domain, sockType, protocol) = - let handle = createNativeSocket(domain, sockType, protocol) +template createAsyncNativeSocketImpl(domain, sockType, protocol: untyped, + inheritable = defined(nimInheritHandles)) = + let handle = createNativeSocket(domain, sockType, protocol, inheritable) if handle == osInvalidSocket: return osInvalidSocket.AsyncFD handle.setBlocking(false) @@ -1589,13 +1625,15 @@ template createAsyncNativeSocketImpl(domain, sockType, protocol) = register(result) proc createAsyncNativeSocket*(domain: cint, sockType: cint, - protocol: cint): AsyncFD = - createAsyncNativeSocketImpl(domain, sockType, protocol) + protocol: cint, + inheritable = defined(nimInheritHandles)): AsyncFD = + createAsyncNativeSocketImpl(domain, sockType, protocol, inheritable) proc createAsyncNativeSocket*(domain: Domain = Domain.AF_INET, - sockType: SockType = SOCK_STREAM, - protocol: Protocol = IPPROTO_TCP): AsyncFD = - createAsyncNativeSocketImpl(domain, sockType, protocol) + sockType: SockType = SOCK_STREAM, + protocol: Protocol = IPPROTO_TCP, + inheritable = defined(nimInheritHandles)): AsyncFD = + createAsyncNativeSocketImpl(domain, sockType, protocol, inheritable) proc newAsyncNativeSocket*(domain: cint, sockType: cint, protocol: cint): AsyncFD {.deprecated: "use createAsyncNativeSocket instead".} = @@ -1824,12 +1862,17 @@ proc withTimeout*[T](fut: Future[T], timeout: int): owned(Future[bool]) = return retFuture proc accept*(socket: AsyncFD, - flags = {SocketFlag.SafeDisconn}): owned(Future[AsyncFD]) = + flags = {SocketFlag.SafeDisconn}, + inheritable = defined(nimInheritHandles)): owned(Future[AsyncFD]) = ## Accepts a new connection. Returns a future containing the client socket ## corresponding to that connection. + ## + ## If ``inheritable`` is false (the default), the resulting client socket + ## will not be inheritable by child processes. + ## ## The future will complete when the connection is successfully accepted. var retFut = newFuture[AsyncFD]("accept") - var fut = acceptAddr(socket, flags) + var fut = acceptAddr(socket, flags, inheritable) fut.callback = proc (future: Future[tuple[address: string, client: AsyncFD]]) = assert future.finished diff --git a/lib/pure/asyncnet.nim b/lib/pure/asyncnet.nim index 93334d0e2..7ccb469cf 100644 --- a/lib/pure/asyncnet.nim +++ b/lib/pure/asyncnet.nim @@ -129,12 +129,17 @@ type AsyncSocket* = ref AsyncSocketDesc proc newAsyncSocket*(fd: AsyncFD, domain: Domain = AF_INET, - sockType: SockType = SOCK_STREAM, - protocol: Protocol = IPPROTO_TCP, buffered = true): owned(AsyncSocket) = + sockType: SockType = SOCK_STREAM, + protocol: Protocol = IPPROTO_TCP, + buffered = true, + inheritable = defined(nimInheritHandles)): owned(AsyncSocket) = ## Creates a new ``AsyncSocket`` based on the supplied params. ## ## The supplied ``fd``'s non-blocking state will be enabled implicitly. ## + ## If ``inheritable`` is false (the default), the supplied ``fd`` will not + ## be inheritable by child processes. + ## ## **Note**: This procedure will **NOT** register ``fd`` with the global ## async dispatcher. You need to do this manually. If you have used ## ``newAsyncNativeSocket`` to create ``fd`` then it's already registered. @@ -142,6 +147,8 @@ proc newAsyncSocket*(fd: AsyncFD, domain: Domain = AF_INET, new(result) result.fd = fd.SocketHandle fd.SocketHandle.setBlocking(false) + if not fd.SocketHandle.setInheritable(inheritable): + raiseOSError(osLastError()) result.isBuffered = buffered result.domain = domain result.sockType = sockType @@ -150,15 +157,19 @@ proc newAsyncSocket*(fd: AsyncFD, domain: Domain = AF_INET, result.currPos = 0 proc newAsyncSocket*(domain: Domain = AF_INET, sockType: SockType = SOCK_STREAM, - protocol: Protocol = IPPROTO_TCP, buffered = true): owned(AsyncSocket) = + protocol: Protocol = IPPROTO_TCP, buffered = true, + inheritable = defined(nimInheritHandles)): owned(AsyncSocket) = ## Creates a new asynchronous socket. ## ## This procedure will also create a brand new file descriptor for ## this socket. - let fd = createAsyncNativeSocket(domain, sockType, protocol) + ## + ## If ``inheritable`` is false (the default), the new file descriptor will not + ## be inheritable by child processes. + let fd = createAsyncNativeSocket(domain, sockType, protocol, inheritable) if fd.SocketHandle == osInvalidSocket: raiseOSError(osLastError()) - result = newAsyncSocket(fd, domain, sockType, protocol, buffered) + result = newAsyncSocket(fd, domain, sockType, protocol, buffered, inheritable) proc getLocalAddr*(socket: AsyncSocket): (string, Port) = ## Get the socket's local address and port number. @@ -173,16 +184,20 @@ proc getPeerAddr*(socket: AsyncSocket): (string, Port) = getPeerAddr(socket.fd, socket.domain) proc newAsyncSocket*(domain, sockType, protocol: cint, - buffered = true): owned(AsyncSocket) = + buffered = true, + inheritable = defined(nimInheritHandles)): owned(AsyncSocket) = ## Creates a new asynchronous socket. ## ## This procedure will also create a brand new file descriptor for ## this socket. - let fd = createAsyncNativeSocket(domain, sockType, protocol) + ## + ## If ``inheritable`` is false (the default), the new file descriptor will not + ## be inheritable by child processes. + let fd = createAsyncNativeSocket(domain, sockType, protocol, inheritable) if fd.SocketHandle == osInvalidSocket: raiseOSError(osLastError()) result = newAsyncSocket(fd, Domain(domain), SockType(sockType), - Protocol(protocol), buffered) + Protocol(protocol), buffered, inheritable) when defineSsl: proc getSslError(handle: SslPtr, err: cint): cint = @@ -443,13 +458,18 @@ proc send*(socket: AsyncSocket, data: string, else: await send(socket.fd.AsyncFD, data, flags) -proc acceptAddr*(socket: AsyncSocket, flags = {SocketFlag.SafeDisconn}): +proc acceptAddr*(socket: AsyncSocket, flags = {SocketFlag.SafeDisconn}, + inheritable = defined(nimInheritHandles)): owned(Future[tuple[address: string, client: AsyncSocket]]) = ## Accepts a new connection. Returns a future containing the client socket ## corresponding to that connection and the remote address of the client. + ## + ## If ``inheritable`` is false (the default), the resulting client socket will + ## not be inheritable by child processes. + ## ## The future will complete when the connection is successfully accepted. var retFuture = newFuture[tuple[address: string, client: AsyncSocket]]("asyncnet.acceptAddr") - var fut = acceptAddr(socket.fd.AsyncFD, flags) + var fut = acceptAddr(socket.fd.AsyncFD, flags, inheritable) fut.callback = proc (future: Future[tuple[address: string, client: AsyncFD]]) = assert future.finished @@ -458,7 +478,7 @@ proc acceptAddr*(socket: AsyncSocket, flags = {SocketFlag.SafeDisconn}): else: let resultTup = (future.read.address, newAsyncSocket(future.read.client, socket.domain, - socket.sockType, socket.protocol, socket.isBuffered)) + socket.sockType, socket.protocol, socket.isBuffered, inheritable)) retFuture.complete(resultTup) return retFuture @@ -466,6 +486,8 @@ proc accept*(socket: AsyncSocket, flags = {SocketFlag.SafeDisconn}): owned(Future[AsyncSocket]) = ## Accepts a new connection. Returns a future containing the client socket ## corresponding to that connection. + ## If ``inheritable`` is false (the default), the resulting client socket will + ## not be inheritable by child processes. ## The future will complete when the connection is successfully accepted. var retFut = newFuture[AsyncSocket]("asyncnet.accept") var fut = acceptAddr(socket, flags) |