about summary refs log tree commit diff stats
path: root/src/loader/loader.nim
diff options
context:
space:
mode:
Diffstat (limited to 'src/loader/loader.nim')
-rw-r--r--src/loader/loader.nim160
1 files changed, 73 insertions, 87 deletions
diff --git a/src/loader/loader.nim b/src/loader/loader.nim
index c84f247a..89d97cde 100644
--- a/src/loader/loader.nim
+++ b/src/loader/loader.nim
@@ -8,11 +8,17 @@
 #  S: output ID
 #  S: status code
 #  S: headers
+#  C: resume
 #  S: response body
 # else:
 #  S: error message
 #
 # The body is passed to the stream as-is, so effectively nothing can follow it.
+#
+# Note: if the consumer closes the request's body after headers have been
+# passed, it will *not* be cleaned up until a `resume' command is
+# received. (This allows for passing outputIds to the pager for later
+# addCacheFile commands there.)
 
 import std/deques
 import std/nativesockets
@@ -57,7 +63,7 @@ type
     process*: int
     clientPid*: int
     connecting*: Table[int, ConnectData]
-    ongoing*: Table[int, OngoingData]
+    ongoing*: Table[int, Response]
     unregistered*: seq[int]
     registerFun*: proc(fd: int)
     unregisterFun*: proc(fd: int)
@@ -71,11 +77,6 @@ type
     stream*: SocketStream
     request: Request
 
-  OngoingData* = object
-    buf: string
-    response*: Response
-    bodyRead: Promise[string]
-
   LoaderCommand = enum
     lcAddCacheFile
     lcAddClient
@@ -155,10 +156,12 @@ proc rejectHandle(handle: LoaderHandle; code: ConnectErrorCode; msg = "") =
   handle.sendResult(code, msg)
   handle.close()
 
-func findOutput(ctx: LoaderContext; id: int): OutputHandle =
+func findOutput(ctx: LoaderContext; id: int; client: ClientData): OutputHandle =
   assert id != -1
   for it in ctx.outputMap.values:
     if it.outputId == id:
+      # verify that it's safe to access this handle.
+      doAssert ctx.isPrivileged(client) or client.pid == it.ownerPid
       return it
   return nil
 
@@ -211,11 +214,8 @@ proc getOutputId(ctx: LoaderContext): int =
   result = ctx.outputNum
   inc ctx.outputNum
 
-proc redirectToFile(ctx: LoaderContext; output: OutputHandle;
-    targetPath: string): bool =
-  let ps = newPosixStream(targetPath, O_CREAT or O_WRONLY, 0o600)
-  if ps == nil:
-    return false
+proc redirectToStream(ctx: LoaderContext; output: OutputHandle;
+    ps: PosixStream): bool =
   if output.currentBuffer != nil:
     let n = ps.sendData(output.currentBuffer, output.currentBufferIdx)
     if unlikely(n < output.currentBuffer.len - output.currentBufferIdx):
