summary refs log tree commit diff stats
path: root/lib/pure/asynchttpserver.nim
diff options
context:
space:
mode:
Diffstat (limited to 'lib/pure/asynchttpserver.nim')
-rw-r--r--lib/pure/asynchttpserver.nim181
1 files changed, 109 insertions, 72 deletions
diff --git a/lib/pure/asynchttpserver.nim b/lib/pure/asynchttpserver.nim
index 2ebd7036d..ee6658fd1 100644
--- a/lib/pure/asynchttpserver.nim
+++ b/lib/pure/asynchttpserver.nim
@@ -14,12 +14,13 @@
 import strtabs, asyncnet, asyncdispatch, parseutils, parseurl, strutils
 type
   TRequest* = object
-    client: PAsyncSocket # TODO: Separate this into a Response object?
+    client*: PAsyncSocket # TODO: Separate this into a Response object?
     reqMethod*: string
     headers*: PStringTable
     protocol*: tuple[orig: string, major, minor: int]
     url*: TURL
     hostname*: string ## The hostname of the client that made the request.
+    body*: string
 
   PAsyncHttpServer* = ref object
     socket: PAsyncSocket
@@ -50,10 +51,15 @@ proc `==`*(protocol: tuple[orig: string, major, minor: int],
 proc newAsyncHttpServer*(): PAsyncHttpServer =
   new result
 
-proc sendHeaders*(req: TRequest, headers: PStringTable) {.async.} =
-  ## Sends the specified headers to the requesting client.
+proc addHeaders(msg: var string, headers: PStringTable) =
   for k, v in headers:
-    await req.client.send(k & ": " & v & "\c\L")
+    msg.add(k & ": " & v & "\c\L")
+
+proc sendHeaders*(req: TRequest, headers: PStringTable): PFuture[void] =
+  ## Sends the specified headers to the requesting client.
+  var msg = ""
+  addHeaders(msg, headers)
+  return req.client.send(msg)
 
 proc respond*(req: TRequest, code: THttpCode,
         content: string, headers: PStringTable = newStringTable()) {.async.} =
@@ -63,9 +69,9 @@ proc respond*(req: TRequest, code: THttpCode,
   ## This procedure will **not** close the client socket.
   var customHeaders = headers
   customHeaders["Content-Length"] = $content.len
-  await req.client.send("HTTP/1.1 " & $code & "\c\L")
-  await sendHeaders(req, headers)
-  await req.client.send("\c\L" & content)
+  var msg = "HTTP/1.1 " & $code & "\c\L"
+  msg.addHeaders(customHeaders)
+  await req.client.send(msg & "\c\L" & content)
 
 proc newRequest(): TRequest =
   result.headers = newStringTable(modeCaseInsensitive)
@@ -77,7 +83,7 @@ proc parseHeader(line: string): tuple[key, value: string] =
   i += line.skipWhiteSpace(i)
   i += line.parseUntil(result.value, {'\c', '\L'}, i)
 
-proc parseProtocol(protocol: string):  tuple[orig: string, major, minor: int] =
+proc parseProtocol(protocol: string): tuple[orig: string, major, minor: int] =
   var i = protocol.skipIgnoreCase("HTTP/")
   if i != 5:
     raise newException(EInvalidValue, "Invalid request protocol. Got: " &
@@ -87,70 +93,95 @@ proc parseProtocol(protocol: string):  tuple[orig: string, major, minor: int] =
   i.inc # Skip .
   i.inc protocol.parseInt(result.minor, i)
 
+proc sendStatus(client: PAsyncSocket, status: string): PFuture[void] =
+  client.send("HTTP/1.1 " & status & "\c\L")
+
 proc processClient(client: PAsyncSocket, address: string,
                  callback: proc (request: TRequest): PFuture[void]) {.async.} =
-  # GET /path HTTP/1.1
-  # Header: val
-  # \n
-  var request = newRequest()
-  request.hostname = address
-  assert client != nil
-  request.client = client
-
-  # First line - GET /path HTTP/1.1
-  let line = await client.recvLine() # TODO: Timeouts.
-  if line == "":
-    client.close()
-    return
-  let lineParts = line.split(' ')
-  if lineParts.len != 3:
-    request.respond(Http400, "Invalid request. Got: " & line)
-    client.close()
-    return
-
-  let reqMethod = lineParts[0]
-  let path = lineParts[1]
-  let protocol = lineParts[2]
-
-  # Headers
-  var i = 0
   while true:
-    i = 0
-    let headerLine = await client.recvLine()
-    if headerLine == "":
-      client.close(); return
-    if headerLine == "\c\L": break
-    # TODO: Compiler crash
-    #let (key, value) = parseHeader(headerLine)
-    let kv = parseHeader(headerLine)
-    request.headers[kv.key] = kv.value
-
-  request.reqMethod = reqMethod
-  request.url = parseUrl(path)
-  try:
-    request.protocol = protocol.parseProtocol()
-  except EInvalidValue:
-    request.respond(Http400, "Invalid request protocol. Got: " & protocol)
-    return
-  
-  case reqMethod.normalize
-  of "get", "post", "head", "put", "delete", "trace", "options", "connect", "patch":
-    await callback(request)
-  else:
-    request.respond(Http400, "Invalid request method. Got: " & reqMethod)
-
-  # Persistent connections
-  if (request.protocol == HttpVer11 and
-      request.headers["connection"].normalize != "close") or
-     (request.protocol == HttpVer10 and
-      request.headers["connection"].normalize == "keep-alive"):
-    # In HTTP 1.1 we assume that connection is persistent. Unless connection
-    # header states otherwise.
-    # In HTTP 1.0 we assume that the connection should not be persistent.
-    # Unless the connection header states otherwise.
-    await processClient(client, address, callback)
-  else:
-    request.client.close()
+    # GET /path HTTP/1.1
+    # Header: val
+    # \n
+    var request = newRequest()
+    request.hostname = address
+    assert client != nil
+    request.client = client
+
+    # First line - GET /path HTTP/1.1
+    let line = await client.recvLine() # TODO: Timeouts.
+    if line == "":
+      client.close()
+      return
+    let lineParts = line.split(' ')
+    if lineParts.len != 3:
+      await request.respond(Http400, "Invalid request. Got: " & line)
+      continue
+
+    let reqMethod = lineParts[0]
+    let path = lineParts[1]
+    let protocol = lineParts[2]
+
+    # Headers
+    var i = 0
+    while true:
+      i = 0
+      let headerLine = await client.recvLine()
+      if headerLine == "":
+        client.close(); return
+      if headerLine == "\c\L": break
+      # TODO: Compiler crash
+      #let (key, value) = parseHeader(headerLine)
+      let kv = parseHeader(headerLine)
+      request.headers[kv.key] = kv.value
+
+    request.reqMethod = reqMethod
+    request.url = parseUrl(path)
+    try:
+      request.protocol = protocol.parseProtocol()
+    except EInvalidValue:
+      asyncCheck request.respond(Http400, "Invalid request protocol. Got: " &
+          protocol)
+      continue
+
+    if reqMethod.normalize == "post":
+      # Check for Expect header
+      if request.headers.hasKey("Expect"):
+        if request.headers["Expect"].toLower == "100-continue":
+          await client.sendStatus("100 Continue")
+        else:
+          await client.sendStatus("417 Expectation Failed")
+    
+      # Read the body
+      # - Check for Content-length header
+      if request.headers.hasKey("Content-Length"):
+        var contentLength = 0
+        if parseInt(request.headers["Content-Length"], contentLength) == 0:
+          await request.respond(Http400, "Bad Request. Invalid Content-Length.")
+        else:
+          request.body = await client.recv(contentLength)
+          assert request.body.len == contentLength
+      else:
+        await request.respond(Http400, "Bad Request. No Content-Length.")
+        continue
+
+    case reqMethod.normalize
+    of "get", "post", "head", "put", "delete", "trace", "options", "connect", "patch":
+      await callback(request)
+    else:
+      await request.respond(Http400, "Invalid request method. Got: " & reqMethod)
+
+    # Persistent connections
+    if (request.protocol == HttpVer11 and
+        request.headers["connection"].normalize != "close") or
+       (request.protocol == HttpVer10 and
+        request.headers["connection"].normalize == "keep-alive"):
+      # In HTTP 1.1 we assume that connection is persistent. Unless connection
+      # header states otherwise.
+      # In HTTP 1.0 we assume that the connection should not be persistent.
+      # Unless the connection header states otherwise.
+    else:
+      request.client.close()
+      break
 
 proc serve*(server: PAsyncHttpServer, port: TPort,
             callback: proc (request: TRequest): PFuture[void],
@@ -167,14 +198,20 @@ proc serve*(server: PAsyncHttpServer, port: TPort,
     # TODO: Causes compiler crash.
     #var (address, client) = await server.socket.acceptAddr()
     var fut = await server.socket.acceptAddr()
-    processClient(fut.client, fut.address, callback)
+    asyncCheck processClient(fut.client, fut.address, callback)
+
+proc close*(server: PAsyncHttpServer) =
+  ## Terminates the async http server instance.
+  server.socket.close()
 
 when isMainModule:
   var server = newAsyncHttpServer()
   proc cb(req: TRequest) {.async.} =
     #echo(req.reqMethod, " ", req.url)
     #echo(req.headers)
-    await req.respond(Http200, "Hello World")
+    let headers = {"Date": "Tue, 29 Apr 2014 23:40:08 GMT",
+        "Content-type": "text/plain; charset=utf-8"}
+    await req.respond(Http200, "Hello World", headers.newStringTable())
 
-  server.serve(TPort(5555), cb)
+  asyncCheck server.serve(TPort(5555), cb)
   runForever()