summary refs log tree commit diff stats
path: root/lib/pure/httpclient.nim
diff options
context:
space:
mode:
Diffstat (limited to 'lib/pure/httpclient.nim')
-rw-r--r--lib/pure/httpclient.nim191
1 files changed, 139 insertions, 52 deletions
diff --git a/lib/pure/httpclient.nim b/lib/pure/httpclient.nim
index 1ded540ec..e88847004 100644
--- a/lib/pure/httpclient.nim
+++ b/lib/pure/httpclient.nim
@@ -84,6 +84,9 @@
 ## .. code-block:: Nim
 ##   client.onProgressChanged = nil
 ##
+## **Warning:** The ``total`` reported by httpclient may be 0 in some cases.
+##
+##
 ## SSL/TLS support
 ## ===============
 ## This requires the OpenSSL library, fortunately it's widely used and installed
@@ -117,20 +120,28 @@
 ## only basic authentication is supported at the moment.
 
 import net, strutils, uri, parseutils, strtabs, base64, os, mimetypes,
-  math, random, httpcore, times, tables
-import asyncnet, asyncdispatch
+  math, random, httpcore, times, tables, streams
+import asyncnet, asyncdispatch, asyncfile
 import nativesockets
 
 export httpcore except parseHeader # TODO: The ``except`` doesn't work
 
 type
-  Response* = object
+  Response* = ref object
+    version*: string
+    status*: string
+    headers*: HttpHeaders
+    body: string
+    bodyStream*: Stream
+
+  AsyncResponse* = ref object
     version*: string
     status*: string
     headers*: HttpHeaders
-    body*: string
+    body: string
+    bodyStream*: FutureStream[string]
 
-proc code*(response: Response): HttpCode
+proc code*(response: Response | AsyncResponse): HttpCode
            {.raises: [ValueError, OverflowError].} =
   ## Retrieves the specified response's ``HttpCode``.
   ##
@@ -138,6 +149,27 @@ proc code*(response: Response): HttpCode
   ## corresponding ``HttpCode``.
   return response.status[0 .. 2].parseInt.HttpCode
 
+proc body*(response: Response): string =
+  ## Retrieves the specified response's body.
+  ##
+  ## The response's body stream is read synchronously.
+  if response.body.isNil():
+    response.body = response.bodyStream.readAll()
+  return response.body
+
+proc `body=`*(response: Response, value: string) {.deprecated.} =
+  ## Setter for backward compatibility.
+  ##
+  ## **This is deprecated and should not be used**.
+  response.body = value
+
+proc body*(response: AsyncResponse): Future[string] {.async.} =
+  ## Reads the response's body and caches it. The read is performed only
+  ## once.
+  if response.body.isNil:
+    response.body = await readAll(response.bodyStream)
+  return response.body
+
 type
   Proxy* = ref object
     url*: Uri
