summary refs log tree commit diff stats
path: root/lib/pure/asynchttpserver.nim
diff options
context:
space:
mode:
Diffstat (limited to 'lib/pure/asynchttpserver.nim')
-rw-r--r--lib/pure/asynchttpserver.nim92
1 files changed, 47 insertions, 45 deletions
diff --git a/lib/pure/asynchttpserver.nim b/lib/pure/asynchttpserver.nim
index 64242234c..279cedb5d 100644
--- a/lib/pure/asynchttpserver.nim
+++ b/lib/pure/asynchttpserver.nim
@@ -23,8 +23,7 @@
 ##    proc cb(req: Request) {.async.} =
 ##      await req.respond(Http200, "Hello World")
 ##
-##    asyncCheck server.serve(Port(8080), cb)
-##    runForever()
+##    waitFor server.serve(Port(8080), cb)
 
 import strtabs, asyncnet, asyncdispatch, parseutils, uri, strutils
 type
@@ -109,22 +108,19 @@ proc sendHeaders*(req: Request, headers: StringTableRef): Future[void] =
   addHeaders(msg, headers)
   return req.client.send(msg)
 
-proc respond*(req: Request, code: HttpCode,
-        content: string, headers = newStringTable()) {.async.} =
+proc respond*(req: Request, code: HttpCode, content: string,
+              headers: StringTableRef = nil): Future[void] =
   ## Responds to the request with the specified ``HttpCode``, headers and
   ## content.
   ##
   ## This procedure will **not** close the client socket.
-  var customHeaders = headers
-  customHeaders["Content-Length"] = $content.len
   var msg = "HTTP/1.1 " & $code & "\c\L"
-  msg.addHeaders(customHeaders)
-  await req.client.send(msg & "\c\L" & content)
 
-proc newRequest(): Request =
-  result.headers = newStringTable(modeCaseInsensitive)
-  result.hostname = ""
-  result.body = ""
+  if headers != nil:
+    msg.addHeaders(headers)
+  msg.add("Content-Length: " & $content.len & "\c\L\c\L")
+  msg.add(content)
+  result = req.client.send(msg)
 
 proc parseHeader(line: string): tuple[key, value: string] =
   var i = 0
@@ -149,59 +145,65 @@ proc sendStatus(client: AsyncSocket, status: string): Future[void] =
 proc processClient(client: AsyncSocket, address: string,
                    callback: proc (request: Request):
                       Future[void] {.closure, gcsafe.}) {.async.} =
+  var request: Request
+  request.url = initUri()
+  request.headers = newStringTable(modeCaseInsensitive)
+  var line = newStringOfCap(80)
+  var key, value = ""
+
   while not client.isClosed:
     # GET /path HTTP/1.1
     # Header: val
     # \n
-    var request = newRequest()
-    request.hostname = address
+    request.headers.clear(modeCaseInsensitive)
+    request.hostname.shallowCopy(address)
     assert client != nil
     request.client = client
 
     # First line - GET /path HTTP/1.1
-    let line = await client.recvLine() # TODO: Timeouts.
+    line.setLen(0)
+    await client.recvLineInto(addr line) # TODO: Timeouts.
     if line == "":
       client.close()
       return
-    let lineParts = line.split(' ')
-    if lineParts.len != 3:
-      await request.respond(Http400, "Invalid request. Got: " & line)
-      continue
 
-    let reqMethod = lineParts[0]
-    let path = lineParts[1]
-    let protocol = lineParts[2]
+    var i = 0
+    for linePart in line.split(' '):
+      case i
+      of 0: request.reqMethod.shallowCopy(linePart.normalize)
+      of 1: parseUri(linePart, request.url)
+      of 2:
+        try:
+          request.protocol = parseProtocol(linePart)
+        except ValueError:
+          asyncCheck request.respond(Http400,
+            "Invalid request protocol. Got: " & linePart)
+          continue
+      else:
+        await request.respond(Http400, "Invalid request. Got: " & line)
+        continue
+      inc i
 
     # Headers
-    var i = 0
     while true:
       i = 0
-      let headerLine = await client.recvLine()
-      if headerLine == "":
-        client.close(); return
-      if headerLine == "\c\L": break
-      # TODO: Compiler crash
-      #let (key, value) = parseHeader(headerLine)
-      let kv = parseHeader(headerLine)
-      request.headers[kv.key] = kv.value
+      line.setLen(0)
+      await client.recvLineInto(addr line)
 
-    request.reqMethod = reqMethod
-    request.url = parseUri(path)
-    try:
-      request.protocol = protocol.parseProtocol()
-    except ValueError:
-      asyncCheck request.respond(Http400, "Invalid request protocol. Got: " &
-          protocol)
-      continue
+      if line == "":
+        client.close(); return
+      if line == "\c\L": break
+      let (key, value) = parseHeader(line)
+      request.headers[key] = value
 
-    if reqMethod.normalize == "post":
+    if request.reqMethod == "post":
       # Check for Expect header
       if request.headers.hasKey("Expect"):
         if request.headers["Expect"].toLower == "100-continue":
           await client.sendStatus("100 Continue")
         else:
           await client.sendStatus("417 Expectation Failed")
-    
+
       # Read the body
       # - Check for Content-length header
       if request.headers.hasKey("Content-Length"):
@@ -215,11 +217,11 @@ proc processClient(client: AsyncSocket, address: string,
         await request.respond(Http400, "Bad Request. No Content-Length.")
         continue
 
-    case reqMethod.normalize
+    case request.reqMethod
     of "get", "post", "head", "put", "delete", "trace", "options", "connect", "patch":
       await callback(request)
     else:
-      await request.respond(Http400, "Invalid request method. Got: " & reqMethod)
+      await request.respond(Http400, "Invalid request method. Got: " & request.reqMethod)
 
     # Persistent connections
     if (request.protocol == HttpVer11 and
@@ -247,7 +249,7 @@ proc serve*(server: AsyncHttpServer, port: Port,
     server.socket.setSockOpt(OptReuseAddr, true)
   server.socket.bindAddr(port, address)
   server.socket.listen()
-  
+
   while true:
     # TODO: Causes compiler crash.
     #var (address, client) = await server.socket.acceptAddr()
@@ -260,7 +262,7 @@ proc close*(server: AsyncHttpServer) =
   ## Terminates the async http server instance.
   server.socket.close()
 
-when isMainModule:
+when not defined(testing) and isMainModule:
   proc main =
     var server = newAsyncHttpServer()
     proc cb(req: Request) {.async.} =