summary refs log tree commit diff stats
diff options
context:
space:
mode:
authorDominik Picheta <dominikpicheta@gmail.com>2016-09-18 19:08:12 +0200
committerDominik Picheta <dominikpicheta@gmail.com>2016-09-18 19:08:12 +0200
commit0c99523ad314433538df44c734cdf987c6c4337e (patch)
tree6cbd47d424b6dda0e55c32fe75cad8fbec73a1c3
parent3c47e70d53231802c0e0d03ae64371a48655103c (diff)
downloadNim-0c99523ad314433538df44c734cdf987c6c4337e.tar.gz
Implements timeouts for synchronous HttpClient.
-rw-r--r--lib/pure/httpclient.nim31
-rw-r--r--tests/stdlib/thttpclient.nim13
2 files changed, 35 insertions, 9 deletions
diff --git a/lib/pure/httpclient.nim b/lib/pure/httpclient.nim
index 5f8b77c5b..ae9378331 100644
--- a/lib/pure/httpclient.nim
+++ b/lib/pure/httpclient.nim
@@ -639,6 +639,7 @@ type
     headers*: HttpHeaders
     maxRedirects: int
     userAgent: string
+    timeout: int ## Only used for blocking HttpClient for now.
     when defined(ssl):
       sslContext: net.SslContext
 
@@ -646,7 +647,8 @@ type
   HttpClient* = HttpClientBase[Socket]
 
 proc newHttpClient*(userAgent = defUserAgent,
-    maxRedirects = 5, sslContext = defaultSslContext): HttpClient =
+    maxRedirects = 5, sslContext = defaultSslContext,
+    timeout = -1): HttpClient =
   ## Creates a new HttpClient instance.
   ##
   ## ``userAgent`` specifies the user agent that will be used when making
@@ -656,10 +658,14 @@ proc newHttpClient*(userAgent = defUserAgent,
   ## default is 5.
   ##
   ## ``sslContext`` specifies the SSL context to use for HTTPS requests.
+  ##
+  ## ``timeout`` specifies the number of miliseconds to allow before a
+  ## ``TimeoutError`` is raised.
   new result
   result.headers = newHttpHeaders()
   result.userAgent = userAgent
   result.maxRedirects = maxRedirects
+  result.timeout = timeout
   when defined(ssl):
     result.sslContext = sslContext
 
@@ -683,6 +689,7 @@ proc newAsyncHttpClient*(userAgent = defUserAgent,
   result.headers = newHttpHeaders()
   result.userAgent = userAgent
   result.maxRedirects = maxRedirects
+  result.timeout = -1 # TODO
   when defined(ssl):
     result.sslContext = sslContext
 
@@ -693,12 +700,15 @@ proc close*(client: HttpClient | AsyncHttpClient) =
     client.connected = false
 
 proc recvFull(socket: Socket | AsyncSocket,
-              size: int): Future[string] {.multisync.} =
+              size: int, timeout: int): Future[string] {.multisync.} =
   ## Ensures that all the data requested is read and returned.
   result = ""
   while true:
     if size == result.len: break
-    let data = await socket.recv(size - result.len)
+    when socket is Socket:
+      let data = socket.recv(size - result.len, timeout)
+    else:
+      let data = await socket.recv(size - result.len)
     if data == "": break # We've been disconnected.
     result.add data
 
@@ -729,10 +739,10 @@ proc parseChunks(client: HttpClient | AsyncHttpClient): Future[string]
         httpError("Invalid chunk size: " & chunkSizeStr)
       inc(i)
     if chunkSize <= 0:
-      discard await recvFull(client.socket, 2) # Skip \c\L
+      discard await recvFull(client.socket, 2, client.timeout) # Skip \c\L
       break
-    result.add await recvFull(client.socket, chunkSize)
-    discard await recvFull(client.socket, 2) # Skip \c\L
+    result.add await recvFull(client.socket, chunkSize, client.timeout)
+    discard await recvFull(client.socket, 2, client.timeout) # Skip \c\L
     # Trailer headers will only be sent if the request specifies that we want
     # them: http://tools.ietf.org/html/rfc2616#section-3.6.1
 
@@ -749,7 +759,7 @@ proc parseBody(client: HttpClient | AsyncHttpClient,
     if contentLengthHeader != "":
       var length = contentLengthHeader.parseint()
       if length > 0:
-        result = await client.socket.recvFull(length)
+        result = await client.socket.recvFull(length, client.timeout)
         if result == "":
           httpError("Got disconnected while trying to read body.")
         if result.len != length:
@@ -763,7 +773,7 @@ proc parseBody(client: HttpClient | AsyncHttpClient,
       if headers.getOrDefault"Connection" == "close" or httpVersion == "1.0":
         var buf = ""
         while true:
-          buf = await client.socket.recvFull(4000)
+          buf = await client.socket.recvFull(4000, client.timeout)
           if buf == "": break
           result.add(buf)
 
@@ -776,7 +786,10 @@ proc parseResponse(client: HttpClient | AsyncHttpClient,
   result.headers = newHttpHeaders()
   while true:
     linei = 0
-    line = await client.socket.recvLine()
+    when client is HttpClient:
+      line = await client.socket.recvLine(client.timeout)
+    else:
+      line = await client.socket.recvLine()
     if line == "": break # We've been disconnected.
     if line == "\c\L":
       fullyRead = true
diff --git a/tests/stdlib/thttpclient.nim b/tests/stdlib/thttpclient.nim
index ced39d9c9..b5daa963a 100644
--- a/tests/stdlib/thttpclient.nim
+++ b/tests/stdlib/thttpclient.nim
@@ -1,4 +1,5 @@
 import strutils
+from net import TimeoutError
 
 import httpclient, asyncdispatch
 
@@ -30,6 +31,18 @@ proc syncTest() =
   resp = client.request("https://google.com/")
   doAssert(resp.code.is2xx or resp.code.is3xx)
 
+  client.close()
+
+  # Timeout test.
+  client = newHttpClient(timeout = 1)
+  try:
+    resp = client.request("http://example.com/")
+    doAssert false, "TimeoutError should have been raised."
+  except TimeoutError:
+    discard
+  except:
+    doAssert false, "TimeoutError should have been raised."
+
 syncTest()
 
 waitFor(asyncTest())