about summary refs log tree commit diff stats
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/io/posixstream.nim3
-rw-r--r--src/io/socketstream.nim22
-rw-r--r--src/loader/cgi.nim42
-rw-r--r--src/loader/loader.nim23
-rw-r--r--src/loader/loaderhandle.nim63
5 files changed, 71 insertions, 82 deletions
diff --git a/src/io/posixstream.nim b/src/io/posixstream.nim
index dbc48ece..73a957f8 100644
--- a/src/io/posixstream.nim
+++ b/src/io/posixstream.nim
@@ -14,6 +14,7 @@ type
   ErrorInterrupted* = object of IOError
   ErrorInvalid* = object of IOError
   ErrorConnectionReset* = object of IOError
+  ErrorBrokenPipe* = object of IOError
 
 proc raisePosixIOError*() =
   # In the nim stdlib, these are only constants on linux amd64, so we
@@ -30,6 +31,8 @@ proc raisePosixIOError*() =
     raise newException(ErrorInvalid, "invalid")
   elif errno == ECONNRESET:
     raise newException(ErrorConnectionReset, "connection reset by peer")
+  elif errno == EPIPE:
+    raise newException(ErrorBrokenPipe, "broken pipe")
   else:
     raise newException(IOError, $strerror(errno))
 
diff --git a/src/io/socketstream.nim b/src/io/socketstream.nim
index 31e4a3b7..9426f7a7 100644
--- a/src/io/socketstream.nim
+++ b/src/io/socketstream.nim
@@ -47,7 +47,7 @@ proc sockWriteData(s: Stream, buffer: pointer, len: int) =
   while i < len:
     let n = SocketStream(s).source.send(addr buffer[i], len - i)
     if n < 0:
-      raise newException(IOError, $strerror(errno))
+      raisePosixIOError()
     i += n
 
 proc sockAtEnd(s: Stream): bool =
@@ -64,7 +64,7 @@ proc sendfd(sock: SocketHandle, fd: cint): int {.importc.}
 proc sendFileHandle*(s: SocketStream, fd: FileHandle) =
   assert not s.source.hasDataBuffered
   let n = sendfd(s.source.getFd(), cint(fd))
-  if n < -1:
+  if n < 0:
     raisePosixIOError()
   assert n == 1 # we send a single nul byte as buf
 
