diff options
Diffstat (limited to 'lib/pure/asynchttpserver.nim')
-rw-r--r-- | lib/pure/asynchttpserver.nim | 39 |
1 files changed, 32 insertions, 7 deletions
diff --git a/lib/pure/asynchttpserver.nim b/lib/pure/asynchttpserver.nim index 005c56ebc..1b47cf5f1 100644 --- a/lib/pure/asynchttpserver.nim +++ b/lib/pure/asynchttpserver.nim @@ -20,7 +20,7 @@ type protocol*: tuple[orig: string, major, minor: int] url*: TURL hostname*: string ## The hostname of the client that made the request. - body*: string # TODO + body*: string PAsyncHttpServer* = ref object socket: PAsyncSocket @@ -78,7 +78,7 @@ proc parseHeader(line: string): tuple[key, value: string] = i += line.skipWhiteSpace(i) i += line.parseUntil(result.value, {'\c', '\L'}, i) -proc parseProtocol(protocol: string): tuple[orig: string, major, minor: int] = +proc parseProtocol(protocol: string): tuple[orig: string, major, minor: int] = var i = protocol.skipIgnoreCase("HTTP/") if i != 5: raise newException(EInvalidValue, "Invalid request protocol. Got: " & @@ -88,6 +88,9 @@ proc parseProtocol(protocol: string): tuple[orig: string, major, minor: int] = i.inc # Skip . i.inc protocol.parseInt(result.minor, i) +proc sendStatus(client: PAsyncSocket, status: string): PFuture[void] = + client.send("HTTP/1.1 " & status & "\c\L") + proc processClient(client: PAsyncSocket, address: string, callback: proc (request: TRequest): PFuture[void]) {.async.} = # GET /path HTTP/1.1 @@ -97,6 +100,7 @@ proc processClient(client: PAsyncSocket, address: string, request.hostname = address assert client != nil request.client = client + var runCallback = true # First line - GET /path HTTP/1.1 let line = await client.recvLine() # TODO: Timeouts. @@ -106,8 +110,7 @@ proc processClient(client: PAsyncSocket, address: string, let lineParts = line.split(' ') if lineParts.len != 3: request.respond(Http400, "Invalid request. Got: " & line) - client.close() - return + runCallback = false let reqMethod = lineParts[0] let path = lineParts[1] @@ -132,13 +135,35 @@ proc processClient(client: PAsyncSocket, address: string, request.protocol = protocol.parseProtocol() except EInvalidValue: request.respond(Http400, "Invalid request protocol. Got: " & protocol) - return + runCallback = false + + if reqMethod.normalize == "post": + # Check for Expect header + if request.headers.hasKey("Expect"): + if request.headers["Expect"].toLower == "100-continue": + await client.sendStatus("100 Continue") + else: + await client.sendStatus("417 Expectation Failed") + # Read the body + # - Check for Content-length header + if request.headers.hasKey("Content-Length"): + var contentLength = 0 + if parseInt(request.headers["Content-Length"], contentLength) == 0: + await request.respond(Http400, "Bad Request. Invalid Content-Length.") + else: + request.body = await client.recv(contentLength) + assert request.body.len == contentLength + else: + await request.respond(Http400, "Bad Request. No Content-Length.") + runCallback = false + case reqMethod.normalize of "get", "post", "head", "put", "delete", "trace", "options", "connect", "patch": - await callback(request) + if runCallback: + await callback(request) else: - request.respond(Http400, "Invalid request method. Got: " & reqMethod) + await request.respond(Http400, "Invalid request method. Got: " & reqMethod) # Persistent connections if (request.protocol == HttpVer11 and |