summary refs log tree commit diff stats
path: root/lib/pure/asynchttpserver.nim
diff options
context:
space:
mode:
authorDaniil Yarancev <21169548+Yardanico@users.noreply.github.com>2018-01-07 21:02:00 +0300
committerGitHub <noreply@github.com>2018-01-07 21:02:00 +0300
commitfb44c522e6173528efa8035ecc459c84887d0167 (patch)
treea2f5e98606be265981a5f72748896967033e23d7 /lib/pure/asynchttpserver.nim
parentccf99fa5ce4fe992fb80dc89271faa51456c3fa5 (diff)
parente23ea64c41e101d4e1d933f0b015f51cc6c2f7de (diff)
downloadNim-fb44c522e6173528efa8035ecc459c84887d0167.tar.gz
Merge pull request #1 from nim-lang/devel
upstream
Diffstat (limited to 'lib/pure/asynchttpserver.nim')
-rw-r--r--lib/pure/asynchttpserver.nim268
1 files changed, 142 insertions, 126 deletions
diff --git a/lib/pure/asynchttpserver.nim b/lib/pure/asynchttpserver.nim
index 6d4b85145..ba1615651 100644
--- a/lib/pure/asynchttpserver.nim
+++ b/lib/pure/asynchttpserver.nim
@@ -58,15 +58,18 @@ type
     socket: AsyncSocket
     reuseAddr: bool
     reusePort: bool
+    maxBody: int ## The maximum content-length that will be read for the body.
 
 {.deprecated: [TRequest: Request, PAsyncHttpServer: AsyncHttpServer,
   THttpCode: HttpCode, THttpVersion: HttpVersion].}
 
-proc newAsyncHttpServer*(reuseAddr = true, reusePort = false): AsyncHttpServer =
+proc newAsyncHttpServer*(reuseAddr = true, reusePort = false,
+                         maxBody = 8388608): AsyncHttpServer =
   ## Creates a new ``AsyncHttpServer`` instance.
   new result
   result.reuseAddr = reuseAddr
   result.reusePort = reusePort
+  result.maxBody = maxBody
 
 proc addHeaders(msg: var string, headers: HttpHeaders) =
   for k, v in headers:
@@ -122,144 +125,157 @@ proc parseProtocol(protocol: string): tuple[orig: string, major, minor: int] =
     raise newException(ValueError, "Invalid request protocol. Got: " &
         protocol)
   result.orig = protocol
-  i.inc protocol.parseInt(result.major, i)
+  i.inc protocol.parseSaturatedNatural(result.major, i)
   i.inc # Skip .
-  i.inc protocol.parseInt(result.minor, i)
+  i.inc protocol.parseSaturatedNatural(result.minor, i)
 
 proc sendStatus(client: AsyncSocket, status: string): Future[void] =
   client.send("HTTP/1.1 " & status & "\c\L\c\L")
 
