summary refs log tree commit diff stats
path: root/lib
diff options
context:
space:
mode:
Diffstat (limited to 'lib')
-rw-r--r--lib/pure/asyncnet.nim2
-rw-r--r--lib/pure/net.nim35
2 files changed, 26 insertions, 11 deletions
diff --git a/lib/pure/asyncnet.nim b/lib/pure/asyncnet.nim
index 5be457d2a..93399bb40 100644
--- a/lib/pure/asyncnet.nim
+++ b/lib/pure/asyncnet.nim
@@ -277,6 +277,7 @@ template readInto(buf: pointer, size: int, socket: AsyncSocket,
                   flags: set[SocketFlag]): int =
   ## Reads **up to** ``size`` bytes from ``socket`` into ``buf``. Note that
   ## this is a template and not a proc.
+  assert(not socket.closed, "Cannot `recv` on a closed socket")
   var res = 0
   if socket.isSsl:
     when defineSsl:
@@ -403,6 +404,7 @@ proc send*(socket: AsyncSocket, buf: pointer, size: int,
   ## Sends ``size`` bytes from ``buf`` to ``socket``. The returned future will complete once all
   ## data has been sent.
   assert socket != nil
+  assert(not socket.closed, "Cannot `send` on a closed socket")
   if socket.isSsl:
     when defineSsl:
       sslLoop(socket, flags,
diff --git a/lib/pure/net.nim b/lib/pure/net.nim
index 9215a1812..f348b7c51 100644
--- a/lib/pure/net.nim
+++ b/lib/pure/net.nim
@@ -868,6 +868,7 @@ proc close*(socket: Socket) =
         socket.sslHandle = nil
 
     socket.fd.close()
+    socket.fd = osInvalidSocket
 
 when defined(posix):
   from posix import TCP_NODELAY
@@ -1005,15 +1006,25 @@ proc select(readfd: Socket, timeout = 500): int =
   var fds = @[readfd.fd]
   result = select(fds, timeout)
 
-proc readIntoBuf(socket: Socket, flags: int32): int =
+proc isClosed(socket: Socket): bool =
+  socket.fd == osInvalidSocket
+
+proc uniRecv(socket: Socket, buffer: pointer, size, flags: cint): int =
+  ## Handles SSL and non-ssl recv in a nice package.
+  ##
+  ## In particular handles the case where socket has been closed properly
+  ## for both SSL and non-ssl.
   result = 0
+  assert(not socket.isClosed, "Cannot `recv` on a closed socket")
   when defineSsl:
-    if socket.isSSL:
-      result = SSLRead(socket.sslHandle, addr(socket.buffer), int(socket.buffer.high))
-    else:
-      result = recv(socket.fd, addr(socket.buffer), cint(socket.buffer.high), flags)
-  else:
-    result = recv(socket.fd, addr(socket.buffer), cint(socket.buffer.high), flags)
+    if socket.isSsl:
+      return SSLRead(socket.sslHandle, buffer, size)
+
+  return recv(socket.fd, buffer, size, flags)
+
+proc readIntoBuf(socket: Socket, flags: int32): int =
+  result = 0
+  result = uniRecv(socket, addr(socket.buffer), socket.buffer.high, flags)
   if result < 0:
     # Save it in case it gets reset (the Nim codegen occasionally may call
     # Win API functions which reset it).
@@ -1059,16 +1070,16 @@ proc recv*(socket: Socket, data: pointer, size: int): int {.tags: [ReadIOEffect]
   else:
     when defineSsl:
       if socket.isSSL:
-        if socket.sslHasPeekChar:
+        if socket.sslHasPeekChar: # TODO: Merge this peek char mess into uniRecv
           copyMem(data, addr(socket.sslPeekChar), 1)
           socket.sslHasPeekChar = false
           if size-1 > 0:
             var d = cast[cstring](data)
-            result = SSLRead(socket.sslHandle, addr(d[1]), size-1) + 1
+            result = uniRecv(socket, addr(d[1]), cint(size-1), 0'i32) + 1
           else:
             result = 1
         else:
-          result = SSLRead(socket.sslHandle, data, size)
+          result = uniRecv(socket, data, size.cint, 0'i32)
       else:
         result = recv(socket.fd, data, size.cint, 0'i32)
     else:
@@ -1186,7 +1197,7 @@ proc peekChar(socket: Socket, c: var char): int {.tags: [ReadIOEffect].} =
     when defineSsl:
       if socket.isSSL:
         if not socket.sslHasPeekChar:
-          result = SSLRead(socket.sslHandle, addr(socket.sslPeekChar), 1)
+          result = uniRecv(socket, addr(socket.sslPeekChar), 1, 0'i32)
           socket.sslHasPeekChar = true
 
         c = socket.sslPeekChar
@@ -1320,6 +1331,7 @@ proc send*(socket: Socket, data: pointer, size: int): int {.
   ##
   ## **Note**: This is a low-level version of ``send``. You likely should use
   ## the version below.
+  assert(not socket.isClosed, "Cannot `send` on a closed socket")
   when defineSsl:
     if socket.isSSL:
       return SSLWrite(socket.sslHandle, cast[cstring](data), size)
@@ -1364,6 +1376,7 @@ proc sendTo*(socket: Socket, address: string, port: Port, data: pointer,
   ## which is defined below.
   ##
   ## **Note:** This proc is not available for SSL sockets.
+  assert(not socket.isClosed, "Cannot `sendTo` on a closed socket")
   var aiList = getAddrInfo(address, port, af)
 
   # try all possibilities: