diff options
Diffstat (limited to 'lib/pure/asynchttpserver.nim')
-rw-r--r-- | lib/pure/asynchttpserver.nim | 268 |
1 files changed, 142 insertions, 126 deletions
diff --git a/lib/pure/asynchttpserver.nim b/lib/pure/asynchttpserver.nim index 6d4b85145..ba1615651 100644 --- a/lib/pure/asynchttpserver.nim +++ b/lib/pure/asynchttpserver.nim @@ -58,15 +58,18 @@ type socket: AsyncSocket reuseAddr: bool reusePort: bool + maxBody: int ## The maximum content-length that will be read for the body. {.deprecated: [TRequest: Request, PAsyncHttpServer: AsyncHttpServer, THttpCode: HttpCode, THttpVersion: HttpVersion].} -proc newAsyncHttpServer*(reuseAddr = true, reusePort = false): AsyncHttpServer = +proc newAsyncHttpServer*(reuseAddr = true, reusePort = false, + maxBody = 8388608): AsyncHttpServer = ## Creates a new ``AsyncHttpServer`` instance. new result result.reuseAddr = reuseAddr result.reusePort = reusePort + result.maxBody = maxBody proc addHeaders(msg: var string, headers: HttpHeaders) = for k, v in headers: @@ -122,144 +125,157 @@ proc parseProtocol(protocol: string): tuple[orig: string, major, minor: int] = raise newException(ValueError, "Invalid request protocol. Got: " & protocol) result.orig = protocol - i.inc protocol.parseInt(result.major, i) + i.inc protocol.parseSaturatedNatural(result.major, i) i.inc # Skip . - i.inc protocol.parseInt(result.minor, i) + i.inc protocol.parseSaturatedNatural(result.minor, i) proc sendStatus(client: AsyncSocket, status: string): Future[void] = client.send("HTTP/1.1 " & status & "\c\L\c\L") -proc processClient(client: AsyncSocket, address: string, - callback: proc (request: Request): +proc processRequest(server: AsyncHttpServer, req: FutureVar[Request], + client: AsyncSocket, + address: string, lineFut: FutureVar[string], + callback: proc (request: Request): Future[void] {.closure, gcsafe.}) {.async.} = - var request: Request - request.url = initUri() - request.headers = newHttpHeaders() - var lineFut = newFutureVar[string]("asynchttpserver.processClient") - lineFut.mget() = newStringOfCap(80) - var key, value = "" - while not client.isClosed: - # GET /path HTTP/1.1 - # Header: val - # \n - request.headers.clear() - request.body = "" - request.hostname.shallowCopy(address) - assert client != nil - request.client = client - - # We should skip at least one empty line before the request - # https://tools.ietf.org/html/rfc7230#section-3.5 - for i in 0..1: - lineFut.mget().setLen(0) - lineFut.clean() - await client.recvLineInto(lineFut, maxLength=maxLine) # TODO: Timeouts. - - if lineFut.mget == "": - client.close() - return + # Alias `request` to `req.mget()` so we don't have to write `mget` everywhere. + template request(): Request = + req.mget() + + # GET /path HTTP/1.1 + # Header: val + # \n + request.headers.clear() + request.body = "" + request.hostname.shallowCopy(address) + assert client != nil + request.client = client + + # We should skip at least one empty line before the request + # https://tools.ietf.org/html/rfc7230#section-3.5 + for i in 0..1: + lineFut.mget().setLen(0) + lineFut.clean() + await client.recvLineInto(lineFut, maxLength=maxLine) # TODO: Timeouts. + + if lineFut.mget == "": + client.close() + return - if lineFut.mget.len > maxLine: - await request.respondError(Http413) - client.close() + if lineFut.mget.len > maxLine: + await request.respondError(Http413) + client.close() + return + if lineFut.mget != "\c\L": + break + + # First line - GET /path HTTP/1.1 + var i = 0 + for linePart in lineFut.mget.split(' '): + case i + of 0: + try: + # TODO: this is likely slow. + request.reqMethod = parseEnum[HttpMethod]("http" & linePart) + except ValueError: + asyncCheck request.respondError(Http400) return - if lineFut.mget != "\c\L": - break - - # First line - GET /path HTTP/1.1 - var i = 0 - for linePart in lineFut.mget.split(' '): - case i - of 0: - try: - # TODO: this is likely slow. - request.reqMethod = parseEnum[HttpMethod]("http" & linePart) - except ValueError: - asyncCheck request.respondError(Http400) - continue - of 1: - try: - parseUri(linePart, request.url) - except ValueError: - asyncCheck request.respondError(Http400) - continue - of 2: - try: - request.protocol = parseProtocol(linePart) - except ValueError: - asyncCheck request.respondError(Http400) - continue - else: - await request.respondError(Http400) - continue - inc i - - # Headers - while true: - i = 0 - lineFut.mget.setLen(0) - lineFut.clean() - await client.recvLineInto(lineFut, maxLength=maxLine) - - if lineFut.mget == "": - client.close(); return - if lineFut.mget.len > maxLine: - await request.respondError(Http413) - client.close(); return - if lineFut.mget == "\c\L": break - let (key, value) = parseHeader(lineFut.mget) - request.headers[key] = value - # Ensure the client isn't trying to DoS us. - if request.headers.len > headerLimit: - await client.sendStatus("400 Bad Request") - request.client.close() + of 1: + try: + parseUri(linePart, request.url) + except ValueError: + asyncCheck request.respondError(Http400) return + of 2: + try: + request.protocol = parseProtocol(linePart) + except ValueError: + asyncCheck request.respondError(Http400) + return + else: + await request.respondError(Http400) + return + inc i - if request.reqMethod == HttpPost: - # Check for Expect header - if request.headers.hasKey("Expect"): - if "100-continue" in request.headers["Expect"]: - await client.sendStatus("100 Continue") - else: - await client.sendStatus("417 Expectation Failed") - - # Read the body - # - Check for Content-length header - if request.headers.hasKey("Content-Length"): - var contentLength = 0 - if parseInt(request.headers["Content-Length"], - contentLength) == 0: - await request.respond(Http400, "Bad Request. Invalid Content-Length.") - continue - else: - request.body = await client.recv(contentLength) - if request.body.len != contentLength: - await request.respond(Http400, "Bad Request. Content-Length does not match actual.") - continue - elif request.reqMethod == HttpPost: - await request.respond(Http411, "Content-Length required.") - continue - - # Call the user's callback. - await callback(request) - - if "upgrade" in request.headers.getOrDefault("connection"): + # Headers + while true: + i = 0 + lineFut.mget.setLen(0) + lineFut.clean() + await client.recvLineInto(lineFut, maxLength=maxLine) + + if lineFut.mget == "": + client.close(); return + if lineFut.mget.len > maxLine: + await request.respondError(Http413) + client.close(); return + if lineFut.mget == "\c\L": break + let (key, value) = parseHeader(lineFut.mget) + request.headers[key] = value + # Ensure the client isn't trying to DoS us. + if request.headers.len > headerLimit: + await client.sendStatus("400 Bad Request") + request.client.close() return - # Persistent connections - if (request.protocol == HttpVer11 and - request.headers.getOrDefault("connection").normalize != "close") or - (request.protocol == HttpVer10 and - request.headers.getOrDefault("connection").normalize == "keep-alive"): - # In HTTP 1.1 we assume that connection is persistent. Unless connection - # header states otherwise. - # In HTTP 1.0 we assume that the connection should not be persistent. - # Unless the connection header states otherwise. - discard + if request.reqMethod == HttpPost: + # Check for Expect header + if request.headers.hasKey("Expect"): + if "100-continue" in request.headers["Expect"]: + await client.sendStatus("100 Continue") + else: + await client.sendStatus("417 Expectation Failed") + + # Read the body + # - Check for Content-length header + if request.headers.hasKey("Content-Length"): + var contentLength = 0 + if parseSaturatedNatural(request.headers["Content-Length"], contentLength) == 0: + await request.respond(Http400, "Bad Request. Invalid Content-Length.") + return else: - request.client.close() - break + if contentLength > server.maxBody: + await request.respondError(Http413) + return + request.body = await client.recv(contentLength) + if request.body.len != contentLength: + await request.respond(Http400, "Bad Request. Content-Length does not match actual.") + return + elif request.reqMethod == HttpPost: + await request.respond(Http411, "Content-Length required.") + return + + # Call the user's callback. + await callback(request) + + if "upgrade" in request.headers.getOrDefault("connection"): + return + + # Persistent connections + if (request.protocol == HttpVer11 and + cmpIgnoreCase(request.headers.getOrDefault("connection"), "close") != 0) or + (request.protocol == HttpVer10 and + cmpIgnoreCase(request.headers.getOrDefault("connection"), "keep-alive") == 0): + # In HTTP 1.1 we assume that connection is persistent. Unless connection + # header states otherwise. + # In HTTP 1.0 we assume that the connection should not be persistent. + # Unless the connection header states otherwise. + discard + else: + request.client.close() + return + +proc processClient(server: AsyncHttpServer, client: AsyncSocket, address: string, + callback: proc (request: Request): + Future[void] {.closure, gcsafe.}) {.async.} = + var request = newFutureVar[Request]("asynchttpserver.processClient") + request.mget().url = initUri() + request.mget().headers = newHttpHeaders() + var lineFut = newFutureVar[string]("asynchttpserver.processClient") + lineFut.mget() = newStringOfCap(80) + + while not client.isClosed: + await processRequest(server, request, client, address, lineFut, callback) proc serve*(server: AsyncHttpServer, port: Port, callback: proc (request: Request): Future[void] {.closure,gcsafe.}, @@ -280,7 +296,7 @@ proc serve*(server: AsyncHttpServer, port: Port, # TODO: Causes compiler crash. #var (address, client) = await server.socket.acceptAddr() var fut = await server.socket.acceptAddr() - asyncCheck processClient(fut.client, fut.address, callback) + asyncCheck processClient(server, fut.client, fut.address, callback) #echo(f.isNil) #echo(f.repr) |