@@ -226,7 +226,9 @@ proc redirectToFile(ctx: LoaderContext; output: OutputHandle;
     if unlikely(n < buffer.len):
       ps.sclose()
       return false
-  if output.parent != nil:
+  if output.istreamAtEnd:
+    ps.sclose()
+  elif output.parent != nil:
     output.parent.outputs.add(OutputHandle(
       parent: output.parent,
       ostream: ps,
@@ -235,6 +237,13 @@ proc redirectToFile(ctx: LoaderContext; output: OutputHandle;
     ))
   return true
 
+proc redirectToFile(ctx: LoaderContext; output: OutputHandle;
+    targetPath: string): bool =
+  let ps = newPosixStream(targetPath, O_CREAT or O_WRONLY, 0o600)
+  if ps == nil:
+    return false
+  return ctx.redirectToStream(output, ps)
+
 type AddCacheFileResult = tuple[outputId: int; cacheFile: string]
 
 proc addCacheFile(ctx: LoaderContext; client: ClientData; output: OutputHandle):
@@ -335,8 +344,7 @@ proc loadStreamRegular(ctx: LoaderContext; handle, cachedHandle: LoaderHandle) =
       output.ostream.sclose()
       output.ostream = nil
   handle.outputs.setLen(0)
-  handle.istream.sclose()
-  handle.istream = nil
+  handle.iclose()
 
 proc loadStream(ctx: LoaderContext; handle: LoaderHandle; request: Request) =
   ctx.passedFdMap.withValue(request.url.pathname, fdp):
@@ -406,10 +414,17 @@ proc loadResource(ctx: LoaderContext; client: ClientData; config: LoaderClientCo
           redo = true
           continue
     if request.url.scheme == "cgi-bin":
+      var ostream: PosixStream = nil
       handle.loadCGI(request, ctx.config.cgiDir, prevurl,
-        config.insecureSSLNoVerify)
+        config.insecureSSLNoVerify, ostream)
       if handle.istream != nil:
         ctx.addFd(handle)
+        if ostream != nil:
+          let output = ctx.findOutput(request.body.outputId, client)
+          if output != nil:
+            doAssert ctx.redirectToStream(output, ostream)
+          else:
+            ostream.sclose()
       else:
         handle.close()
     elif request.url.scheme == "stream":
@@ -451,8 +466,7 @@ proc setupRequestDefaults(request: Request; config: LoaderClientConfig) =
 
 proc load(ctx: LoaderContext; stream: SocketStream; request: Request;
     client: ClientData; config: LoaderClientConfig) =
-  let handle = newLoaderHandle(stream, ctx.getOutputId(), client.pid,
-    request.suspended)
+  let handle = newLoaderHandle(stream, ctx.getOutputId(), client.pid)
   when defined(debug):
     handle.url = request.url
     handle.output.url = request.url
@@ -514,9 +528,12 @@ proc addCacheFile(ctx: LoaderContext; stream: SocketStream;
     r: var BufferedReader) =
   var outputId: int
   var targetPid: int
+  var sourcePid: int
   r.sread(outputId)
   r.sread(targetPid)
-  let output = ctx.findOutput(outputId)
+  r.sread(sourcePid)
+  let sourceClient = ctx.clientData[sourcePid]
+  let output = ctx.findOutput(outputId, sourceClient)
   assert output != nil
   let targetClient = ctx.clientData[targetPid]
   let (id, file) = ctx.addCacheFile(targetClient, output)
@@ -531,7 +548,7 @@ proc redirectToFile(ctx: LoaderContext; stream: SocketStream;
   var targetPath: string
   r.sread(outputId)
   r.sread(targetPath)
-  let output = ctx.findOutput(outputId)
+  let output = ctx.findOutput(outputId, ctx.pagerClient)
   var success = false
   if output != nil:
     success = ctx.redirectToFile(output, targetPath)
@@ -583,9 +600,7 @@ proc tee(ctx: LoaderContext; stream: SocketStream; client: ClientData;
   var targetPid: int
   r.sread(sourceId)
   r.sread(targetPid)
-  let output = ctx.findOutput(sourceId)
-  # only allow tee'ing outputs owned by client
-  doAssert output.ownerPid == client.pid
+  let output = ctx.findOutput(sourceId, client)
   if output != nil:
     let id = ctx.getOutputId()
     output.tee(stream, id, targetPid)
@@ -602,7 +617,7 @@ proc suspend(ctx: LoaderContext; stream: SocketStream; client: ClientData;
   var ids: seq[int]
   r.sread(ids)
   for id in ids:
-    let output = ctx.findOutput(id)
+    let output = ctx.findOutput(id, client)
     if output != nil:
       output.suspended = true
       if output.registered:
@@ -615,7 +630,7 @@ proc resume(ctx: LoaderContext; stream: SocketStream; client: ClientData;
   var ids: seq[int]
   r.sread(ids)
   for id in ids:
-    let output = ctx.findOutput(id)
+    let output = ctx.findOutput(id, client)
     if output != nil:
       output.suspended = false
       assert not output.registered
@@ -793,10 +808,9 @@ proc finishCycle(ctx: LoaderContext; unregRead: var seq[LoaderHandle];
     if handle.istream != nil:
       ctx.selector.unregister(handle.istream.fd)
       ctx.handleMap.del(handle.istream.fd)
-      handle.istream.sclose()
-      handle.istream = nil
       if handle.parser != nil:
         handle.finishParse()
+      handle.iclose()
       for output in handle.outputs:
         output.istreamAtEnd = true
         if output.isEmpty:
@@ -816,10 +830,9 @@ proc finishCycle(ctx: LoaderContext; unregRead: var seq[LoaderHandle];
           # premature end of all output streams; kill istream too
           ctx.selector.unregister(handle.istream.fd)
           ctx.handleMap.del(handle.istream.fd)
-          handle.istream.sclose()
-          handle.istream = nil
           if handle.parser != nil:
             handle.finishParse()
+          handle.iclose()
 
 proc runFileLoader*(fd: cint; config: LoaderConfig) =
   var ctx = initLoaderContext(fd, config)
@@ -861,12 +874,7 @@ proc getRedirect*(response: Response; request: Request): Request =
             status == 302 and request.httpMethod == hmPost:
           return newRequest(url.get, hmGet)
         else:
-          return newRequest(
-            url.get,
-            request.httpMethod,
-            body = request.body,
-            multipart = request.multipart
-          )
+          return newRequest(url.get, request.httpMethod, body = request.body)
   return nil
 
 template withLoaderPacketWriter(stream: SocketStream; loader: FileLoader;
@@ -898,7 +906,6 @@ proc startRequest*(loader: FileLoader; request: Request;
     w.swrite(config)
   return stream
 
-#TODO: add init
 proc fetch*(loader: FileLoader; input: Request): FetchPromise =
   let stream = loader.startRequest(input)
   let fd = int(stream.fd)
@@ -913,10 +920,7 @@ proc fetch*(loader: FileLoader; input: Request): FetchPromise =
 
 proc reconnect*(loader: FileLoader; data: ConnectData) =
   data.stream.sclose()
-  let stream = loader.connect()
-  stream.withLoaderPacketWriter loader, w:
-    w.swrite(lcLoad)
-    w.swrite(data.request)
+  let stream = loader.startRequest(data.request)
   let fd = int(stream.fd)
   loader.registerFun(fd)
   loader.connecting[fd] = ConnectData(
@@ -925,18 +929,6 @@ proc reconnect*(loader: FileLoader; data: ConnectData) =
     stream: stream
   )
 
-proc switchStream*(data: var ConnectData; stream: SocketStream) =
-  data.stream = stream
-
-proc switchStream*(loader: FileLoader; data: var OngoingData;
-    stream: SocketStream) =
-  data.response.body = stream
-  let fd = int(stream.fd)
-  data.response.unregisterFun = proc() =
-    loader.ongoing.del(fd)
-    loader.unregistered.add(fd)
-    loader.unregisterFun(fd)
-
 proc suspend*(loader: FileLoader; fds: seq[int]) =
   let stream = loader.connect()
   stream.withLoaderPacketWriter loader, w:
@@ -944,13 +936,16 @@ proc suspend*(loader: FileLoader; fds: seq[int]) =
     w.swrite(fds)
   stream.sclose()
 
-proc resume*(loader: FileLoader; fds: seq[int]) =
+proc resume*(loader: FileLoader; fds: openArray[int]) =
   let stream = loader.connect()
   stream.withLoaderPacketWriter loader, w:
     w.swrite(lcResume)
     w.swrite(fds)
   stream.sclose()
 
+proc resume*(loader: FileLoader; fds: int) =
+  loader.resume([fds])
+
 proc tee*(loader: FileLoader; sourceId, targetPid: int): (SocketStream, int) =
   let stream = loader.connect()
   stream.withLoaderPacketWriter loader, w:
@@ -962,15 +957,20 @@ proc tee*(loader: FileLoader; sourceId, targetPid: int): (SocketStream, int) =
   r.sread(outputId)
   return (stream, outputId)
 
-proc addCacheFile*(loader: FileLoader; outputId, targetPid: int):
-    AddCacheFileResult =
+# sourcePid is the PID of the output's owner. This is used in pager for images,
+# so that we can be sure that a container only loads images on the page that
+# it owns.
+proc addCacheFile*(loader: FileLoader; outputId, targetPid: int;
+    sourcePid = -1): AddCacheFileResult =
   let stream = loader.connect()
   if stream == nil:
     return (-1, "")
+  let sourcePid = if sourcePid == -1: loader.clientPid else: sourcePid
   stream.withLoaderPacketWriter loader, w:
     w.swrite(lcAddCacheFile)
     w.swrite(outputId)
     w.swrite(targetPid)
+    w.swrite(sourcePid)
   var r = stream.initPacketReader()
   var outputId: int
   var cacheFile: string
@@ -990,18 +990,18 @@ proc redirectToFile*(loader: FileLoader; outputId: int; targetPath: string):
   var r = stream.initPacketReader()
   r.sread(result)
 
-const BufferSize = 4096
-
 proc onConnected*(loader: FileLoader; fd: int) =
   let connectData = loader.connecting[fd]
   let stream = connectData.stream
   let promise = connectData.promise
   let request = connectData.request
+  # delete before resolving the promise
+  loader.connecting.del(fd)
   var r = stream.initPacketReader()
   var res: int
   r.sread(res) # packet 1
-  let response = newResponse(res, request, stream)
   if res == 0:
+    let response = newResponse(res, request, stream)
     r.sread(response.outputId) # packet 1
     r = stream.initPacketReader()
     r.sread(response.status) # packet 2
@@ -1011,13 +1011,12 @@ proc onConnected*(loader: FileLoader; fd: int) =
     response.body = stream
     assert loader.unregisterFun != nil
     response.unregisterFun = proc() =
-      loader.ongoing.del(fd)
-      loader.unregistered.add(fd)
-      loader.unregisterFun(fd)
-    loader.ongoing[fd] = OngoingData(
-      response: response,
-      bodyRead: response.bodyRead
-    )
+      loader.ongoing.del(response.body.fd)
+      loader.unregistered.add(response.body.fd)
+      loader.unregisterFun(response.body.fd)
+    response.resumeFun = proc(outputId: int) =
+      loader.resume(outputId)
+    loader.ongoing[fd] = response
     stream.setBlocking(false)
     promise.resolve(JSResult[Response].ok(response))
   else:
@@ -1030,40 +1029,27 @@ proc onConnected*(loader: FileLoader; fd: int) =
     stream.sclose()
     let err = newTypeError("NetworkError when attempting to fetch resource")
     promise.resolve(JSResult[Response].err(err))
-  loader.connecting.del(fd)
 
 proc onRead*(loader: FileLoader; fd: int) =
-  loader.ongoing.withValue(fd, buffer):
-    let response = buffer[].response
-    while not response.body.isend:
-      let olen = buffer[].buf.len
-      try:
-        buffer[].buf.setLen(olen + BufferSize)
-        let n = response.body.recvData(addr buffer[].buf[olen], BufferSize)
-        buffer[].buf.setLen(olen + n)
-        if n == 0:
-          break
-      except ErrorAgain:
-        buffer[].buf.setLen(olen)
-        break
+  let response = loader.ongoing.getOrDefault(fd)
+  if response != nil:
+    response.onRead(response)
     if response.body.isend:
-      buffer[].bodyRead.resolve(buffer[].buf)
-      buffer[].bodyRead = nil
-      buffer[].buf = ""
+      response.bodyRead.resolve()
+      response.bodyRead = nil
       response.unregisterFun()
 
 proc onError*(loader: FileLoader; fd: int) =
-  loader.ongoing.withValue(fd, buffer):
-    let response = buffer[].response
+  let response = loader.ongoing.getOrDefault(fd)
+  if response != nil:
     when defined(debug):
       var lbuf {.noinit.}: array[BufferSize, char]
       if not response.body.isend:
         let n = response.body.recvData(addr lbuf[0], lbuf.len)
         assert n == 0
       assert response.body.isend
-    buffer[].bodyRead.resolve(buffer[].buf)
-    buffer[].bodyRead = nil
-    buffer[].buf = ""
+    response.bodyRead.resolve()
+    response.bodyRead = nil
     response.unregisterFun()
 
 # Note: this blocks until headers are received.