summary refs log tree commit diff stats
path: root/lib/pure/asynchttpserver.nim
diff options
context:
space:
mode:
Diffstat (limited to 'lib/pure/asynchttpserver.nim')
-rw-r--r--lib/pure/asynchttpserver.nim133
1 files changed, 44 insertions, 89 deletions
diff --git a/lib/pure/asynchttpserver.nim b/lib/pure/asynchttpserver.nim
index 590b52c1a..6a7326e83 100644
--- a/lib/pure/asynchttpserver.nim
+++ b/lib/pure/asynchttpserver.nim
@@ -25,12 +25,21 @@
 ##
 ##    waitFor server.serve(Port(8080), cb)
 
-import strtabs, asyncnet, asyncdispatch, parseutils, uri, strutils
+import tables, asyncnet, asyncdispatch, parseutils, uri, strutils
+import httpcore
+
+export httpcore except parseHeader
+
+# TODO: If it turns out that the decisions that asynchttpserver makes
+# explicitly, about whether to close the client sockets or upgrade them are
+# wrong, then add a return value which determines what to do for the callback.
+# Also, maybe move `client` out of `Request` object and into the args for
+# the proc.
 type
   Request* = object
     client*: AsyncSocket # TODO: Separate this into a Response object?
     reqMethod*: string
-    headers*: StringTableRef
+    headers*: HttpHeaders
     protocol*: tuple[orig: string, major, minor: int]
     url*: Uri
     hostname*: string ## The hostname of the client that made the request.
@@ -39,83 +48,29 @@ type
   AsyncHttpServer* = ref object
     socket: AsyncSocket
     reuseAddr: bool
-
-  HttpCode* = enum
-    Http100 = "100 Continue",
-    Http101 = "101 Switching Protocols",
-    Http200 = "200 OK",
-    Http201 = "201 Created",
-    Http202 = "202 Accepted",
-    Http204 = "204 No Content",
-    Http205 = "205 Reset Content",
-    Http206 = "206 Partial Content",
-    Http300 = "300 Multiple Choices",
-    Http301 = "301 Moved Permanently",
-    Http302 = "302 Found",
-    Http303 = "303 See Other",
-    Http304 = "304 Not Modified",
-    Http305 = "305 Use Proxy",
-    Http307 = "307 Temporary Redirect",
-    Http400 = "400 Bad Request",
-    Http401 = "401 Unauthorized",
-    Http403 = "403 Forbidden",
-    Http404 = "404 Not Found",
-    Http405 = "405 Method Not Allowed",
-    Http406 = "406 Not Acceptable",
-    Http407 = "407 Proxy Authentication Required",
-    Http408 = "408 Request Timeout",
-    Http409 = "409 Conflict",
-    Http410 = "410 Gone",
-    Http411 = "411 Length Required",
-    Http412 = "412 Precondition Failed",
-    Http413 = "413 Request Entity Too Large",
-    Http414 = "414 Request-URI Too Long",
-    Http415 = "415 Unsupported Media Type",
-    Http416 = "416 Requested Range Not Satisfiable",
-    Http417 = "417 Expectation Failed",
-    Http418 = "418 I'm a teapot",
-    Http500 = "500 Internal Server Error",
-    Http501 = "501 Not Implemented",
-    Http502 = "502 Bad Gateway",
-    Http503 = "503 Service Unavailable",
-    Http504 = "504 Gateway Timeout",
-    Http505 = "505 HTTP Version Not Supported"
-
-  HttpVersion* = enum
-    HttpVer11,
-    HttpVer10
+    reusePort: bool
 
 {.deprecated: [TRequest: Request, PAsyncHttpServer: AsyncHttpServer,
   THttpCode: HttpCode, THttpVersion: HttpVersion].}
 
-proc `==`*(protocol: tuple[orig: string, major, minor: int],
-           ver: HttpVersion): bool =
-  let major =
-    case ver
-    of HttpVer11, HttpVer10: 1
-  let minor =
-    case ver
-    of HttpVer11: 1
-    of HttpVer10: 0
-  result = protocol.major == major and protocol.minor == minor
-
-proc newAsyncHttpServer*(reuseAddr = true): AsyncHttpServer =
+proc newAsyncHttpServer*(reuseAddr = true, reusePort = false): AsyncHttpServer =
   ## Creates a new ``AsyncHttpServer`` instance.
   new result
   result.reuseAddr = reuseAddr
+  result.reusePort = reusePort
 
-proc addHeaders(msg: var string, headers: StringTableRef) =
+proc addHeaders(msg: var string, headers: HttpHeaders) =
   for k, v in headers:
     msg.add(k & ": " & v & "\c\L")
 
