summary refs log tree commit diff stats
path: root/lib
diff options
context:
space:
mode:
authorDominik Picheta <dominikpicheta@googlemail.com>2012-12-22 23:03:28 +0000
committerDominik Picheta <dominikpicheta@googlemail.com>2012-12-22 23:03:28 +0000
commit6cb8edfce94c9236dfd8a30380d7474a8ede6f87 (patch)
tree1824f9a08690210eda433c01f98201a003c63b6d /lib
parentb6c8e16b0f23b09fc4c35e3c2542b75c32152e62 (diff)
downloadNim-6cb8edfce94c9236dfd8a30380d7474a8ede6f87.tar.gz
recvLine now works with unbuffered ssl sockets.
Added higher level recv functions.
Diffstat (limited to 'lib')
-rwxr-xr-xlib/pure/sockets.nim77
1 files changed, 44 insertions, 33 deletions
diff --git a/lib/pure/sockets.nim b/lib/pure/sockets.nim
index 5f7ba6ac2..5ec831bcc 100755
--- a/lib/pure/sockets.nim
+++ b/lib/pure/sockets.nim
@@ -62,6 +62,8 @@ type
         sslHandle: PSSL
         sslContext: PSSLContext
         sslNoHandshake: bool # True if needs handshake.
+        sslHasPeekChar: bool
+        sslPeekChar: char
       of false: nil
   
   TSocket* = ref TSocketImpl
@@ -291,6 +293,7 @@ when defined(ssl):
     socket.sslContext = ctx
     socket.sslHandle = SSLNew(PSSLCTX(socket.sslContext))
     socket.sslNoHandshake = false
+    socket.sslHasPeekChar = false
     if socket.sslHandle == nil:
       SSLError()
     
@@ -849,11 +852,8 @@ proc checkBuffer(readfds: var seq[TSocket]): int =
   var res: seq[TSocket] = @[]
   result = 0
   for s in readfds:
-    if s.isBuffered:
-      if s.bufLen <= 0 or s.currPos == s.bufLen:
-        res.add(s)
-      else:
-        inc(result)
+    if hasDataBuffered(s):
+      inc(result)
     else:
       res.add(s)
   readfds = res
@@ -975,42 +975,46 @@ template retRead(flags, readBytes: int) =
 
 proc recv*(socket: TSocket, data: pointer, size: int): int {.tags: [FReadIO].} =
   ## receives data from a socket
+  if size == 0: return
   if socket.isBuffered:
     if socket.bufLen == 0:
       retRead(0'i32, 0)
     
-    when true:
-      var read = 0
-      while read < size:
-        if socket.currPos >= socket.bufLen:
-          retRead(0'i32, read)
-      
-        let chunk = min(socket.bufLen-socket.currPos, size-read)
-        var d = cast[cstring](data)
-        copyMem(addr(d[read]), addr(socket.buffer[socket.currPos]), chunk)
-        read.inc(chunk)
-        socket.currPos.inc(chunk)
-    else:
-      var read = 0
-      while read < size:
-        if socket.currPos >= socket.bufLen:
-          retRead(0'i32, read)
-      
-        var d = cast[cstring](data)
-        d[read] = socket.buffer[socket.currPos]
-        read.inc(1)
-        socket.currPos.inc(1)
+    var read = 0
+    while read < size:
+      if socket.currPos >= socket.bufLen:
+        retRead(0'i32, read)
     
+      let chunk = min(socket.bufLen-socket.currPos, size-read)
+      var d = cast[cstring](data)
+      copyMem(addr(d[read]), addr(socket.buffer[socket.currPos]), chunk)
+      read.inc(chunk)
+      socket.currPos.inc(chunk)
+
     result = read
   else:
     when defined(ssl):
       if socket.isSSL:
-        result = SSLRead(socket.sslHandle, data, size)
+        if socket.sslHasPeekChar:
+          copyMem(data, addr(socket.sslPeekChar), 1)
+          socket.sslHasPeekChar = false
+          if size-1 > 0:
+            var d = cast[cstring](data)
+            result = SSLRead(socket.sslHandle, addr(d[1]), size-1) + 1
+          else:
+            result = 1
+        else:
+          result = SSLRead(socket.sslHandle, data, size)
       else:
         result = recv(socket.fd, data, size.cint, 0'i32)
     else:
       result = recv(socket.fd, data, size.cint, 0'i32)
 
+proc recv*(socket: TSocket, data: var string, size: int): int =
+  ## higher-level version of the above
+  data.setLen(size)
+  result = recv(socket, cstring(data), size)
+
 proc waitFor(socket: TSocket, waited: var float, timeout: int): int {.
   tags: [FTime].} =
   ## returns the number of characters available to be read. In unbuffered
@@ -1045,6 +1049,11 @@ proc recv*(socket: TSocket, data: pointer, size: int, timeout: int): int {.
   
   result = read
 
+proc recv*(socket: TSocket, data: var string, size: int, timeout: int): int =
+  # higher-level version of the above
+  data.setLen(size)
+  result = recv(socket, cstring(data), size, timeout)
+
 proc peekChar(socket: TSocket, c: var char): int {.tags: [FReadIO].} =
   if socket.isBuffered:
     result = 1
@@ -1057,8 +1066,12 @@ proc peekChar(socket: TSocket, c: var char): int {.tags: [FReadIO].} =
   else:
     when defined(ssl):
       if socket.isSSL:
-        raise newException(ESSL, "Sorry, you cannot use recvLine on an unbuffered SSL socket.")
-  
+        if not socket.sslHasPeekChar:
+          result = SSLRead(socket.sslHandle, addr(socket.sslPeekChar), 1)
+          socket.sslHasPeekChar = true
+        
+        c = socket.sslPeekChar
+        return
     result = recv(socket.fd, addr(c), 1, MSG_PEEK)
 
 proc recvLine*(socket: TSocket, line: var TaintedString): bool {.
@@ -1068,13 +1081,11 @@ proc recvLine*(socket: TSocket, line: var TaintedString): bool {.
   ## will be set to it.
   ## 
   ## ``True`` is returned if data is available. ``False`` usually suggests an
-  ## error, EOS exceptions are not raised in favour of this.
+  ## error, EOS exceptions are not raised and ``False`` is simply returned
+  ## instead.
   ## 
   ## If the socket is disconnected, ``line`` will be set to ``""`` and ``True``
   ## will be returned.
-  ##
-  ## **Warning:** Using this function on a unbuffered ssl socket will result
-  ## in an error.
   template addNLIfEmpty(): stmt =
     if line.len == 0:
       line.add("\c\L")