diff options
author | bptato <nincsnevem662@gmail.com> | 2022-11-24 20:03:21 +0100 |
---|---|---|
committer | bptato <nincsnevem662@gmail.com> | 2022-11-24 20:03:21 +0100 |
commit | 896489a6c500e28f13d0237ab691622cb5c5114f (patch) | |
tree | 91b92da01bc126c2489a3dd083df5f9de06927c6 /src/ips | |
parent | ee930b0f5a587768d340c4204cf1f2e9fb818c89 (diff) | |
download | chawan-896489a6c500e28f13d0237ab691622cb5c5114f.tar.gz |
Avoid forking child processes from the main process
Caveat: this breaks piped streams.
Diffstat (limited to 'src/ips')
-rw-r--r-- | src/ips/forkserver.nim | 152 | ||||
-rw-r--r-- | src/ips/serialize.nim | 32 | ||||
-rw-r--r-- | src/ips/serversocket.nim | 5 | ||||
-rw-r--r-- | src/ips/socketstream.nim | 16 |
4 files changed, 193 insertions, 12 deletions
diff --git a/src/ips/forkserver.nim b/src/ips/forkserver.nim new file mode 100644 index 00000000..3e93402d --- /dev/null +++ b/src/ips/forkserver.nim @@ -0,0 +1,152 @@ +import streams + +when defined(posix): + import posix + +import buffer/buffer +import config/bufferconfig +import io/loader +import io/request +import io/window +import ips/serialize +import types/buffersource + +type + ForkCommand* = enum + FORK_BUFFER, FORK_LOADER, REMOVE_CHILD + + ForkServer* = ref object + process*: Pid + istream*: Stream + ostream*: Stream + + ForkServerContext = object + istream: Stream + ostream: Stream + children: seq[(Pid, Pid)] + +proc newFileLoader*(forkserver: ForkServer, defaultHeaders: HeaderList = DefaultHeaders): FileLoader = + forkserver.ostream.swrite(FORK_LOADER) + forkserver.ostream.swrite(defaultHeaders) + forkserver.ostream.flush() + forkserver.istream.sread(result) + +proc removeChild*(forkserver: Forkserver, pid: Pid) = + forkserver.ostream.swrite(REMOVE_CHILD) + forkserver.ostream.flush() + +proc forkLoader(ctx: var ForkServerContext, defaultHeaders: HeaderList): FileLoader = + var pipefd: array[2, cint] + if pipe(pipefd) == -1: + raise newException(Defect, "Failed to open pipe.") + let pid = fork() + if pid == 0: + # child process + for i in 0 ..< ctx.children.len: ctx.children[i] = (Pid(0), Pid(0)) + ctx.children.setLen(0) + zeroMem(addr ctx, sizeof(ctx)) + discard close(pipefd[0]) # close read + runFileLoader(pipefd[1], defaultHeaders) + assert false + let readfd = pipefd[0] # get read + discard close(pipefd[1]) # close write + var readf: File + if not open(readf, FileHandle(readfd), fmRead): + raise newException(Defect, "Failed to open output handle.") + assert readf.readChar() == char(0u8) + close(readf) + discard close(pipefd[0]) + return FileLoader(process: pid) + +proc forkBuffer(ctx: var ForkServerContext): Pid = + var source: BufferSource + var config: BufferConfig + var attrs: WindowAttributes + var mainproc: Pid + ctx.istream.sread(source) + ctx.istream.sread(config) + ctx.istream.sread(attrs) + ctx.istream.sread(mainproc) + let loader = ctx.forkLoader(DefaultHeaders) #TODO make this configurable + let pid = fork() + if pid == 0: + for i in 0 ..< ctx.children.len: ctx.children[i] = (Pid(0), Pid(0)) + ctx.children.setLen(0) + zeroMem(addr ctx, sizeof(ctx)) + launchBuffer(config, source, attrs, loader, mainproc) + assert false + ctx.children.add((pid, loader.process)) + return pid + +proc runForkServer() = + var ctx: ForkServerContext + ctx.istream = newFileStream(stdin) + ctx.ostream = newFileStream(stdout) + while true: + try: + var cmd: ForkCommand + ctx.istream.sread(cmd) + case cmd + of REMOVE_CHILD: + var pid: Pid + ctx.istream.sread(pid) + for i in 0 .. ctx.children.high: + if ctx.children[i][0] == pid: + ctx.children.del(i) + break + of FORK_BUFFER: + ctx.ostream.swrite(ctx.forkBuffer()) + of FORK_LOADER: + var defaultHeaders: HeaderList + ctx.istream.sread(defaultHeaders) + let loader = ctx.forkLoader(defaultHeaders) + ctx.ostream.swrite(loader) + ctx.children.add((loader.process, Pid(-1))) + ctx.ostream.flush() + except IOError: + # EOF + break + ctx.istream.close() + ctx.ostream.close() + # Clean up when the main process crashed. + for childpair in ctx.children: + let a = childpair[0] + let b = childpair[1] + discard kill(cint(a), cint(SIGTERM)) + if b != -1: + discard kill(cint(b), cint(SIGTERM)) + quit(0) + +proc newForkServer*(): ForkServer = + new(result) + var pipefd_in: array[2, cint] + var pipefd_out: array[2, cint] + if pipe(pipefd_in) == -1: + raise newException(Defect, "Failed to open input pipe.") + if pipe(pipefd_out) == -1: + raise newException(Defect, "Failed to open output pipe.") + let pid = fork() + if pid == -1: + raise newException(Defect, "Failed to fork the fork process.") + elif pid == 0: + # child process + let readfd = pipefd_in[0] + discard close(pipefd_in[1]) # close write + let writefd = pipefd_out[1] + discard close(pipefd_out[0]) # close read + discard dup2(readfd, stdin.getFileHandle()) + discard dup2(writefd, stdout.getFileHandle()) + discard close(pipefd_in[0]) + discard close(pipefd_out[1]) + runForkServer() + assert false + else: + discard close(pipefd_in[0]) # close read + discard close(pipefd_out[1]) # close write + var readf, writef: File + if not open(writef, pipefd_in[1], fmWrite): + raise newException(Defect, "Failed to open output handle") + if not open(readf, pipefd_out[0], fmRead): + raise newException(Defect, "Failed to open input handle") + result.ostream = newFileStream(writef) + result.istream = newFileStream(readf) diff --git a/src/ips/serialize.nim b/src/ips/serialize.nim index 2dde0649..82003715 100644 --- a/src/ips/serialize.nim +++ b/src/ips/serialize.nim @@ -7,6 +7,7 @@ import tables import buffer/cell import io/request import js/regex +import types/buffersource import types/color import types/url @@ -94,6 +95,15 @@ proc swrite*(stream: Stream, regex: Regex) = stream.writeData(regex.bytecode, regex.plen) stream.swrite(regex.buf) +proc swrite*(stream: Stream, source: BufferSource) = + stream.swrite(source.t) + case source.t + of CLONE: stream.swrite(source.clonepid) + of LOAD_REQUEST: stream.swrite(source.request) + of LOAD_PIPE: stream.swrite(source.fd) + stream.swrite(source.location) + stream.swrite(source.contenttype) + template sread*[T](stream: Stream, o: T) = stream.read(o) @@ -169,6 +179,10 @@ proc sread*(stream: Stream, req: var RequestObj) = stream.sread(req.body) stream.sread(req.multipart) +proc read*(stream: Stream, req: var Request) = + new(req) + stream.sread(req[]) + proc sread*(stream: Stream, color: var CellColor) = var rgb: bool stream.sread(rgb) @@ -202,6 +216,18 @@ proc sread*(stream: Stream, regex: var Regex) = if l != regex.plen: `=destroy`(regex) -proc readRequest*(stream: Stream): Request = - new(result) - stream.sread(result[]) +proc sread*(stream: Stream, source: var BufferSource) = + var t: BufferSourceType + stream.sread(t) + case t + of CLONE: + source = BufferSource(t: CLONE) + stream.sread(source.clonepid) + of LOAD_REQUEST: + source = BufferSource(t: LOAD_REQUEST) + stream.sread(source.request) + of LOAD_PIPE: + source = BufferSource(t: LOAD_PIPE) + stream.sread(source.fd) + stream.sread(source.location) + stream.sread(source.contenttype) diff --git a/src/ips/serversocket.nim b/src/ips/serversocket.nim index aa260907..8ed79ab3 100644 --- a/src/ips/serversocket.nim +++ b/src/ips/serversocket.nim @@ -12,9 +12,9 @@ const SocketPathPrefix = SocketDirectory & "cha_sock_" func getSocketPath*(pid: Pid): string = SocketPathPrefix & $pid -proc initServerSocket*(pid: Pid): ServerSocket = +proc initServerSocket*(buffered = true): ServerSocket = createDir(SocketDirectory) - result.sock = newSocket(Domain.AF_UNIX, SockType.SOCK_STREAM, Protocol.IPPROTO_IP) + result.sock = newSocket(Domain.AF_UNIX, SockType.SOCK_STREAM, Protocol.IPPROTO_IP, buffered) result.path = getSocketPath(getpid()) discard unlink(cstring(result.path)) bindUnix(result.sock, result.path) @@ -23,4 +23,3 @@ proc initServerSocket*(pid: Pid): ServerSocket = proc close*(ssock: ServerSocket) = close(ssock.sock) discard unlink(cstring(ssock.path)) - diff --git a/src/ips/socketstream.nim b/src/ips/socketstream.nim index f25f72c1..efc226bd 100644 --- a/src/ips/socketstream.nim +++ b/src/ips/socketstream.nim @@ -1,5 +1,6 @@ import nativesockets import net +import os import streams when defined(posix): @@ -15,7 +16,7 @@ proc sockReadData(s: Stream, buffer: pointer, len: int): int = let s = SocketStream(s) result = s.source.recv(buffer, len) if result < 0: - raise newException(Defect, "Failed to read data") + raise newException(Defect, "Failed to read data (code " & $osLastError() & ")") elif result < len: s.isend = true @@ -38,17 +39,20 @@ func newSocketStream*(): SocketStream = result.atEndImpl = sockAtEnd result.closeImpl = sockClose -proc connectSocketStream*(path: string): SocketStream = +proc connectSocketStream*(path: string, buffered = true): SocketStream = result = newSocketStream() - let sock = newSocket(Domain.AF_UNIX, SockType.SOCK_STREAM, Protocol.IPPROTO_IP) + let sock = newSocket(Domain.AF_UNIX, SockType.SOCK_STREAM, Protocol.IPPROTO_IP, buffered) connectUnix(sock, path) result.source = sock -proc connectSocketStream*(pid: Pid): SocketStream = - connectSocketStream(getSocketPath(pid)) +proc connectSocketStream*(pid: Pid, buffered = true): SocketStream = + try: + connectSocketStream(getSocketPath(pid), buffered) + except OSError: + return nil proc acceptSocketStream*(ssock: ServerSocket): SocketStream = result = newSocketStream() var sock: Socket - ssock.sock.accept(sock) + ssock.sock.accept(sock, inheritable = true) result.source = sock |