summary refs log tree commit diff stats
path: root/lib/pure
diff options
context:
space:
mode:
Diffstat (limited to 'lib/pure')
-rw-r--r--lib/pure/asyncdispatch.nim3
-rw-r--r--lib/pure/asynchttpserver.nim39
2 files changed, 34 insertions, 8 deletions
diff --git a/lib/pure/asyncdispatch.nim b/lib/pure/asyncdispatch.nim
index 6292bfc12..12329951c 100644
--- a/lib/pure/asyncdispatch.nim
+++ b/lib/pure/asyncdispatch.nim
@@ -779,6 +779,7 @@ proc accept*(socket: TAsyncFD): PFuture[TAsyncFD] =
 template createCb*(retFutureSym, iteratorNameSym,
                    name: expr): stmt {.immediate.} =
   var nameIterVar = iteratorNameSym
+  #{.push stackTrace: off.}
   proc cb {.closure,gcsafe.} =
     try:
       if not nameIterVar.finished:
@@ -791,7 +792,7 @@ template createCb*(retFutureSym, iteratorNameSym,
     except:
       retFutureSym.fail(getCurrentException())
   cb()
-
+  #{.pop.}
 proc generateExceptionCheck(futSym,
     exceptBranch, rootReceiver: PNimrodNode): PNimrodNode {.compileTime.} =
   if exceptBranch == nil:
diff --git a/lib/pure/asynchttpserver.nim b/lib/pure/asynchttpserver.nim
index 005c56ebc..1b47cf5f1 100644
--- a/lib/pure/asynchttpserver.nim
+++ b/lib/pure/asynchttpserver.nim
@@ -20,7 +20,7 @@ type
     protocol*: tuple[orig: string, major, minor: int]
     url*: TURL
     hostname*: string ## The hostname of the client that made the request.
-    body*: string # TODO
+    body*: string
 
   PAsyncHttpServer* = ref object
     socket: PAsyncSocket
@@ -78,7 +78,7 @@ proc parseHeader(line: string): tuple[key, value: string] =
   i += line.skipWhiteSpace(i)
   i += line.parseUntil(result.value, {'\c', '\L'}, i)
 
-proc parseProtocol(protocol: string):  tuple[orig: string, major, minor: int] =
+proc parseProtocol(protocol: string): tuple[orig: string, major, minor: int] =
   var i = protocol.skipIgnoreCase("HTTP/")
   if i != 5:
     raise newException(EInvalidValue, "Invalid request protocol. Got: " &
@@ -88,6 +88,9 @@ proc parseProtocol(protocol: string):  tuple[orig: string, major, minor: int] =
   i.inc # Skip .
   i.inc protocol.parseInt(result.minor, i)
 
+proc sendStatus(client: PAsyncSocket, status: string): PFuture[void] =
+  client.send("HTTP/1.1 " & status & "\c\L")
+
 proc processClient(client: PAsyncSocket, address: string,
                  callback: proc (request: TRequest): PFuture[void]) {.async.} =
   # GET /path HTTP/1.1
@@ -97,6 +100,7 @@ proc processClient(client: PAsyncSocket, address: string,
   request.hostname = address
   assert client != nil
   request.client = client
+  var runCallback = true
 
   # First line - GET /path HTTP/1.1
   let line = await client.recvLine() # TODO: Timeouts.
@@ -106,8 +110,7 @@ proc processClient(client: PAsyncSocket, address: string,
   let lineParts = line.split(' ')
   if lineParts.len != 3:
     request.respond(Http400, "Invalid request. Got: " & line)
-    client.close()
-    return
+    runCallback = false
 
   let reqMethod = lineParts[0]
   let path = lineParts[1]
@@ -132,13 +135,35 @@ proc processClient(client: PAsyncSocket, address: string,
     request.protocol = protocol.parseProtocol()
   except EInvalidValue:
     request.respond(Http400, "Invalid request protocol. Got: " & protocol)
-    return
+    runCallback = false
+
+  if reqMethod.normalize == "post":
+    # Check for Expect header
+    if request.headers.hasKey("Expect"):
+      if request.headers["Expect"].toLower == "100-continue":
+        await client.sendStatus("100 Continue")
+      else:
+        await client.sendStatus("417 Expectation Failed")
   
+    # Read the body
+    # - Check for Content-length header
+    if request.headers.hasKey("Content-Length"):
+      var contentLength = 0
+      if parseInt(request.headers["Content-Length"], contentLength) == 0:
+        await request.respond(Http400, "Bad Request. Invalid Content-Length.")
+      else:
+        request.body = await client.recv(contentLength)
+        assert request.body.len == contentLength
+    else:
+      await request.respond(Http400, "Bad Request. No Content-Length.")
+      runCallback = false
+
   case reqMethod.normalize
   of "get", "post", "head", "put", "delete", "trace", "options", "connect", "patch":
-    await callback(request)
+    if runCallback:
+      await callback(request)
   else:
-    request.respond(Http400, "Invalid request method. Got: " & reqMethod)
+    await request.respond(Http400, "Invalid request method. Got: " & reqMethod)
 
   # Persistent connections
   if (request.protocol == HttpVer11 and