@@ -249,6 +281,7 @@ proc parseBody(s: Socket, headers: HttpHeaders, httpVersion: string, timeout: in
           result.add(buf)
 
 proc parseResponse(s: Socket, getBody: bool, timeout: int): Response =
+  new result
   var parsedStatus = false
   var linei = 0
   var fullyRead = false
@@ -604,7 +637,7 @@ proc post*(url: string, extraHeaders = "", body = "",
   ## **Deprecated since version 0.15.0**: use ``HttpClient.post`` instead.
   let (mpHeaders, mpBody) = format(multipart)
 
-  template withNewLine(x): expr =
+  template withNewLine(x): untyped =
     if x.len > 0 and not x.endsWith("\c\L"):
       x & "\c\L"
     else:
@@ -653,10 +686,13 @@ proc postContent*(url: string, extraHeaders = "", body = "",
 proc downloadFile*(url: string, outputFilename: string,
                    sslContext: SSLContext = defaultSSLContext,
                    timeout = -1, userAgent = defUserAgent,
-                   proxy: Proxy = nil) =
+                   proxy: Proxy = nil) {.deprecated.} =
   ## | Downloads ``url`` and saves it to ``outputFilename``
   ## | An optional timeout can be specified in milliseconds, if reading from the
   ## server takes longer than specified an ETimeout exception will be raised.
+  ##
+  ## **Deprecated since version 0.16.2**: use ``HttpClient.downloadFile``
+  ## instead.
   var f: File
   if open(f, outputFilename, fmWrite):
     f.write(getContent(url, sslContext = sslContext, timeout = timeout,
@@ -735,6 +771,11 @@ type
     contentProgress: BiggestInt
     oneSecondProgress: BiggestInt
     lastProgressReport: float
+    when SocketType is AsyncSocket:
+      bodyStream: FutureStream[string]
+    else:
+      bodyStream: Stream
+    getBody: bool ## When `false`, the body is never read in requestAux.
 
 type
   HttpClient* = HttpClientBase[Socket]
@@ -764,6 +805,8 @@ proc newHttpClient*(userAgent = defUserAgent,
   result.proxy = proxy
   result.timeout = timeout
   result.onProgressChanged = nil
+  result.bodyStream = newStringStream()
+  result.getBody = true
   when defined(ssl):
     result.sslContext = sslContext
 
@@ -794,6 +837,8 @@ proc newAsyncHttpClient*(userAgent = defUserAgent,
   result.proxy = proxy
   result.timeout = -1 # TODO
   result.onProgressChanged = nil
+  result.bodyStream = newFutureStream[string]("newAsyncHttpClient")
+  result.getBody = true
   when defined(ssl):
     result.sslContext = sslContext
 
@@ -815,14 +860,14 @@ proc reportProgress(client: HttpClient | AsyncHttpClient,
       client.oneSecondProgress = 0
       client.lastProgressReport = epochTime()
 
-proc recvFull(client: HttpClient | AsyncHttpClient,
-              size: int, timeout: int): Future[string] {.multisync.} =
+proc recvFull(client: HttpClient | AsyncHttpClient, size: int, timeout: int,
+              keep: bool): Future[int] {.multisync.} =
   ## Ensures that all the data requested is read and returned.
-  result = ""
+  var readLen = 0
   while true:
-    if size == result.len: break
+    if size == readLen: break
 
-    let remainingSize = size - result.len
+    let remainingSize = size - readLen
     let sizeToRecv = min(remainingSize, net.BufferSize)
 
     when client.socket is Socket:
@@ -830,13 +875,17 @@ proc recvFull(client: HttpClient | AsyncHttpClient,
     else:
       let data = await client.socket.recv(sizeToRecv)
     if data == "": break # We've been disconnected.
-    result.add data
+
+    readLen.inc(data.len)
+    if keep:
+      await client.bodyStream.write(data)
 
     await reportProgress(client, data.len)
 
-proc parseChunks(client: HttpClient | AsyncHttpClient): Future[string]
+  return readLen
+
+proc parseChunks(client: HttpClient | AsyncHttpClient): Future[void]
                  {.multisync.} =
-  result = ""
   while true:
     var chunkSize = 0
     var chunkSizeStr = await client.socket.recvLine()
@@ -861,25 +910,27 @@ proc parseChunks(client: HttpClient | AsyncHttpClient): Future[string]
         httpError("Invalid chunk size: " & chunkSizeStr)
       inc(i)
     if chunkSize <= 0:
-      discard await recvFull(client, 2, client.timeout) # Skip \c\L
+      discard await recvFull(client, 2, client.timeout, false) # Skip \c\L
       break
-    result.add await recvFull(client, chunkSize, client.timeout)
-    discard await recvFull(client, 2, client.timeout) # Skip \c\L
+    discard await recvFull(client, chunkSize, client.timeout, true)
+    discard await recvFull(client, 2, client.timeout, false) # Skip \c\L
     # Trailer headers will only be sent if the request specifies that we want
     # them: http://tools.ietf.org/html/rfc2616#section-3.6.1
 
 proc parseBody(client: HttpClient | AsyncHttpClient,
                headers: HttpHeaders,
-               httpVersion: string): Future[string] {.multisync.} =
-  result = ""
+               httpVersion: string): Future[void] {.multisync.} =
   # Reset progress from previous requests.
   client.contentTotal = 0
   client.contentProgress = 0
   client.oneSecondProgress = 0
   client.lastProgressReport = 0
 
+  when client is AsyncHttpClient:
+    assert(not client.bodyStream.finished)
+
   if headers.getOrDefault"Transfer-Encoding" == "chunked":
-    result = await parseChunks(client)
+    await parseChunks(client)
   else:
     # -REGION- Content-Length
     # (http://tools.ietf.org/html/rfc2616#section-4.4) NR.3
@@ -888,26 +939,31 @@ proc parseBody(client: HttpClient | AsyncHttpClient,
       var length = contentLengthHeader.parseint()
       client.contentTotal = length
       if length > 0:
-        result = await client.recvFull(length, client.timeout)
-        if result == "":
+        let recvLen = await client.recvFull(length, client.timeout, true)
+        if recvLen == 0:
           httpError("Got disconnected while trying to read body.")
-        if result.len != length:
+        if recvLen != length:
           httpError("Received length doesn't match expected length. Wanted " &
-                    $length & " got " & $result.len)
+                    $length & " got " & $recvLen)
     else:
       # (http://tools.ietf.org/html/rfc2616#section-4.4) NR.4 TODO
 
       # -REGION- Connection: Close
       # (http://tools.ietf.org/html/rfc2616#section-4.4) NR.5
       if headers.getOrDefault"Connection" == "close" or httpVersion == "1.0":
-        var buf = ""
         while true:
-          buf = await client.recvFull(4000, client.timeout)
-          if buf == "": break
-          result.add(buf)
+          let recvLen = await client.recvFull(4000, client.timeout, true)
+          if recvLen == 0: break
+
+  when client is AsyncHttpClient:
+    client.bodyStream.complete()
+  else:
+    client.bodyStream.setPosition(0)
 
 proc parseResponse(client: HttpClient | AsyncHttpClient,
-                   getBody: bool): Future[Response] {.multisync.} =
+                   getBody: bool): Future[Response | AsyncResponse]
+                   {.multisync.} =
+  new result
   var parsedStatus = false
   var linei = 0
   var fullyRead = false
@@ -955,10 +1011,14 @@ proc parseResponse(client: HttpClient | AsyncHttpClient,
 
   if not fullyRead:
     httpError("Connection was closed before full request has been made")
+
   if getBody:
-    result.body = await parseBody(client, result.headers, result.version)
-  else:
-    result.body = ""
+    when client is HttpClient:
+      client.bodyStream = newStringStream()
+    else:
+      client.bodyStream = newFutureStream[string]("parseResponse")
+    await parseBody(client, result.headers, result.version)
+    result.bodyStream = client.bodyStream
 
 proc newConnection(client: HttpClient | AsyncHttpClient,
                    url: Uri) {.multisync.} =
@@ -1006,8 +1066,9 @@ proc override(fallback, override: HttpHeaders): HttpHeaders =
     result[k] = vs
 
 proc requestAux(client: HttpClient | AsyncHttpClient, url: string,
-              httpMethod: string, body = "",
-              headers: HttpHeaders = nil): Future[Response] {.multisync.} =
+                httpMethod: string, body = "",
+                headers: HttpHeaders = nil): Future[Response | AsyncResponse]
+                {.multisync.} =
   # Helper that actually makes the request. Does not handle redirects.
   let connectionUrl =
     if client.proxy.isNil: parseUri(url) else: client.proxy.url
@@ -1047,16 +1108,17 @@ proc requestAux(client: HttpClient | AsyncHttpClient, url: string,
   if body != "":
     await client.socket.send(body)
 
-  result = await parseResponse(client,
-                               httpMethod.toLower() notin ["head", "connect"])
+  let getBody = httpMethod.toLowerAscii() notin ["head", "connect"] and
+                client.getBody
+  result = await parseResponse(client, getBody)
 
   # Restore the clients proxy in case it was overwritten.
   client.proxy = savedProxy
 
-
 proc request*(client: HttpClient | AsyncHttpClient, url: string,
               httpMethod: string, body = "",
-              headers: HttpHeaders = nil): Future[Response] {.multisync.} =
+              headers: HttpHeaders = nil): Future[Response | AsyncResponse]
+              {.multisync.} =
   ## Connects to the hostname specified by the URL and performs a request
   ## using the custom method string specified by ``httpMethod``.
   ##
@@ -1078,7 +1140,8 @@ proc request*(client: HttpClient | AsyncHttpClient, url: string,
 
 proc request*(client: HttpClient | AsyncHttpClient, url: string,
               httpMethod = HttpGET, body = "",
-              headers: HttpHeaders = nil): Future[Response] {.multisync.} =
+              headers: HttpHeaders = nil): Future[Response | AsyncResponse]
+              {.multisync.} =
   ## Connects to the hostname specified by the URL and performs a request
   ## using the method specified.
   ##
@@ -1088,11 +1151,10 @@ proc request*(client: HttpClient | AsyncHttpClient, url: string,
   ##
   ## When a request is made to a different hostname, the current connection will
   ## be closed.
-  result = await request(client, url, $httpMethod, body,
-                         headers = headers)
+  result = await request(client, url, $httpMethod, body, headers)
 
 proc get*(client: HttpClient | AsyncHttpClient,
-          url: string): Future[Response] {.multisync.} =
+          url: string): Future[Response | AsyncResponse] {.multisync.} =
   ## Connects to the hostname specified by the URL and performs a GET request.
   ##
   ## This procedure will follow redirects up to a maximum number of redirects
@@ -1112,17 +1174,18 @@ proc getContent*(client: HttpClient | AsyncHttpClient,
   if resp.code.is4xx or resp.code.is5xx:
     raise newException(HttpRequestError, resp.status)
   else:
-    return resp.body
+    return await resp.bodyStream.readAll()
 
 proc post*(client: HttpClient | AsyncHttpClient, url: string, body = "",
-           multipart: MultipartData = nil): Future[Response] {.multisync.} =
+           multipart: MultipartData = nil): Future[Response | AsyncResponse]
+           {.multisync.} =
   ## Connects to the hostname specified by the URL and performs a POST request.
   ##
   ## This procedure will follow redirects up to a maximum number of redirects
   ## specified in ``client.maxRedirects``.
   let (mpHeader, mpBody) = format(multipart)
-
-  template withNewLine(x): expr =
+  # TODO: Support FutureStream for `body` parameter.
+  template withNewLine(x): untyped =
     if x.len > 0 and not x.endsWith("\c\L"):
       x & "\c\L"
     else:
@@ -1134,16 +1197,14 @@ proc post*(client: HttpClient | AsyncHttpClient, url: string, body = "",
     headers["Content-Type"] = mpHeader.split(": ")[1]
   headers["Content-Length"] = $len(xb)
 
-  result = await client.requestAux(url, $HttpPOST, xb,
-                                headers = headers)
+  result = await client.requestAux(url, $HttpPOST, xb, headers)
   # Handle redirects.
   var lastURL = url
   for i in 1..client.maxRedirects:
     if result.status.redirection():
       let redirectTo = getNewLocation(lastURL, result.headers)
       var meth = if result.status != "307": HttpGet else: HttpPost
-      result = await client.requestAux(redirectTo, $meth, xb,
-                                    headers = headers)
+      result = await client.requestAux(redirectTo, $meth, xb, headers)
       lastURL = redirectTo
 
 proc postContent*(client: HttpClient | AsyncHttpClient, url: string,
@@ -1161,4 +1222,30 @@ proc postContent*(client: HttpClient | AsyncHttpClient, url: string,
   if resp.code.is4xx or resp.code.is5xx:
     raise newException(HttpRequestError, resp.status)
   else:
-    return resp.body
+    return await resp.bodyStream.readAll()
+
+proc downloadFile*(client: HttpClient | AsyncHttpClient,
+                   url: string, filename: string): Future[void] {.multisync.} =
+  ## Downloads ``url`` and saves it to ``filename``.
+  client.getBody = false
+  let resp = await client.get(url)
+
+  when client is HttpClient:
+    client.bodyStream = newFileStream(filename, fmWrite)
+    if client.bodyStream.isNil:
+      fileError("Unable to open file")
+    parseBody(client, resp.headers, resp.version)
+    client.bodyStream.close()
+  else:
+    client.bodyStream = newFutureStream[string]("downloadFile")
+    var file = openAsync(filename, fmWrite)
+    # Let `parseBody` write response data into client.bodyStream in the
+    # background.
+    asyncCheck parseBody(client, resp.headers, resp.version)
+    # The `writeFromStream` proc will complete once all the data in the
+    # `bodyStream` has been written to the file.
+    await file.writeFromStream(client.bodyStream)
+    file.close()
+
+  if resp.code.is4xx or resp.code.is5xx:
+    raise newException(HttpRequestError, resp.status)