-proc processClient(client: AsyncSocket, address: string,
-                   callback: proc (request: Request):
+proc processRequest(server: AsyncHttpServer, req: FutureVar[Request],
+                    client: AsyncSocket,
+                    address: string, lineFut: FutureVar[string],
+                    callback: proc (request: Request):
                       Future[void] {.closure, gcsafe.}) {.async.} =
-  var request: Request
-  request.url = initUri()
-  request.headers = newHttpHeaders()
-  var lineFut = newFutureVar[string]("asynchttpserver.processClient")
-  lineFut.mget() = newStringOfCap(80)
-  var key, value = ""
 
-  while not client.isClosed:
-    # GET /path HTTP/1.1
-    # Header: val
-    # \n
-    request.headers.clear()
-    request.body = ""
-    request.hostname.shallowCopy(address)
-    assert client != nil
-    request.client = client
-
-    # We should skip at least one empty line before the request
-    # https://tools.ietf.org/html/rfc7230#section-3.5
-    for i in 0..1:
-      lineFut.mget().setLen(0)
-      lineFut.clean()
-      await client.recvLineInto(lineFut, maxLength=maxLine) # TODO: Timeouts.
-
-      if lineFut.mget == "":
-        client.close()
-        return
+  # Alias `request` to `req.mget()` so we don't have to write `mget` everywhere.
+  template request(): Request =
+    req.mget()
+
+  # GET /path HTTP/1.1
+  # Header: val
+  # \n
+  request.headers.clear()
+  request.body = ""
+  request.hostname.shallowCopy(address)
+  assert client != nil
+  request.client = client
+
+  # We should skip at least one empty line before the request
+  # https://tools.ietf.org/html/rfc7230#section-3.5
+  for i in 0..1:
+    lineFut.mget().setLen(0)
+    lineFut.clean()
+    await client.recvLineInto(lineFut, maxLength=maxLine) # TODO: Timeouts.
+
+    if lineFut.mget == "":
+      client.close()
+      return
 
-      if lineFut.mget.len > maxLine:
-        await request.respondError(Http413)
-        client.close()
+    if lineFut.mget.len > maxLine:
+      await request.respondError(Http413)
+      client.close()
+      return
+    if lineFut.mget != "\c\L":
+      break
+
+  # First line - GET /path HTTP/1.1
+  var i = 0
+  for linePart in lineFut.mget.split(' '):
+    case i
+    of 0:
+      try:
+        # TODO: this is likely slow.
+        request.reqMethod = parseEnum[HttpMethod]("http" & linePart)
+      except ValueError:
+        asyncCheck request.respondError(Http400)
         return
-      if lineFut.mget != "\c\L":
-        break
-
-    # First line - GET /path HTTP/1.1
-    var i = 0
-    for linePart in lineFut.mget.split(' '):
-      case i
-      of 0:
-        try:
-          # TODO: this is likely slow.
-          request.reqMethod = parseEnum[HttpMethod]("http" & linePart)
-        except ValueError:
-          asyncCheck request.respondError(Http400)
-          continue
-      of 1: 
-        try:
-          parseUri(linePart, request.url)
-        except ValueError:
-          asyncCheck request.respondError(Http400) 
-          continue
-      of 2:
-        try:
-          request.protocol = parseProtocol(linePart)
-        except ValueError:
-          asyncCheck request.respondError(Http400)
-          continue
-      else:
-        await request.respondError(Http400)
-        continue
-      inc i
-
-    # Headers
-    while true:
-      i = 0
-      lineFut.mget.setLen(0)
-      lineFut.clean()
-      await client.recvLineInto(lineFut, maxLength=maxLine)
-
-      if lineFut.mget == "":
-        client.close(); return
-      if lineFut.mget.len > maxLine:
-        await request.respondError(Http413)
-        client.close(); return
-      if lineFut.mget == "\c\L": break
-      let (key, value) = parseHeader(lineFut.mget)
-      request.headers[key] = value
-      # Ensure the client isn't trying to DoS us.
-      if request.headers.len > headerLimit:
-        await client.sendStatus("400 Bad Request")
-        request.client.close()
+    of 1:
+      try:
+        parseUri(linePart, request.url)
+      except ValueError:
+        asyncCheck request.respondError(Http400)
         return
+    of 2:
+      try:
+        request.protocol = parseProtocol(linePart)
+      except ValueError:
+        asyncCheck request.respondError(Http400)
+        return
+    else:
+      await request.respondError(Http400)
+      return
+    inc i
 
-    if request.reqMethod == HttpPost:
-      # Check for Expect header
-      if request.headers.hasKey("Expect"):
-        if "100-continue" in request.headers["Expect"]:
-          await client.sendStatus("100 Continue")
-        else:
-          await client.sendStatus("417 Expectation Failed")
-
-    # Read the body
-    # - Check for Content-length header
-    if request.headers.hasKey("Content-Length"):
-      var contentLength = 0
-      if parseInt(request.headers["Content-Length"],
-                  contentLength) == 0:
-        await request.respond(Http400, "Bad Request. Invalid Content-Length.")
-        continue
-      else:
-        request.body = await client.recv(contentLength)
-        if request.body.len != contentLength:
-          await request.respond(Http400, "Bad Request. Content-Length does not match actual.")
-          continue
-    elif request.reqMethod == HttpPost:
-      await request.respond(Http411, "Content-Length required.")
-      continue
-
-    # Call the user's callback.
-    await callback(request)
-
-    if "upgrade" in request.headers.getOrDefault("connection"):
+  # Headers
+  while true:
+    i = 0
+    lineFut.mget.setLen(0)
+    lineFut.clean()
+    await client.recvLineInto(lineFut, maxLength=maxLine)
+
+    if lineFut.mget == "":
+      client.close(); return
+    if lineFut.mget.len > maxLine:
+      await request.respondError(Http413)
+      client.close(); return
+    if lineFut.mget == "\c\L": break
+    let (key, value) = parseHeader(lineFut.mget)
+    request.headers[key] = value
+    # Ensure the client isn't trying to DoS us.
+    if request.headers.len > headerLimit:
+      await client.sendStatus("400 Bad Request")
+      request.client.close()
       return
 
-    # Persistent connections
-    if (request.protocol == HttpVer11 and
-        request.headers.getOrDefault("connection").normalize != "close") or
-       (request.protocol == HttpVer10 and
-        request.headers.getOrDefault("connection").normalize == "keep-alive"):
-      # In HTTP 1.1 we assume that connection is persistent. Unless connection
-      # header states otherwise.
-      # In HTTP 1.0 we assume that the connection should not be persistent.
-      # Unless the connection header states otherwise.
-      discard
+  if request.reqMethod == HttpPost:
+    # Check for Expect header
+    if request.headers.hasKey("Expect"):
+      if "100-continue" in request.headers["Expect"]:
+        await client.sendStatus("100 Continue")
+      else:
+        await client.sendStatus("417 Expectation Failed")
+
+  # Read the body
+  # - Check for Content-length header
+  if request.headers.hasKey("Content-Length"):
+    var contentLength = 0
+    if parseSaturatedNatural(request.headers["Content-Length"], contentLength) == 0:
+      await request.respond(Http400, "Bad Request. Invalid Content-Length.")
+      return
     else:
-      request.client.close()
-      break
+      if contentLength > server.maxBody:
+        await request.respondError(Http413)
+        return
+      request.body = await client.recv(contentLength)
+      if request.body.len != contentLength:
+        await request.respond(Http400, "Bad Request. Content-Length does not match actual.")
+        return
+  elif request.reqMethod == HttpPost:
+    await request.respond(Http411, "Content-Length required.")
+    return
+
+  # Call the user's callback.
+  await callback(request)
+
+  if "upgrade" in request.headers.getOrDefault("connection"):
+    return
+
+  # Persistent connections
+  if (request.protocol == HttpVer11 and
+      cmpIgnoreCase(request.headers.getOrDefault("connection"), "close") != 0) or
+     (request.protocol == HttpVer10 and
+      cmpIgnoreCase(request.headers.getOrDefault("connection"), "keep-alive") == 0):
+    # In HTTP 1.1 we assume that connection is persistent. Unless connection
+    # header states otherwise.
+    # In HTTP 1.0 we assume that the connection should not be persistent.
+    # Unless the connection header states otherwise.
+    discard
+  else:
+    request.client.close()
+    return
+
+proc processClient(server: AsyncHttpServer, client: AsyncSocket, address: string,
+                   callback: proc (request: Request):
+                      Future[void] {.closure, gcsafe.}) {.async.} =
+  var request = newFutureVar[Request]("asynchttpserver.processClient")
+  request.mget().url = initUri()
+  request.mget().headers = newHttpHeaders()
+  var lineFut = newFutureVar[string]("asynchttpserver.processClient")
+  lineFut.mget() = newStringOfCap(80)
+
+  while not client.isClosed:
+    await processRequest(server, request, client, address, lineFut, callback)
 
 proc serve*(server: AsyncHttpServer, port: Port,
             callback: proc (request: Request): Future[void] {.closure,gcsafe.},
@@ -280,7 +296,7 @@ proc serve*(server: AsyncHttpServer, port: Port,
     # TODO: Causes compiler crash.
     #var (address, client) = await server.socket.acceptAddr()
     var fut = await server.socket.acceptAddr()
-    asyncCheck processClient(fut.client, fut.address, callback)
+    asyncCheck processClient(server, fut.client, fut.address, callback)
     #echo(f.isNil)
     #echo(f.repr)