@@ -82,7 +82,8 @@ proc recvFileHandle*(s: SocketStream): FileHandle =
 func newSocketStream*(): SocketStream =
   return SocketStream(
     readDataImpl: cast[proc (s: Stream, buffer: pointer, bufLen: int): int
-        {.nimcall, raises: [Defect, IOError, OSError], tags: [ReadIOEffect], gcsafe.}
+        {.nimcall, raises: [Defect, IOError, OSError], tags: [ReadIOEffect],
+        gcsafe.}
     ](sockReadData), # ... ???
     writeDataImpl: sockWriteData,
     atEndImpl: sockAtEnd,
@@ -94,19 +95,24 @@ proc setBlocking*(ss: SocketStream, blocking: bool) =
 
 # see serversocket.nim for an explanation
 {.compile: "connect_unix.c".}
-proc connect_unix_from_c(fd: cint, path: cstring, pathlen: cint): cint {.importc.}
+proc connect_unix_from_c(fd: cint, path: cstring, pathlen: cint): cint
+  {.importc.}
 
-proc connectSocketStream*(path: string, buffered = true, blocking = true): SocketStream =
+proc connectSocketStream*(path: string, buffered = true, blocking = true):
+    SocketStream =
   result = newSocketStream()
   result.blk = blocking
-  let sock = newSocket(Domain.AF_UNIX, SockType.SOCK_STREAM, Protocol.IPPROTO_IP, buffered)
+  let sock = newSocket(Domain.AF_UNIX, SockType.SOCK_STREAM,
+    Protocol.IPPROTO_IP, buffered)
   if not blocking:
     sock.getFd().setBlocking(false)
-  if connect_unix_from_c(cint(sock.getFd()), cstring(path), cint(path.len)) != 0:
+  if connect_unix_from_c(cint(sock.getFd()), cstring(path),
+      cint(path.len)) != 0:
     raiseOSError(osLastError())
   result.source = sock
 
-proc connectSocketStream*(pid: Pid, buffered = true, blocking = true): SocketStream =
+proc connectSocketStream*(pid: Pid, buffered = true, blocking = true):
+    SocketStream =
   try:
     connectSocketStream(getSocketPath(pid), buffered, blocking)
   except OSError:
diff --git a/src/loader/cgi.nim b/src/loader/cgi.nim
index 2a2451d2..64afc39f 100644
--- a/src/loader/cgi.nim
+++ b/src/loader/cgi.nim
@@ -70,21 +70,21 @@ proc handleFirstLine(handle: LoaderHandle, line: string, headers: Headers,
   let k = line.until(':')
   if k.len == line.len:
     # invalid
-    discard handle.sendResult(ERROR_CGI_MALFORMED_HEADER)
+    handle.sendResult(ERROR_CGI_MALFORMED_HEADER)
     return RESULT_ERROR
   let v = line.substr(k.len + 1).strip()
   if k.equalsIgnoreCase("Status"):
-    discard handle.sendResult(0) # success
+    handle.sendResult(0) # success
     status = parseInt32(v).get(0)
     return RESULT_CONTROL_CONTINUE
   if k.equalsIgnoreCase("Cha-Control"):
     if v.startsWithIgnoreCase("Connected"):
-      discard handle.sendResult(0) # success
+      handle.sendResult(0) # success
       return RESULT_CONTROL_CONTINUE
     elif v.startsWithIgnoreCase("ConnectionError"):
       let errs = v.split(' ')
       if errs.len <= 1:
-        discard handle.sendResult(ERROR_CGI_INVALID_CHA_CONTROL)
+        handle.sendResult(ERROR_CGI_INVALID_CHA_CONTROL)
       else:
         let fb = int32(ERROR_CGI_INVALID_CHA_CONTROL)
         let code = int(parseInt32(errs[1]).get(fb))
@@ -94,13 +94,13 @@ proc handleFirstLine(handle: LoaderHandle, line: string, headers: Headers,
           for i in 3 ..< errs.len:
             message &= ' '
             message &= errs[i]
-        discard handle.sendResult(code, message)
+        handle.sendResult(code, message)
       return RESULT_ERROR
     elif v.startsWithIgnoreCase("ControlDone"):
       return RESULT_CONTROL_DONE
-    discard handle.sendResult(ERROR_CGI_INVALID_CHA_CONTROL)
+    handle.sendResult(ERROR_CGI_INVALID_CHA_CONTROL)
     return RESULT_ERROR
-  discard handle.sendResult(0) # success
+  handle.sendResult(0) # success
   headers.add(k, v)
   return RESULT_CONTROL_DONE
 
@@ -132,11 +132,8 @@ proc handleLine(handle: LoaderHandle, line: string, headers: Headers) =
 
 proc loadCGI*(handle: LoaderHandle, request: Request, cgiDir: seq[string],
     libexecPath: string, prevURL: URL) =
-  template t(body: untyped) =
-    if not body:
-      return
   if cgiDir.len == 0:
-    discard handle.sendResult(ERROR_NO_CGI_DIR)
+    handle.sendResult(ERROR_NO_CGI_DIR)
     return
   var path = percentDecode(request.url.pathname)
   if path.startsWith("/cgi-bin/"):
@@ -144,7 +141,7 @@ proc loadCGI*(handle: LoaderHandle, request: Request, cgiDir: seq[string],
   elif path.startsWith("/$LIB/"):
     path.delete(0 .. "/$LIB/".high)
   if path == "" or request.url.hostname != "":
-    discard handle.sendResult(ERROR_INVALID_CGI_PATH)
+    handle.sendResult(ERROR_INVALID_CGI_PATH)
     return
   var basename: string
   var pathInfo: string
@@ -163,7 +160,7 @@ proc loadCGI*(handle: LoaderHandle, request: Request, cgiDir: seq[string],
         requestURI = cmd / pathInfo & request.url.search
         break
     if cmd == "":
-      discard handle.sendResult(ERROR_INVALID_CGI_PATH)
+      handle.sendResult(ERROR_INVALID_CGI_PATH)
       return
   else:
     basename = path.until('/')
@@ -175,20 +172,21 @@ proc loadCGI*(handle: LoaderHandle, request: Request, cgiDir: seq[string],
       if fileExists(cmd):
         break
   if not fileExists(cmd):
-    discard handle.sendResult(ERROR_CGI_FILE_NOT_FOUND)
+    handle.sendResult(ERROR_CGI_FILE_NOT_FOUND)
+    return
   if basename in ["", ".", ".."] or basename.startsWith("~"):
-    discard handle.sendResult(ERROR_INVALID_CGI_PATH)
+    handle.sendResult(ERROR_INVALID_CGI_PATH)
     return
   var pipefd: array[0..1, cint] # child -> parent
   if pipe(pipefd) == -1:
-    discard handle.sendResult(ERROR_FAIL_SETUP_CGI)
+    handle.sendResult(ERROR_FAIL_SETUP_CGI)
     return
   # Pipe the request body as stdin for POST.
   var pipefd_read: array[0..1, cint] # parent -> child
   let needsPipe = request.body.isSome or request.multipart.isSome
   if needsPipe:
     if pipe(pipefd_read) == -1:
-      discard handle.sendResult(ERROR_FAIL_SETUP_CGI)
+      handle.sendResult(ERROR_FAIL_SETUP_CGI)
       return
   var contentLen = 0
   if request.body.isSome:
@@ -197,7 +195,7 @@ proc loadCGI*(handle: LoaderHandle, request: Request, cgiDir: seq[string],
     contentLen = request.multipart.get.calcLength()
   let pid = fork()
   if pid == -1:
-    t handle.sendResult(ERROR_FAIL_SETUP_CGI)
+    handle.sendResult(ERROR_FAIL_SETUP_CGI)
   elif pid == 0:
     discard close(pipefd[0]) # close read
     discard dup2(pipefd[1], 1) # dup stdout
@@ -232,12 +230,12 @@ proc loadCGI*(handle: LoaderHandle, request: Request, cgiDir: seq[string],
     var status = 200
     if ps.atEnd:
       # no data?
-      discard handle.sendResult(ERROR_CGI_NO_DATA)
+      handle.sendResult(ERROR_CGI_NO_DATA)
       return
     let line = ps.readLine()
     if line == "": #\r\n
       # no headers, body comes immediately
-      t handle.sendResult(0) # success
+      handle.sendResult(0) # success
     else:
       var res = handle.handleFirstLine(line, headers, status)
       if res == RESULT_ERROR:
@@ -257,6 +255,6 @@ proc loadCGI*(handle: LoaderHandle, request: Request, cgiDir: seq[string],
           if line == "": #\r\n
             break
           handle.handleLine(line, headers)
-    t handle.sendStatus(status)
-    t handle.sendHeaders(headers)
+    handle.sendStatus(status)
+    handle.sendHeaders(headers)
     handle.istream = ps
diff --git a/src/loader/loader.nim b/src/loader/loader.nim
index 2fda709c..02585630 100644
--- a/src/loader/loader.nim
+++ b/src/loader/loader.nim
@@ -152,13 +152,13 @@ proc loadResource(ctx: LoaderContext, request: Request, handle: LoaderHandle) =
         inc tries
         redo = true
       of URI_RESULT_WRONG_URL:
-        discard handle.sendResult(ERROR_INVALID_URI_METHOD_ENTRY)
+        handle.sendResult(ERROR_INVALID_URI_METHOD_ENTRY)
         handle.close()
       of URI_RESULT_NOT_FOUND:
-        discard handle.sendResult(ERROR_UNKNOWN_SCHEME)
+        handle.sendResult(ERROR_UNKNOWN_SCHEME)
         handle.close()
   if tries >= MaxRewrites:
-    discard handle.sendResult(ERROR_TOO_MANY_REWRITES)
+    handle.sendResult(ERROR_TOO_MANY_REWRITES)
     handle.close()
 
 proc onLoad(ctx: LoaderContext, stream: SocketStream) =
@@ -166,7 +166,7 @@ proc onLoad(ctx: LoaderContext, stream: SocketStream) =
   stream.sread(request)
   let handle = newLoaderHandle(stream, request.canredir)
   if not ctx.config.filter.match(request.url):
-    discard handle.sendResult(ERROR_DISALLOWED_URL)
+    handle.sendResult(ERROR_DISALLOWED_URL)
     handle.close()
   else:
     for k, v in ctx.config.defaultheaders.table:
@@ -231,10 +231,8 @@ proc acceptConnection(ctx: LoaderContext) =
     of SET_REFERRER_POLICY:
       stream.sread(ctx.referrerpolicy)
       stream.close()
-  except IOError:
-    # End-of-file, broken pipe, or something else. For now we just
-    # ignore it and pray nothing breaks.
-    # (TODO: this is probably not a very good idea.)
+  except ErrorBrokenPipe:
+    # receiving end died while reading the file; give up.
     stream.close()
 
 proc exitLoader(ctx: LoaderContext) =
@@ -286,10 +284,11 @@ proc runFileLoader*(fd: cint, config: LoaderConfig) =
           while not handle.istream.atEnd:
             try:
               let n = handle.istream.readData(addr buffer[0], buffer.len)
-              if not handle.sendData(addr buffer[0], n):
-                unreg.add(event.fd)
-                break
-            except ErrorAgain, ErrorWouldBlock:
+              handle.sendData(addr buffer[0], n)
+            except ErrorAgain, ErrorWouldBlock: # retry later
+              break
+            except ErrorBrokenPipe: # receiver died; stop streaming
+              unreg.add(event.fd)
               break
       if Error in event.events:
         assert event.fd != ctx.fd
diff --git a/src/loader/loaderhandle.nim b/src/loader/loaderhandle.nim
index 1cddd5f8..7a9b3434 100644
--- a/src/loader/loaderhandle.nim
+++ b/src/loader/loaderhandle.nim
@@ -46,50 +46,33 @@ proc addOutputStream*(handle: LoaderHandle, stream: Stream) =
     let ms = newMultiStream(handle.ostream, stream)
     handle.ostream = ms
 
-proc sendResult*(handle: LoaderHandle, res: int, msg = ""): bool =
-  try:
-    handle.ostream.swrite(res)
-    if res == 0: # success
-      assert msg == ""
-    else: # error
-      handle.ostream.swrite(msg)
-    return true
-  except IOError: # broken pipe
-    return false
+proc sendResult*(handle: LoaderHandle, res: int, msg = "") =
+  handle.ostream.swrite(res)
+  if res == 0: # success
+    assert msg == ""
+  else: # error
+    handle.ostream.swrite(msg)
 
-proc sendStatus*(handle: LoaderHandle, status: int): bool =
-  try:
-    handle.ostream.swrite(status)
-    return true
-  except IOError: # broken pipe
-    return false
+proc sendStatus*(handle: LoaderHandle, status: int) =
+  handle.ostream.swrite(status)
 
-proc sendHeaders*(handle: LoaderHandle, headers: Headers): bool =
-  try:
-    handle.ostream.swrite(headers)
-    if handle.canredir:
-      var redir: bool
-      handle.ostream.sread(redir)
-      if redir:
-        let fd = SocketStream(handle.ostream).recvFileHandle()
-        handle.sostream = handle.ostream
-        let stream = newPosixStream(fd)
-        handle.ostream = stream
-    return true
-  except IOError: # broken pipe
-    return false
+proc sendHeaders*(handle: LoaderHandle, headers: Headers) =
+  handle.ostream.swrite(headers)
+  if handle.canredir:
+    var redir: bool
+    handle.ostream.sread(redir)
+    if redir:
+      let fd = SocketStream(handle.ostream).recvFileHandle()
+      handle.sostream = handle.ostream
+      let stream = newPosixStream(fd)
+      handle.ostream = stream
 
-proc sendData*(handle: LoaderHandle, p: pointer, nmemb: int): bool =
-  try:
-    handle.ostream.writeData(p, nmemb)
-    return true
-  except IOError: # broken pipe
-    return false
+proc sendData*(handle: LoaderHandle, p: pointer, nmemb: int) =
+  handle.ostream.writeData(p, nmemb)
 
-proc sendData*(handle: LoaderHandle, s: string): bool =
+proc sendData*(handle: LoaderHandle, s: string) =
   if s.len > 0:
-    return handle.sendData(unsafeAddr s[0], s.len)
-  return true
+    handle.sendData(unsafeAddr s[0], s.len)
 
 proc suspend*(handle: LoaderHandle) =
   handle.sostream_suspend = handle.ostream
@@ -99,7 +82,7 @@ proc resume*(handle: LoaderHandle) =
   let ss = handle.ostream
   handle.ostream = handle.sostream_suspend
   handle.sostream_suspend = nil
-  discard handle.sendData(ss.readAll())
+  handle.sendData(ss.readAll())
   ss.close()
 
 proc close*(handle: LoaderHandle) =