diff options
-rw-r--r-- | lib/pure/asyncnet.nim | 2 | ||||
-rw-r--r-- | lib/pure/net.nim | 35 |
2 files changed, 26 insertions, 11 deletions
diff --git a/lib/pure/asyncnet.nim b/lib/pure/asyncnet.nim index 5be457d2a..93399bb40 100644 --- a/lib/pure/asyncnet.nim +++ b/lib/pure/asyncnet.nim @@ -277,6 +277,7 @@ template readInto(buf: pointer, size: int, socket: AsyncSocket, flags: set[SocketFlag]): int = ## Reads **up to** ``size`` bytes from ``socket`` into ``buf``. Note that ## this is a template and not a proc. + assert(not socket.closed, "Cannot `recv` on a closed socket") var res = 0 if socket.isSsl: when defineSsl: @@ -403,6 +404,7 @@ proc send*(socket: AsyncSocket, buf: pointer, size: int, ## Sends ``size`` bytes from ``buf`` to ``socket``. The returned future will complete once all ## data has been sent. assert socket != nil + assert(not socket.closed, "Cannot `send` on a closed socket") if socket.isSsl: when defineSsl: sslLoop(socket, flags, diff --git a/lib/pure/net.nim b/lib/pure/net.nim index 9215a1812..f348b7c51 100644 --- a/lib/pure/net.nim +++ b/lib/pure/net.nim @@ -868,6 +868,7 @@ proc close*(socket: Socket) = socket.sslHandle = nil socket.fd.close() + socket.fd = osInvalidSocket when defined(posix): from posix import TCP_NODELAY @@ -1005,15 +1006,25 @@ proc select(readfd: Socket, timeout = 500): int = var fds = @[readfd.fd] result = select(fds, timeout) -proc readIntoBuf(socket: Socket, flags: int32): int = +proc isClosed(socket: Socket): bool = + socket.fd == osInvalidSocket + +proc uniRecv(socket: Socket, buffer: pointer, size, flags: cint): int = + ## Handles SSL and non-ssl recv in a nice package. + ## + ## In particular handles the case where socket has been closed properly + ## for both SSL and non-ssl. result = 0 + assert(not socket.isClosed, "Cannot `recv` on a closed socket") when defineSsl: - if socket.isSSL: - result = SSLRead(socket.sslHandle, addr(socket.buffer), int(socket.buffer.high)) - else: - result = recv(socket.fd, addr(socket.buffer), cint(socket.buffer.high), flags) - else: - result = recv(socket.fd, addr(socket.buffer), cint(socket.buffer.high), flags) + if socket.isSsl: + return SSLRead(socket.sslHandle, buffer, size) + + return recv(socket.fd, buffer, size, flags) + +proc readIntoBuf(socket: Socket, flags: int32): int = + result = 0 + result = uniRecv(socket, addr(socket.buffer), socket.buffer.high, flags) if result < 0: # Save it in case it gets reset (the Nim codegen occasionally may call # Win API functions which reset it). @@ -1059,16 +1070,16 @@ proc recv*(socket: Socket, data: pointer, size: int): int {.tags: [ReadIOEffect] else: when defineSsl: if socket.isSSL: - if socket.sslHasPeekChar: + if socket.sslHasPeekChar: # TODO: Merge this peek char mess into uniRecv copyMem(data, addr(socket.sslPeekChar), 1) socket.sslHasPeekChar = false if size-1 > 0: var d = cast[cstring](data) - result = SSLRead(socket.sslHandle, addr(d[1]), size-1) + 1 + result = uniRecv(socket, addr(d[1]), cint(size-1), 0'i32) + 1 else: result = 1 else: - result = SSLRead(socket.sslHandle, data, size) + result = uniRecv(socket, data, size.cint, 0'i32) else: result = recv(socket.fd, data, size.cint, 0'i32) else: @@ -1186,7 +1197,7 @@ proc peekChar(socket: Socket, c: var char): int {.tags: [ReadIOEffect].} = when defineSsl: if socket.isSSL: if not socket.sslHasPeekChar: - result = SSLRead(socket.sslHandle, addr(socket.sslPeekChar), 1) + result = uniRecv(socket, addr(socket.sslPeekChar), 1, 0'i32) socket.sslHasPeekChar = true c = socket.sslPeekChar @@ -1320,6 +1331,7 @@ proc send*(socket: Socket, data: pointer, size: int): int {. ## ## **Note**: This is a low-level version of ``send``. You likely should use ## the version below. + assert(not socket.isClosed, "Cannot `send` on a closed socket") when defineSsl: if socket.isSSL: return SSLWrite(socket.sslHandle, cast[cstring](data), size) @@ -1364,6 +1376,7 @@ proc sendTo*(socket: Socket, address: string, port: Port, data: pointer, ## which is defined below. ## ## **Note:** This proc is not available for SSL sockets. + assert(not socket.isClosed, "Cannot `sendTo` on a closed socket") var aiList = getAddrInfo(address, port, af) # try all possibilities: |