summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--lib/pure/asynchttpserver.nim27
-rw-r--r--tests/stdlib/tasynchttpserver_transferencoding.nim81
2 files changed, 105 insertions, 3 deletions
diff --git a/lib/pure/asynchttpserver.nim b/lib/pure/asynchttpserver.nim
index 86688d4b5..df4f97813 100644
--- a/lib/pure/asynchttpserver.nim
+++ b/lib/pure/asynchttpserver.nim
@@ -43,6 +43,7 @@ runnableExamples:
 
 import asyncnet, asyncdispatch, parseutils, uri, strutils
 import httpcore
+import std/private/since
 
 export httpcore except parseHeader
 
@@ -71,6 +72,22 @@ type
     maxBody: int ## The maximum content-length that will be read for the body.
     maxFDs: int
 
+func getSocket*(a: AsyncHttpServer): AsyncSocket {.since: (1, 5, 1).} =
+  ## Returns the ``AsyncHttpServer``s internal ``AsyncSocket`` instance.
+  ## 
+  ## Useful for identifying what port the AsyncHttpServer is bound to, if it
+  ## was chosen automatically.
+  runnableExamples:
+    from asyncdispatch import Port
+    from asyncnet import getFd
+    from nativesockets import getLocalAddr, AF_INET
+    let server = newAsyncHttpServer()
+    server.listen(Port(0)) # Socket is not bound until this point
+    let port = getLocalAddr(server.getSocket.getFd, AF_INET)[1]
+    doAssert uint16(port) > 0
+    server.close()
+  a.socket
+
 proc newAsyncHttpServer*(reuseAddr = true, reusePort = false,
                          maxBody = 8388608): AsyncHttpServer =
   ## Creates a new ``AsyncHttpServer`` instance.
@@ -300,9 +317,13 @@ proc processRequest(
           break
 
         # Read bytesToRead and add to body
-        # Note we add +2 because the line must be terminated by \r\n
-        let chunk = await client.recv(bytesToRead + 2)
-        request.body = request.body & chunk
+        let chunk = await client.recv(bytesToRead)
+        request.body.add(chunk)
+        # Skip \r\n (chunk terminating bytes per spec)
+        let separator = await client.recv(2)
+        if separator != "\r\n":
+          await request.respond(Http400, "Bad Request. Encoding separator must be \\r\\n")
+          return true
 
       inc sizeOrData
   elif request.reqMethod == HttpPost:
diff --git a/tests/stdlib/tasynchttpserver_transferencoding.nim b/tests/stdlib/tasynchttpserver_transferencoding.nim
new file mode 100644
index 000000000..34f3cef11
--- /dev/null
+++ b/tests/stdlib/tasynchttpserver_transferencoding.nim
@@ -0,0 +1,81 @@
+import httpclient, asynchttpserver, asyncdispatch, asyncfutures
+import net
+
+import std/asyncnet
+import std/nativesockets
+
+const postBegin = """
+POST / HTTP/1.1
+Transfer-Encoding:chunked
+
+"""
+
+template genTest(input, expected) =
+  var sanity = false
+  proc handler(request: Request) {.async.} =
+      doAssert(request.body == expected)
+      doAssert(request.headers.hasKey("Transfer-Encoding"))
+      doAssert(not request.headers.hasKey("Content-Length"))
+      sanity = true
+      await request.respond(Http200, "Good")
+
+  proc runSleepLoop(server: AsyncHttpServer) {.async.} = 
+    server.listen(Port(0))
+    proc wrapper() = 
+      waitFor server.acceptRequest(handler)
+    asyncdispatch.callSoon wrapper
+
+  let server = newAsyncHttpServer()
+  waitFor runSleepLoop(server)
+  let port = getLocalAddr(server.getSocket.getFd, AF_INET)[1]
+  let data = postBegin & input
+  var socket = newSocket()
+  socket.connect("127.0.0.1", port)
+  socket.send(data)
+  waitFor sleepAsync(10)
+  socket.close()
+  server.close()
+
+  # Verify we ran the handler and its asserts
+  doAssert(sanity)
+
+block:
+  const expected = "hello=world"
+  const input = ("b\r\n" &
+                 "hello=world\r\n" &
+                 "0\r\n" &
+                 "\r\n")
+  genTest(input, expected)
+block:
+  const expected = "hello encoding"
+  const input = ("e\r\n" &
+                 "hello encoding\r\n" &
+                 "0\r\n" &
+                 "\r\n")
+  genTest(input, expected)
+block:
+  # https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Transfer-Encoding
+  const expected = "MozillaDeveloperNetwork"
+  const input = ("7\r\n" &
+                "Mozilla\r\n" &
+                "9\r\n" &
+                "Developer\r\n" &
+                "7\r\n" &
+                "Network\r\n" &
+                "0\r\n" &
+                "\r\n")
+  genTest(input, expected)
+block:
+  # https://en.wikipedia.org/wiki/Chunked_transfer_encoding#Example
+  const expected = "Wikipedia in \r\n\r\nchunks."
+  const input = ("4\r\n" &
+                "Wiki\r\n" &
+                "6\r\n" &
+                "pedia \r\n" &
+                "E\r\n" &
+                "in \r\n" &
+                "\r\n" &
+                "chunks.\r\n" &
+                "0\r\n" &
+                "\r\n")
+  genTest(input, expected)