summary refs log tree commit diff stats
path: root/lib
diff options
context:
space:
mode:
authorDominik Picheta <dominikpicheta@gmail.com>2016-09-24 22:36:19 +0200
committerDominik Picheta <dominikpicheta@gmail.com>2016-09-24 22:58:10 +0200
commitcff6ec2155bc1556d891118e3915ad72a5d3ee93 (patch)
treee99907024fb52ebbb5f00239e1b5890244623c59 /lib
parent9b810b17a20877d6229f1e9b069ec0478d234892 (diff)
downloadNim-cff6ec2155bc1556d891118e3915ad72a5d3ee93.tar.gz
Implements onProgressChanged callback for httpclient.
Diffstat (limited to 'lib')
-rw-r--r--lib/pure/httpclient.nim62
1 files changed, 51 insertions, 11 deletions
diff --git a/lib/pure/httpclient.nim b/lib/pure/httpclient.nim
index 9ac693581..0f6ee5a09 100644
--- a/lib/pure/httpclient.nim
+++ b/lib/pure/httpclient.nim
@@ -80,7 +80,7 @@
 ## currently only basic authentication is supported.
 
 import net, strutils, uri, parseutils, strtabs, base64, os, mimetypes,
-  math, random, httpcore
+  math, random, httpcore, times
 import asyncnet, asyncdispatch
 import nativesockets
 
@@ -669,17 +669,30 @@ proc generateHeaders(requestUrl: Uri, httpMethod: string,
   add(result, "\c\L")
 
 type
+  ProgressChangedProc*[ReturnType] =
+    proc (total, progress, speed: BiggestInt):
+      ReturnType {.closure, gcsafe.}
+
   HttpClientBase*[SocketType] = ref object
     socket: SocketType
     connected: bool
     currentURL: Uri ## Where we are currently connected.
-    headers*: HttpHeaders
+    headers*: HttpHeaders ## Headers to send in requests.
     maxRedirects: int
     userAgent: string
     timeout: int ## Only used for blocking HttpClient for now.
     proxy: Proxy
+    ## ``nil`` or the callback to call when request progress changes.
+    when SocketType is Socket:
+      onProgressChanged*: ProgressChangedProc[void]
+    else:
+      onProgressChanged*: ProgressChangedProc[Future[void]]
     when defined(ssl):
       sslContext: net.SslContext
+    contentTotal: BiggestInt
+    contentProgress: BiggestInt
+    oneSecondProgress: BiggestInt
+    lastProgressReport: float
 
 type
   HttpClient* = HttpClientBase[Socket]
@@ -708,6 +721,7 @@ proc newHttpClient*(userAgent = defUserAgent,
   result.maxRedirects = maxRedirects
   result.proxy = proxy
   result.timeout = timeout
+  result.onProgressChanged = nil
   when defined(ssl):
     result.sslContext = sslContext
 
@@ -737,6 +751,7 @@ proc newAsyncHttpClient*(userAgent = defUserAgent,
   result.maxRedirects = maxRedirects
   result.proxy = proxy
   result.timeout = -1 # TODO
+  result.onProgressChanged = nil
   when defined(ssl):
     result.sslContext = sslContext
 
@@ -746,19 +761,37 @@ proc close*(client: HttpClient | AsyncHttpClient) =
     client.socket.close()
     client.connected = false
 
-proc recvFull(socket: Socket | AsyncSocket,
+proc reportProgress(client: HttpClient | AsyncHttpClient,
+                    progress: BiggestInt) {.multisync.} =
+  client.contentProgress += progress
+  client.oneSecondProgress += progress
+  if epochTime() - client.lastProgressReport >= 1.0:
+    if not client.onProgressChanged.isNil:
+      await client.onProgressChanged(client.contentTotal,
+                                     client.contentProgress,
+                                     client.oneSecondProgress)
+      client.oneSecondProgress = 0
+      client.lastProgressReport = epochTime()
+
+proc recvFull(client: HttpClient | AsyncHttpClient,
               size: int, timeout: int): Future[string] {.multisync.} =
   ## Ensures that all the data requested is read and returned.
   result = ""
   while true:
     if size == result.len: break
-    when socket is Socket:
-      let data = socket.recv(size - result.len, timeout)
+
+    let remainingSize = size - result.len
+    let sizeToRecv = min(remainingSize, net.BufferSize)
+
+    when client.socket is Socket:
+      let data = client.socket.recv(sizeToRecv, timeout)
     else:
-      let data = await socket.recv(size - result.len)
+      let data = await client.socket.recv(sizeToRecv)
     if data == "": break # We've been disconnected.
     result.add data
 
+    await reportProgress(client, data.len)
+
 proc parseChunks(client: HttpClient | AsyncHttpClient): Future[string]
                  {.multisync.} =
   result = ""
@@ -786,10 +819,10 @@ proc parseChunks(client: HttpClient | AsyncHttpClient): Future[string]
         httpError("Invalid chunk size: " & chunkSizeStr)
       inc(i)
     if chunkSize <= 0:
-      discard await recvFull(client.socket, 2, client.timeout) # Skip \c\L
+      discard await recvFull(client, 2, client.timeout) # Skip \c\L
       break
-    result.add await recvFull(client.socket, chunkSize, client.timeout)
-    discard await recvFull(client.socket, 2, client.timeout) # Skip \c\L
+    result.add await recvFull(client, chunkSize, client.timeout)
+    discard await recvFull(client, 2, client.timeout) # Skip \c\L
     # Trailer headers will only be sent if the request specifies that we want
     # them: http://tools.ietf.org/html/rfc2616#section-3.6.1
 
@@ -797,6 +830,12 @@ proc parseBody(client: HttpClient | AsyncHttpClient,
                headers: HttpHeaders,
                httpVersion: string): Future[string] {.multisync.} =
   result = ""
+  # Reset progress from previous requests.
+  client.contentTotal = 0
+  client.contentProgress = 0
+  client.oneSecondProgress = 0
+  client.lastProgressReport = 0
+
   if headers.getOrDefault"Transfer-Encoding" == "chunked":
     result = await parseChunks(client)
   else:
@@ -805,8 +844,9 @@ proc parseBody(client: HttpClient | AsyncHttpClient,
     var contentLengthHeader = headers.getOrDefault"Content-Length"
     if contentLengthHeader != "":
       var length = contentLengthHeader.parseint()
+      client.contentTotal = length
       if length > 0:
-        result = await client.socket.recvFull(length, client.timeout)
+        result = await client.recvFull(length, client.timeout)
         if result == "":
           httpError("Got disconnected while trying to read body.")
         if result.len != length:
@@ -820,7 +860,7 @@ proc parseBody(client: HttpClient | AsyncHttpClient,
       if headers.getOrDefault"Connection" == "close" or httpVersion == "1.0":
         var buf = ""
         while true:
-          buf = await client.socket.recvFull(4000, client.timeout)
+          buf = await client.recvFull(4000, client.timeout)
           if buf == "": break
           result.add(buf)