diff options
-rw-r--r-- | lib/pure/asyncdispatch.nim | 43 | ||||
-rw-r--r-- | lib/pure/asyncnet.nim | 37 | ||||
-rw-r--r-- | lib/pure/net.nim | 71 | ||||
-rw-r--r-- | lib/pure/rawsockets.nim | 5 |
4 files changed, 108 insertions, 48 deletions
diff --git a/lib/pure/asyncdispatch.nim b/lib/pure/asyncdispatch.nim index d93afce6c..208e83872 100644 --- a/lib/pure/asyncdispatch.nim +++ b/lib/pure/asyncdispatch.nim @@ -11,8 +11,9 @@ include "system/inclrtl" import os, oids, tables, strutils, macros -import rawsockets -export TPort +import rawsockets, net + +export TPort, TSocketFlags #{.injectStmt: newGcInvariant().} @@ -353,7 +354,7 @@ when defined(windows) or defined(nimdoc): return retFuture proc recv*(socket: TAsyncFD, size: int, - flags: int = 0): PFuture[string] = + flags = {TSocketFlags.SafeDisconn}): PFuture[string] = ## Reads **up to** ``size`` bytes from ``socket``. Returned future will ## complete once all the data requested is read, a part of the data has been ## read, or the socket has disconnected in which case the future will @@ -373,7 +374,7 @@ when defined(windows) or defined(nimdoc): dataBuf.len = size var bytesReceived: DWord - var flagsio = flags.DWord + var flagsio = flags.toOSFlags().DWord var ol = PCustomOverlapped() GC_ref(ol) ol.data = TCompletionData(sock: socket, cb: @@ -403,7 +404,10 @@ when defined(windows) or defined(nimdoc): dealloc dataBuf.buf dataBuf.buf = nil GC_unref(ol) - retFuture.fail(newException(EOS, osErrorMsg(err))) + if flags.isDisconnectionError(err): + retFuture.complete("") + else: + retFuture.fail(newException(EOS, osErrorMsg(err))) elif ret == 0 and bytesReceived == 0 and dataBuf.buf[0] == '\0': # We have to ensure that the buffer is empty because WSARecv will tell # us immediatelly when it was disconnected, even when there is still @@ -434,7 +438,8 @@ when defined(windows) or defined(nimdoc): # free ``ol``. return retFuture - proc send*(socket: TAsyncFD, data: string): PFuture[void] = + proc send*(socket: TAsyncFD, data: string, + flags = {TSocketFlags.SafeDisconn}): PFuture[void] = ## Sends ``data`` to ``socket``. The returned future will complete once all ## data has been sent. verifyPresence(socket) @@ -444,7 +449,7 @@ when defined(windows) or defined(nimdoc): dataBuf.buf = data # since this is not used in a callback, this is fine dataBuf.len = data.len - var bytesReceived, flags: DWord + var bytesReceived, lowFlags: DWord var ol = PCustomOverlapped() GC_ref(ol) ol.data = TCompletionData(sock: socket, cb: @@ -457,12 +462,15 @@ when defined(windows) or defined(nimdoc): ) let ret = WSASend(socket.TSocketHandle, addr dataBuf, 1, addr bytesReceived, - flags, cast[POverlapped](ol), nil) + lowFlags, cast[POverlapped](ol), nil) if ret == -1: let err = osLastError() if err.int32 != ERROR_IO_PENDING: - retFuture.fail(newException(EOS, osErrorMsg(err))) GC_unref(ol) + if flags.isDisconnectionError(err): + retFuture.complete() + else: + retFuture.fail(newException(EOS, osErrorMsg(err))) else: retFuture.complete() # We don't deallocate ``ol`` here because even though this completed @@ -706,7 +714,7 @@ else: return retFuture proc recv*(socket: TAsyncFD, size: int, - flags: int = 0): PFuture[string] = + flags = {TSocketFlags.SafeDisconn}): PFuture[string] = var retFuture = newFuture[string]() var readBuffer = newString(size) @@ -719,7 +727,10 @@ else: if res < 0: let lastError = osLastError() if lastError.int32 notin {EINTR, EWOULDBLOCK, EAGAIN}: - retFuture.fail(newException(EOS, osErrorMsg(lastError))) + if flags.isDisconnectionError(lastError): + retFuture.complete("") + else: + retFuture.fail(newException(EOS, osErrorMsg(lastError))) else: result = false # We still want this callback to be called. elif res == 0: @@ -733,7 +744,8 @@ else: addRead(socket, cb) return retFuture - proc send*(socket: TAsyncFD, data: string): PFuture[void] = + proc send*(socket: TAsyncFD, data: string, + flags = {TSocketFlags.SafeDisconn}): PFuture[void] = var retFuture = newFuture[void]() var written = 0 @@ -747,7 +759,10 @@ else: if res < 0: let lastError = osLastError() if lastError.int32 notin {EINTR, EWOULDBLOCK, EAGAIN}: - retFuture.fail(newException(EOS, osErrorMsg(lastError))) + if flags.isDisconnectionError(lastError): + retFuture.complete("") + else: + retFuture.fail(newException(EOS, osErrorMsg(lastError))) else: result = false # We still want this callback to be called. else: @@ -1065,7 +1080,7 @@ proc recvLine*(socket: TAsyncFD): PFuture[string] {.async.} = if c.len == 0: return "" if c == "\r": - c = await recv(socket, 1, MSG_PEEK) + c = await recv(socket, 1, {TSocketFlags.SafeDisconn, TSocketFlags.Peek}) if c.len > 0 and c == "\L": discard await recv(socket, 1) addNLIfEmpty() diff --git a/lib/pure/asyncnet.nim b/lib/pure/asyncnet.nim index fb9f1a26b..374ac77e3 100644 --- a/lib/pure/asyncnet.nim +++ b/lib/pure/asyncnet.nim @@ -80,7 +80,8 @@ proc connect*(socket: PAsyncSocket, address: string, port: TPort, ## or an error occurs. result = connect(socket.fd.TAsyncFD, address, port, af) -proc readIntoBuf(socket: PAsyncSocket, flags: int): PFuture[int] {.async.} = +proc readIntoBuf(socket: PAsyncSocket, + flags: set[TSocketFlags]): PFuture[int] {.async.} = var data = await recv(socket.fd.TAsyncFD, BufferSize, flags) if data.len != 0: copyMem(addr socket.buffer[0], addr data[0], data.len) @@ -89,7 +90,7 @@ proc readIntoBuf(socket: PAsyncSocket, flags: int): PFuture[int] {.async.} = result = data.len proc recv*(socket: PAsyncSocket, size: int, - flags: int = 0): PFuture[string] {.async.} = + flags = {TSocketFlags.SafeDisconn}): PFuture[string] {.async.} = ## Reads ``size`` bytes from ``socket``. Returned future will complete once ## all of the requested data is read. If socket is disconnected during the ## recv operation then the future may complete with only a part of the @@ -100,7 +101,7 @@ proc recv*(socket: PAsyncSocket, size: int, let originalBufPos = socket.currPos if socket.bufLen == 0: - let res = await socket.readIntoBuf(flags and (not MSG_PEEK)) + let res = await socket.readIntoBuf(flags - {TSocketFlags.Peek}) if res == 0: result.setLen(0) return @@ -108,10 +109,10 @@ proc recv*(socket: PAsyncSocket, size: int, var read = 0 while read < size: if socket.currPos >= socket.bufLen: - if (flags and MSG_PEEK) == MSG_PEEK: + if TSocketFlags.Peek in flags: # We don't want to get another buffer if we're peeking. break - let res = await socket.readIntoBuf(flags and (not MSG_PEEK)) + let res = await socket.readIntoBuf(flags - {TSocketFlags.Peek}) if res == 0: break @@ -120,18 +121,19 @@ proc recv*(socket: PAsyncSocket, size: int, read.inc(chunk) socket.currPos.inc(chunk) - if (flags and MSG_PEEK) == MSG_PEEK: + if TSocketFlags.Peek in flags: # Restore old buffer cursor position. socket.currPos = originalBufPos result.setLen(read) else: result = await recv(socket.fd.TAsyncFD, size, flags) -proc send*(socket: PAsyncSocket, data: string): PFuture[void] = +proc send*(socket: PAsyncSocket, data: string, + flags = {TSocketFlags.SafeDisconn}): PFuture[void] = ## Sends ``data`` to ``socket``. The returned future will complete once all ## data has been sent. assert socket != nil - result = send(socket.fd.TAsyncFD, data) + result = send(socket.fd.TAsyncFD, data, flags) proc acceptAddr*(socket: PAsyncSocket): PFuture[tuple[address: string, client: PAsyncSocket]] = @@ -166,7 +168,8 @@ proc accept*(socket: PAsyncSocket): PFuture[PAsyncSocket] = retFut.complete(future.read.client) return retFut -proc recvLine*(socket: PAsyncSocket): PFuture[string] {.async.} = +proc recvLine*(socket: PAsyncSocket, + flags = {TSocketFlags.SafeDisconn}): PFuture[string] {.async.} = ## Reads a line of data from ``socket``. Returned future will complete once ## a full line is read or an error occurs. ## @@ -179,21 +182,23 @@ proc recvLine*(socket: PAsyncSocket): PFuture[string] {.async.} = ## If the socket is disconnected in the middle of a line (before ``\r\L`` ## is read) then line will be set to ``""``. ## The partial line **will be lost**. + ## + ## **Warning**: The ``Peek`` flag is not yet implemented. template addNLIfEmpty(): stmt = if result.len == 0: result.add("\c\L") - + assert TSocketFlags.Peek notin flags ## TODO: if socket.isBuffered: result = "" if socket.bufLen == 0: - let res = await socket.readIntoBuf(0) + let res = await socket.readIntoBuf(flags) if res == 0: return var lastR = false while true: if socket.currPos >= socket.bufLen: - let res = await socket.readIntoBuf(0) + let res = await socket.readIntoBuf(flags) if res == 0: result = "" break @@ -214,18 +219,16 @@ proc recvLine*(socket: PAsyncSocket): PFuture[string] {.async.} = result.add socket.buffer[socket.currPos] socket.currPos.inc() else: - - result = "" var c = "" while true: - c = await recv(socket, 1) + c = await recv(socket, 1, flags) if c.len == 0: return "" if c == "\r": - c = await recv(socket, 1, MSG_PEEK) + c = await recv(socket, 1, flags + {TSocketFlags.Peek}) if c.len > 0 and c == "\L": - let dummy = await recv(socket, 1) + let dummy = await recv(socket, 1, flags) assert dummy == "\L" addNLIfEmpty() return diff --git a/lib/pure/net.nim b/lib/pure/net.nim index e34c88327..ddc2bbe2d 100644 --- a/lib/pure/net.nim +++ b/lib/pure/net.nim @@ -350,6 +350,30 @@ type ETimeout* = object of ESynch + TSocketFlags* {.pure.} = enum + Peek, + SafeDisconn ## Ensures disconnection exceptions (ECONNRESET, EPIPE etc) are not thrown. + +proc isDisconnectionError*(flags: set[TSocketFlags], + lastError: TOSErrorCode): bool = + ## Determines whether ``lastError`` is a disconnection error. Only does this + ## if flags contains ``SafeDisconn``. + when useWinVersion: + TSocketFlags.SafeDisconn in flags and + lastError.int32 in {WSAECONNRESET, WSAECONNABORTED, WSAENETRESET, + WSAEDISCON} + else: + TSocketFlags.SafeDisconn in flags and + lastError.int32 in {ECONNRESET, EPIPE, ENETRESET} + +proc toOSFlags*(socketFlags: set[TSocketFlags]): cint = + ## Converts the flags into the underlying OS representation. + for f in socketFlags: + case f + of TSocketFlags.Peek: + result = result or MSG_PEEK + of TSocketFlags.SafeDisconn: continue + proc createSocket(fd: TSocketHandle, isBuff: bool): PSocket = assert fd != osInvalidSocket new(result) @@ -470,7 +494,8 @@ when defined(ssl): if SSLSetFd(socket.sslHandle, socket.fd) != 1: SSLError() -proc socketError*(socket: PSocket, err: int = -1, async = false) = +proc socketError*(socket: PSocket, err: int = -1, async = false, + lastError = (-1).TOSErrorCode) = ## Raises an EOS error based on the error code returned by ``SSLGetError`` ## (for SSL sockets) and ``osLastError`` otherwise. ## @@ -500,17 +525,17 @@ proc socketError*(socket: PSocket, err: int = -1, async = false) = else: SSLError("Unknown Error") if err == -1 and not (when defined(ssl): socket.isSSL else: false): - let lastError = osLastError() + let lastE = if lastError.int == -1: osLastError() else: lastError if async: when useWinVersion: - if lastError.int32 == WSAEWOULDBLOCK: + if lastE.int32 == WSAEWOULDBLOCK: return - else: osError(lastError) + else: osError(lastE) else: - if lastError.int32 == EAGAIN or lastError.int32 == EWOULDBLOCK: + if lastE.int32 == EAGAIN or lastE.int32 == EWOULDBLOCK: return - else: osError(lastError) - else: osError(lastError) + else: osError(lastE) + else: osError(lastE) proc listen*(socket: PSocket, backlog = SOMAXCONN) {.tags: [FReadIO].} = ## Marks ``socket`` as accepting connections. @@ -881,7 +906,8 @@ proc recv*(socket: PSocket, data: pointer, size: int, timeout: int): int {. result = read -proc recv*(socket: PSocket, data: var string, size: int, timeout = -1): int = +proc recv*(socket: PSocket, data: var string, size: int, timeout = -1, + flags = {TSocketFlags.SafeDisconn}): int = ## Higher-level version of ``recv``. ## ## When 0 is returned the socket's connection has been closed. @@ -893,11 +919,15 @@ proc recv*(socket: PSocket, data: var string, size: int, timeout = -1): int = ## within the time specified an ETimeout exception will be raised. ## ## **Note**: ``data`` must be initialised. + ## + ## **Warning**: Only the ``SafeDisconn`` flag is currently supported. data.setLen(size) result = recv(socket, cstring(data), size, timeout) if result < 0: data.setLen(0) - socket.socketError(result) + let lastError = osLastError() + if flags.isDisconnectionError(lastError): return + socket.socketError(result, lastError = lastError) data.setLen(result) proc peekChar(socket: PSocket, c: var char): int {.tags: [FReadIO].} = @@ -920,7 +950,8 @@ proc peekChar(socket: PSocket, c: var char): int {.tags: [FReadIO].} = return result = recv(socket.fd, addr(c), 1, MSG_PEEK) -proc readLine*(socket: PSocket, line: var TaintedString, timeout = -1) {. +proc readLine*(socket: PSocket, line: var TaintedString, timeout = -1, + flags = {TSocketFlags.SafeDisconn}) {. tags: [FReadIO, FTime].} = ## Reads a line of data from ``socket``. ## @@ -934,11 +965,18 @@ proc readLine*(socket: PSocket, line: var TaintedString, timeout = -1) {. ## ## A timeout can be specified in miliseconds, if data is not received within ## the specified time an ETimeout exception will be raised. + ## + ## **Warning**: Only the ``SafeDisconn`` flag is currently supported. template addNLIfEmpty(): stmt = if line.len == 0: line.add("\c\L") + template raiseSockError(): stmt {.dirty, immediate.} = + let lastError = osLastError() + if flags.isDisconnectionError(lastError): setLen(line.string, 0); return + socket.socketError(n, lastError = lastError) + var waited = 0.0 setLen(line.string, 0) @@ -946,14 +984,14 @@ proc readLine*(socket: PSocket, line: var TaintedString, timeout = -1) {. var c: char discard waitFor(socket, waited, timeout, 1, "readLine") var n = recv(socket, addr(c), 1) - if n < 0: socket.socketError() - elif n == 0: return + if n < 0: raiseSockError() + elif n == 0: setLen(line.string, 0); return if c == '\r': discard waitFor(socket, waited, timeout, 1, "readLine") n = peekChar(socket, c) if n > 0 and c == '\L': discard recv(socket, addr(c), 1) - elif n <= 0: socket.socketError() + elif n <= 0: raiseSockError() addNLIfEmpty() return elif c == '\L': @@ -1021,11 +1059,14 @@ proc send*(socket: PSocket, data: pointer, size: int): int {. const MSG_NOSIGNAL = 0 result = send(socket.fd, data, size, int32(MSG_NOSIGNAL)) -proc send*(socket: PSocket, data: string) {.tags: [FWriteIO].} = +proc send*(socket: PSocket, data: string, + flags = {TSocketFlags.SafeDisconn}) {.tags: [FWriteIO].} = ## sends data to a socket. let sent = send(socket, cstring(data), data.len) if sent < 0: - socketError(socket) + let lastError = osLastError() + if flags.isDisconnectionError(lastError): return + socketError(socket, lastError = lastError) if sent != data.len: raise newException(EOS, "Could not send all data.") diff --git a/lib/pure/rawsockets.nim b/lib/pure/rawsockets.nim index 94189fd89..d96741846 100644 --- a/lib/pure/rawsockets.nim +++ b/lib/pure/rawsockets.nim @@ -21,11 +21,12 @@ const useWinVersion = defined(Windows) or defined(nimdoc) when useWinVersion: import winlean - export WSAEWOULDBLOCK + export WSAEWOULDBLOCK, WSAECONNRESET, WSAECONNABORTED, WSAENETRESET, + WSAEDISCON else: import posix export fcntl, F_GETFL, O_NONBLOCK, F_SETFL, EAGAIN, EWOULDBLOCK, MSG_NOSIGNAL, - EINTR, EINPROGRESS + EINTR, EINPROGRESS, ECONNRESET, EPIPE, ENETRESET export TSocketHandle, TSockaddr_in, TAddrinfo, INADDR_ANY, TSockAddr, TSockLen, inet_ntoa, recv, `==`, connect, send, accept, recvfrom, sendto |