-proc sendHeaders*(req: Request, headers: StringTableRef): Future[void] =
+proc sendHeaders*(req: Request, headers: HttpHeaders): Future[void] =
   ## Sends the specified headers to the requesting client.
   var msg = ""
   addHeaders(msg, headers)
   return req.client.send(msg)
 
 proc respond*(req: Request, code: HttpCode, content: string,
-              headers: StringTableRef = nil): Future[void] =
+              headers: HttpHeaders = nil): Future[void] =
   ## Responds to the request with the specified ``HttpCode``, headers and
   ## content.
   ##
@@ -128,16 +83,6 @@ proc respond*(req: Request, code: HttpCode, content: string,
   msg.add(content)
   result = req.client.send(msg)
 
-proc parseHeader(line: string): tuple[key, value: string] =
-  var i = 0
-  i = line.parseUntil(result.key, ':')
-  inc(i) # skip :
-  if i < len(line):
-    i += line.skipWhiteSpace(i)
-    i += line.parseUntil(result.value, {'\c', '\L'}, i)
-  else:
-    result.value = ""
-
 proc parseProtocol(protocol: string): tuple[orig: string, major, minor: int] =
   var i = protocol.skipIgnoreCase("HTTP/")
   if i != 5:
@@ -156,7 +101,7 @@ proc processClient(client: AsyncSocket, address: string,
                       Future[void] {.closure, gcsafe.}) {.async.} =
   var request: Request
   request.url = initUri()
-  request.headers = newStringTable(modeCaseInsensitive)
+  request.headers = newHttpHeaders()
   var lineFut = newFutureVar[string]("asynchttpserver.processClient")
   lineFut.mget() = newStringOfCap(80)
   var key, value = ""
@@ -165,7 +110,7 @@ proc processClient(client: AsyncSocket, address: string,
     # GET /path HTTP/1.1
     # Header: val
     # \n
-    request.headers.clear(modeCaseInsensitive)
+    request.headers.clear()
     request.body = ""
     request.hostname.shallowCopy(address)
     assert client != nil
@@ -208,29 +153,34 @@ proc processClient(client: AsyncSocket, address: string,
       if lineFut.mget == "\c\L": break
       let (key, value) = parseHeader(lineFut.mget)
       request.headers[key] = value
+      # Ensure the client isn't trying to DoS us.
+      if request.headers.len > headerLimit:
+        await client.sendStatus("400 Bad Request")
+        request.client.close()
+        return
 
     if request.reqMethod == "post":
       # Check for Expect header
       if request.headers.hasKey("Expect"):
-        if request.headers.getOrDefault("Expect").toLower == "100-continue":
+        if "100-continue" in request.headers["Expect"]:
           await client.sendStatus("100 Continue")
         else:
           await client.sendStatus("417 Expectation Failed")
 
-      # Read the body
-      # - Check for Content-length header
-      if request.headers.hasKey("Content-Length"):
-        var contentLength = 0
-        if parseInt(request.headers.getOrDefault("Content-Length"),
-                    contentLength) == 0:
-          await request.respond(Http400, "Bad Request. Invalid Content-Length.")
-          continue
-        else:
-          request.body = await client.recv(contentLength)
-          assert request.body.len == contentLength
-      else:
-        await request.respond(Http400, "Bad Request. No Content-Length.")
+    # Read the body
+    # - Check for Content-length header
+    if request.headers.hasKey("Content-Length"):
+      var contentLength = 0
+      if parseInt(request.headers["Content-Length"],
+                  contentLength) == 0:
+        await request.respond(Http400, "Bad Request. Invalid Content-Length.")
         continue
+      else:
+        request.body = await client.recv(contentLength)
+        assert request.body.len == contentLength
+    elif request.reqMethod == "post":
+      await request.respond(Http400, "Bad Request. No Content-Length.")
+      continue
 
     case request.reqMethod
     of "get", "post", "head", "put", "delete", "trace", "options",
@@ -240,6 +190,9 @@ proc processClient(client: AsyncSocket, address: string,
       await request.respond(Http400, "Invalid request method. Got: " &
         request.reqMethod)
 
+    if "upgrade" in request.headers.getOrDefault("connection"):
+      return
+
     # Persistent connections
     if (request.protocol == HttpVer11 and
         request.headers.getOrDefault("connection").normalize != "close") or
@@ -264,6 +217,8 @@ proc serve*(server: AsyncHttpServer, port: Port,
   server.socket = newAsyncSocket()
   if server.reuseAddr:
     server.socket.setSockOpt(OptReuseAddr, true)
+  if server.reusePort:
+    server.socket.setSockOpt(OptReusePort, true)
   server.socket.bindAddr(port, address)
   server.socket.listen()
 
@@ -287,7 +242,7 @@ when not defined(testing) and isMainModule:
       #echo(req.headers)
       let headers = {"Date": "Tue, 29 Apr 2014 23:40:08 GMT",
           "Content-type": "text/plain; charset=utf-8"}
-      await req.respond(Http200, "Hello World", headers.newStringTable())
+      await req.respond(Http200, "Hello World", headers.newHttpHeaders())
 
     asyncCheck server.serve(Port(5555), cb)
     runForever()