diff options
-rw-r--r-- | lib/posix/posix.nim | 13 | ||||
-rw-r--r-- | lib/pure/asyncdispatch.nim | 81 | ||||
-rw-r--r-- | lib/pure/asynchttpserver.nim | 181 | ||||
-rw-r--r-- | lib/pure/asyncnet.nim | 97 | ||||
-rw-r--r-- | lib/pure/httpclient.nim | 2 | ||||
-rw-r--r-- | lib/pure/net.nim | 71 | ||||
-rw-r--r-- | lib/pure/rawsockets.nim | 5 | ||||
-rw-r--r-- | lib/pure/selectors.nim | 2 | ||||
-rw-r--r-- | tests/async/tasyncawait.nim | 6 | ||||
-rw-r--r-- | tests/async/tasyncdiscard.nim | 2 | ||||
-rw-r--r-- | tests/async/tnestedpfuturetypeparam.nim | 2 |
11 files changed, 294 insertions, 168 deletions
diff --git a/lib/posix/posix.nim b/lib/posix/posix.nim index cdca826ca..8e66336c2 100644 --- a/lib/posix/posix.nim +++ b/lib/posix/posix.nim @@ -1578,8 +1578,17 @@ var ## Terminates a record (if supported by the protocol). MSG_OOB* {.importc, header: "<sys/socket.h>".}: cint ## Out-of-band data. - MSG_NOSIGNAL* {.importc, header: "<sys/socket.h>".}: cint - ## No SIGPIPE generated when an attempt to send is made on a stream-oriented socket that is no longer connected. + +when defined(macosx): + var + MSG_HAVEMORE* {.importc, header: "<sys/socket.h>".}: cint + MSG_NOSIGNAL* = MSG_HAVEMORE +else: + var + MSG_NOSIGNAL* {.importc, header: "<sys/socket.h>".}: cint + ## No SIGPIPE generated when an attempt to send is made on a stream-oriented socket that is no longer connected. + +var MSG_PEEK* {.importc, header: "<sys/socket.h>".}: cint ## Leave received data in queue. MSG_TRUNC* {.importc, header: "<sys/socket.h>".}: cint diff --git a/lib/pure/asyncdispatch.nim b/lib/pure/asyncdispatch.nim index 12329951c..6ace947d3 100644 --- a/lib/pure/asyncdispatch.nim +++ b/lib/pure/asyncdispatch.nim @@ -11,8 +11,9 @@ include "system/inclrtl" import os, oids, tables, strutils, macros -import rawsockets -export TPort +import rawsockets, net + +export TPort, TSocketFlags #{.injectStmt: newGcInvariant().} @@ -126,6 +127,15 @@ proc failed*(future: PFutureBase): bool = ## Determines whether ``future`` completed with an error. future.error != nil +proc asyncCheck*[T](future: PFuture[T]) = + ## Sets a callback on ``future`` which raises an exception if the future + ## finished with an error. + ## + ## This should be used instead of ``discard`` to discard void futures. + future.callback = + proc () = + if future.failed: raise future.error + when defined(windows) or defined(nimdoc): import winlean, sets, hashes type @@ -344,7 +354,7 @@ when defined(windows) or defined(nimdoc): return retFuture proc recv*(socket: TAsyncFD, size: int, - flags: int = 0): PFuture[string] = + flags = {TSocketFlags.SafeDisconn}): PFuture[string] = ## Reads **up to** ``size`` bytes from ``socket``. Returned future will ## complete once all the data requested is read, a part of the data has been ## read, or the socket has disconnected in which case the future will @@ -364,7 +374,7 @@ when defined(windows) or defined(nimdoc): dataBuf.len = size var bytesReceived: DWord - var flagsio = flags.DWord + var flagsio = flags.toOSFlags().DWord var ol = PCustomOverlapped() GC_ref(ol) ol.data = TCompletionData(sock: socket, cb: @@ -394,7 +404,10 @@ when defined(windows) or defined(nimdoc): dealloc dataBuf.buf dataBuf.buf = nil GC_unref(ol) - retFuture.fail(newException(EOS, osErrorMsg(err))) + if flags.isDisconnectionError(err): + retFuture.complete("") + else: + retFuture.fail(newException(EOS, osErrorMsg(err))) elif ret == 0 and bytesReceived == 0 and dataBuf.buf[0] == '\0': # We have to ensure that the buffer is empty because WSARecv will tell # us immediatelly when it was disconnected, even when there is still @@ -425,7 +438,8 @@ when defined(windows) or defined(nimdoc): # free ``ol``. return retFuture - proc send*(socket: TAsyncFD, data: string): PFuture[void] = + proc send*(socket: TAsyncFD, data: string, + flags = {TSocketFlags.SafeDisconn}): PFuture[void] = ## Sends ``data`` to ``socket``. The returned future will complete once all ## data has been sent. verifyPresence(socket) @@ -435,7 +449,7 @@ when defined(windows) or defined(nimdoc): dataBuf.buf = data # since this is not used in a callback, this is fine dataBuf.len = data.len - var bytesReceived, flags: DWord + var bytesReceived, lowFlags: DWord var ol = PCustomOverlapped() GC_ref(ol) ol.data = TCompletionData(sock: socket, cb: @@ -448,12 +462,15 @@ when defined(windows) or defined(nimdoc): ) let ret = WSASend(socket.TSocketHandle, addr dataBuf, 1, addr bytesReceived, - flags, cast[POverlapped](ol), nil) + lowFlags, cast[POverlapped](ol), nil) if ret == -1: let err = osLastError() if err.int32 != ERROR_IO_PENDING: - retFuture.fail(newException(EOS, osErrorMsg(err))) GC_unref(ol) + if flags.isDisconnectionError(err): + retFuture.complete() + else: + retFuture.fail(newException(EOS, osErrorMsg(err))) else: retFuture.complete() # We don't deallocate ``ol`` here because even though this completed @@ -552,7 +569,18 @@ when defined(windows) or defined(nimdoc): initAll() else: import selectors - from posix import EINTR, EAGAIN, EINPROGRESS, EWOULDBLOCK, MSG_PEEK + when defined(windows): + import winlean + const + EINTR = WSAEINPROGRESS + EINPROGRESS = WSAEINPROGRESS + EWOULDBLOCK = WSAEWOULDBLOCK + EAGAIN = EINPROGRESS + MSG_NOSIGNAL = 0 + else: + from posix import EINTR, EAGAIN, EINPROGRESS, EWOULDBLOCK, MSG_PEEK, + MSG_NOSIGNAL + type TAsyncFD* = distinct cint TCallback = proc (sock: TAsyncFD): bool {.closure,gcsafe.} @@ -686,20 +714,23 @@ else: return retFuture proc recv*(socket: TAsyncFD, size: int, - flags: int = 0): PFuture[string] = + flags = {TSocketFlags.SafeDisconn}): PFuture[string] = var retFuture = newFuture[string]() var readBuffer = newString(size) proc cb(sock: TAsyncFD): bool = result = true - let res = recv(sock.TSocketHandle, addr readBuffer[0], size, - flags.cint) + let res = recv(sock.TSocketHandle, addr readBuffer[0], size.cint, + flags.toOSFlags()) #echo("recv cb res: ", res) if res < 0: let lastError = osLastError() - if lastError.int32 notin {EINTR, EWOULDBLOCK, EAGAIN}: - retFuture.fail(newException(EOS, osErrorMsg(lastError))) + if lastError.int32 notin {EINTR, EWOULDBLOCK, EAGAIN}: + if flags.isDisconnectionError(lastError): + retFuture.complete("") + else: + retFuture.fail(newException(EOS, osErrorMsg(lastError))) else: result = false # We still want this callback to be called. elif res == 0: @@ -708,11 +739,13 @@ else: else: readBuffer.setLen(res) retFuture.complete(readBuffer) - + # TODO: The following causes a massive slowdown. + #if not cb(socket): addRead(socket, cb) return retFuture - proc send*(socket: TAsyncFD, data: string): PFuture[void] = + proc send*(socket: TAsyncFD, data: string, + flags = {TSocketFlags.SafeDisconn}): PFuture[void] = var retFuture = newFuture[void]() var written = 0 @@ -721,11 +754,15 @@ else: result = true let netSize = data.len-written var d = data.cstring - let res = send(sock.TSocketHandle, addr d[written], netSize, 0.cint) + let res = send(sock.TSocketHandle, addr d[written], netSize.cint, + MSG_NOSIGNAL) if res < 0: let lastError = osLastError() if lastError.int32 notin {EINTR, EWOULDBLOCK, EAGAIN}: - retFuture.fail(newException(EOS, osErrorMsg(lastError))) + if flags.isDisconnectionError(lastError): + retFuture.complete() + else: + retFuture.fail(newException(EOS, osErrorMsg(lastError))) else: result = false # We still want this callback to be called. else: @@ -734,6 +771,8 @@ else: result = false # We still have data to send. else: retFuture.complete() + # TODO: The following causes crashes. + #if not cb(socket): addWrite(socket, cb) return retFuture @@ -1006,8 +1045,6 @@ macro async*(prc: stmt): stmt {.immediate.} = result[4].del(i) if subtypeIsVoid: # Add discardable pragma. - if prc.kind == nnkProcDef: # TODO: This is a workaround for #1287 - result[4].add(newIdentNode("discardable")) if returnType.kind == nnkEmpty: # Add PFuture[void] result[3][0] = parseExpr("PFuture[void]") @@ -1043,7 +1080,7 @@ proc recvLine*(socket: TAsyncFD): PFuture[string] {.async.} = if c.len == 0: return "" if c == "\r": - c = await recv(socket, 1, MSG_PEEK) + c = await recv(socket, 1, {TSocketFlags.SafeDisconn, TSocketFlags.Peek}) if c.len > 0 and c == "\L": discard await recv(socket, 1) addNLIfEmpty() diff --git a/lib/pure/asynchttpserver.nim b/lib/pure/asynchttpserver.nim index 1b47cf5f1..ee6658fd1 100644 --- a/lib/pure/asynchttpserver.nim +++ b/lib/pure/asynchttpserver.nim @@ -51,10 +51,15 @@ proc `==`*(protocol: tuple[orig: string, major, minor: int], proc newAsyncHttpServer*(): PAsyncHttpServer = new result -proc sendHeaders*(req: TRequest, headers: PStringTable) {.async.} = - ## Sends the specified headers to the requesting client. +proc addHeaders(msg: var string, headers: PStringTable) = for k, v in headers: - await req.client.send(k & ": " & v & "\c\L") + msg.add(k & ": " & v & "\c\L") + +proc sendHeaders*(req: TRequest, headers: PStringTable): PFuture[void] = + ## Sends the specified headers to the requesting client. + var msg = "" + addHeaders(msg, headers) + return req.client.send(msg) proc respond*(req: TRequest, code: THttpCode, content: string, headers: PStringTable = newStringTable()) {.async.} = @@ -64,9 +69,9 @@ proc respond*(req: TRequest, code: THttpCode, ## This procedure will **not** close the client socket. var customHeaders = headers customHeaders["Content-Length"] = $content.len - await req.client.send("HTTP/1.1 " & $code & "\c\L") - await sendHeaders(req, headers) - await req.client.send("\c\L" & content) + var msg = "HTTP/1.1 " & $code & "\c\L" + msg.addHeaders(customHeaders) + await req.client.send(msg & "\c\L" & content) proc newRequest(): TRequest = result.headers = newStringTable(modeCaseInsensitive) @@ -93,90 +98,90 @@ proc sendStatus(client: PAsyncSocket, status: string): PFuture[void] = proc processClient(client: PAsyncSocket, address: string, callback: proc (request: TRequest): PFuture[void]) {.async.} = - # GET /path HTTP/1.1 - # Header: val - # \n - var request = newRequest() - request.hostname = address - assert client != nil - request.client = client - var runCallback = true - - # First line - GET /path HTTP/1.1 - let line = await client.recvLine() # TODO: Timeouts. - if line == "": - client.close() - return - let lineParts = line.split(' ') - if lineParts.len != 3: - request.respond(Http400, "Invalid request. Got: " & line) - runCallback = false - - let reqMethod = lineParts[0] - let path = lineParts[1] - let protocol = lineParts[2] - - # Headers - var i = 0 while true: - i = 0 - let headerLine = await client.recvLine() - if headerLine == "": - client.close(); return - if headerLine == "\c\L": break - # TODO: Compiler crash - #let (key, value) = parseHeader(headerLine) - let kv = parseHeader(headerLine) - request.headers[kv.key] = kv.value - - request.reqMethod = reqMethod - request.url = parseUrl(path) - try: - request.protocol = protocol.parseProtocol() - except EInvalidValue: - request.respond(Http400, "Invalid request protocol. Got: " & protocol) - runCallback = false - - if reqMethod.normalize == "post": - # Check for Expect header - if request.headers.hasKey("Expect"): - if request.headers["Expect"].toLower == "100-continue": - await client.sendStatus("100 Continue") + # GET /path HTTP/1.1 + # Header: val + # \n + var request = newRequest() + request.hostname = address + assert client != nil + request.client = client + + # First line - GET /path HTTP/1.1 + let line = await client.recvLine() # TODO: Timeouts. + if line == "": + client.close() + return + let lineParts = line.split(' ') + if lineParts.len != 3: + await request.respond(Http400, "Invalid request. Got: " & line) + continue + + let reqMethod = lineParts[0] + let path = lineParts[1] + let protocol = lineParts[2] + + # Headers + var i = 0 + while true: + i = 0 + let headerLine = await client.recvLine() + if headerLine == "": + client.close(); return + if headerLine == "\c\L": break + # TODO: Compiler crash + #let (key, value) = parseHeader(headerLine) + let kv = parseHeader(headerLine) + request.headers[kv.key] = kv.value + + request.reqMethod = reqMethod + request.url = parseUrl(path) + try: + request.protocol = protocol.parseProtocol() + except EInvalidValue: + asyncCheck request.respond(Http400, "Invalid request protocol. Got: " & + protocol) + continue + + if reqMethod.normalize == "post": + # Check for Expect header + if request.headers.hasKey("Expect"): + if request.headers["Expect"].toLower == "100-continue": + await client.sendStatus("100 Continue") + else: + await client.sendStatus("417 Expectation Failed") + + # Read the body + # - Check for Content-length header + if request.headers.hasKey("Content-Length"): + var contentLength = 0 + if parseInt(request.headers["Content-Length"], contentLength) == 0: + await request.respond(Http400, "Bad Request. Invalid Content-Length.") + else: + request.body = await client.recv(contentLength) + assert request.body.len == contentLength else: - await client.sendStatus("417 Expectation Failed") - - # Read the body - # - Check for Content-length header - if request.headers.hasKey("Content-Length"): - var contentLength = 0 - if parseInt(request.headers["Content-Length"], contentLength) == 0: - await request.respond(Http400, "Bad Request. Invalid Content-Length.") - else: - request.body = await client.recv(contentLength) - assert request.body.len == contentLength - else: - await request.respond(Http400, "Bad Request. No Content-Length.") - runCallback = false + await request.respond(Http400, "Bad Request. No Content-Length.") + continue - case reqMethod.normalize - of "get", "post", "head", "put", "delete", "trace", "options", "connect", "patch": - if runCallback: + case reqMethod.normalize + of "get", "post", "head", "put", "delete", "trace", "options", "connect", "patch": await callback(request) - else: - await request.respond(Http400, "Invalid request method. Got: " & reqMethod) - - # Persistent connections - if (request.protocol == HttpVer11 and - request.headers["connection"].normalize != "close") or - (request.protocol == HttpVer10 and - request.headers["connection"].normalize == "keep-alive"): - # In HTTP 1.1 we assume that connection is persistent. Unless connection - # header states otherwise. - # In HTTP 1.0 we assume that the connection should not be persistent. - # Unless the connection header states otherwise. - await processClient(client, address, callback) - else: - request.client.close() + else: + await request.respond(Http400, "Invalid request method. Got: " & reqMethod) + + # Persistent connections + if (request.protocol == HttpVer11 and + request.headers["connection"].normalize != "close") or + (request.protocol == HttpVer10 and + request.headers["connection"].normalize == "keep-alive"): + # In HTTP 1.1 we assume that connection is persistent. Unless connection + # header states otherwise. + # In HTTP 1.0 we assume that the connection should not be persistent. + # Unless the connection header states otherwise. + else: + request.client.close() + break proc serve*(server: PAsyncHttpServer, port: TPort, callback: proc (request: TRequest): PFuture[void], @@ -193,7 +198,7 @@ proc serve*(server: PAsyncHttpServer, port: TPort, # TODO: Causes compiler crash. #var (address, client) = await server.socket.acceptAddr() var fut = await server.socket.acceptAddr() - processClient(fut.client, fut.address, callback) + asyncCheck processClient(fut.client, fut.address, callback) proc close*(server: PAsyncHttpServer) = ## Terminates the async http server instance. @@ -208,5 +213,5 @@ when isMainModule: "Content-type": "text/plain; charset=utf-8"} await req.respond(Http200, "Hello World", headers.newStringTable()) - server.serve(TPort(5555), cb) + asyncCheck server.serve(TPort(5555), cb) runForever() diff --git a/lib/pure/asyncnet.nim b/lib/pure/asyncnet.nim index d16c85c58..374ac77e3 100644 --- a/lib/pure/asyncnet.nim +++ b/lib/pure/asyncnet.nim @@ -80,7 +80,8 @@ proc connect*(socket: PAsyncSocket, address: string, port: TPort, ## or an error occurs. result = connect(socket.fd.TAsyncFD, address, port, af) -proc readIntoBuf(socket: PAsyncSocket, flags: int): PFuture[int] {.async.} = +proc readIntoBuf(socket: PAsyncSocket, + flags: set[TSocketFlags]): PFuture[int] {.async.} = var data = await recv(socket.fd.TAsyncFD, BufferSize, flags) if data.len != 0: copyMem(addr socket.buffer[0], addr data[0], data.len) @@ -89,7 +90,7 @@ proc readIntoBuf(socket: PAsyncSocket, flags: int): PFuture[int] {.async.} = result = data.len proc recv*(socket: PAsyncSocket, size: int, - flags: int = 0): PFuture[string] {.async.} = + flags = {TSocketFlags.SafeDisconn}): PFuture[string] {.async.} = ## Reads ``size`` bytes from ``socket``. Returned future will complete once ## all of the requested data is read. If socket is disconnected during the ## recv operation then the future may complete with only a part of the @@ -100,7 +101,7 @@ proc recv*(socket: PAsyncSocket, size: int, let originalBufPos = socket.currPos if socket.bufLen == 0: - let res = await socket.readIntoBuf(flags and (not MSG_PEEK)) + let res = await socket.readIntoBuf(flags - {TSocketFlags.Peek}) if res == 0: result.setLen(0) return @@ -108,32 +109,31 @@ proc recv*(socket: PAsyncSocket, size: int, var read = 0 while read < size: if socket.currPos >= socket.bufLen: - if (flags and MSG_PEEK) == MSG_PEEK: + if TSocketFlags.Peek in flags: # We don't want to get another buffer if we're peeking. - result.setLen(read) - return - let res = await socket.readIntoBuf(flags and (not MSG_PEEK)) + break + let res = await socket.readIntoBuf(flags - {TSocketFlags.Peek}) if res == 0: - result.setLen(read) - return + break let chunk = min(socket.bufLen-socket.currPos, size-read) copyMem(addr(result[read]), addr(socket.buffer[socket.currPos]), chunk) read.inc(chunk) socket.currPos.inc(chunk) - if (flags and MSG_PEEK) == MSG_PEEK: + if TSocketFlags.Peek in flags: # Restore old buffer cursor position. socket.currPos = originalBufPos result.setLen(read) else: result = await recv(socket.fd.TAsyncFD, size, flags) -proc send*(socket: PAsyncSocket, data: string): PFuture[void] = +proc send*(socket: PAsyncSocket, data: string, + flags = {TSocketFlags.SafeDisconn}): PFuture[void] = ## Sends ``data`` to ``socket``. The returned future will complete once all ## data has been sent. assert socket != nil - result = send(socket.fd.TAsyncFD, data) + result = send(socket.fd.TAsyncFD, data, flags) proc acceptAddr*(socket: PAsyncSocket): PFuture[tuple[address: string, client: PAsyncSocket]] = @@ -168,7 +168,8 @@ proc accept*(socket: PAsyncSocket): PFuture[PAsyncSocket] = retFut.complete(future.read.client) return retFut -proc recvLine*(socket: PAsyncSocket): PFuture[string] {.async.} = +proc recvLine*(socket: PAsyncSocket, + flags = {TSocketFlags.SafeDisconn}): PFuture[string] {.async.} = ## Reads a line of data from ``socket``. Returned future will complete once ## a full line is read or an error occurs. ## @@ -181,28 +182,60 @@ proc recvLine*(socket: PAsyncSocket): PFuture[string] {.async.} = ## If the socket is disconnected in the middle of a line (before ``\r\L`` ## is read) then line will be set to ``""``. ## The partial line **will be lost**. - + ## + ## **Warning**: The ``Peek`` flag is not yet implemented. template addNLIfEmpty(): stmt = if result.len == 0: result.add("\c\L") + assert TSocketFlags.Peek notin flags ## TODO: + if socket.isBuffered: + result = "" + if socket.bufLen == 0: + let res = await socket.readIntoBuf(flags) + if res == 0: + return + + var lastR = false + while true: + if socket.currPos >= socket.bufLen: + let res = await socket.readIntoBuf(flags) + if res == 0: + result = "" + break - result = "" - var c = "" - while true: - c = await recv(socket, 1) - if c.len == 0: - return "" - if c == "\r": - c = await recv(socket, 1, MSG_PEEK) - if c.len > 0 and c == "\L": - let dummy = await recv(socket, 1) - assert dummy == "\L" - addNLIfEmpty() - return - elif c == "\L": - addNLIfEmpty() - return - add(result.string, c) + case socket.buffer[socket.currPos] + of '\r': + lastR = true + addNLIfEmpty() + of '\L': + addNLIfEmpty() + socket.currPos.inc() + return + else: + if lastR: + socket.currPos.inc() + return + else: + result.add socket.buffer[socket.currPos] + socket.currPos.inc() + else: + result = "" + var c = "" + while true: + c = await recv(socket, 1, flags) + if c.len == 0: + return "" + if c == "\r": + c = await recv(socket, 1, flags + {TSocketFlags.Peek}) + if c.len > 0 and c == "\L": + let dummy = await recv(socket, 1, flags) + assert dummy == "\L" + addNLIfEmpty() + return + elif c == "\L": + addNLIfEmpty() + return + add(result.string, c) proc bindAddr*(socket: PAsyncSocket, port = TPort(0), address = "") = ## Binds ``address``:``port`` to the socket. @@ -241,7 +274,7 @@ when isMainModule: break else: echo("Got line: ", line) - main() + asyncCheck main() elif test == LowClient: var sock = newAsyncSocket() var f = connect(sock, "irc.freenode.net", TPort(6667)) diff --git a/lib/pure/httpclient.nim b/lib/pure/httpclient.nim index be06a7b8e..9bacc80d6 100644 --- a/lib/pure/httpclient.nim +++ b/lib/pure/httpclient.nim @@ -654,7 +654,7 @@ when isMainModule: resp = await client.request("http://nimrod-lang.org/download.html") echo("Got response: ", resp.status) - main() + asyncCheck main() runForever() else: diff --git a/lib/pure/net.nim b/lib/pure/net.nim index e34c88327..ddc2bbe2d 100644 --- a/lib/pure/net.nim +++ b/lib/pure/net.nim @@ -350,6 +350,30 @@ type ETimeout* = object of ESynch + TSocketFlags* {.pure.} = enum + Peek, + SafeDisconn ## Ensures disconnection exceptions (ECONNRESET, EPIPE etc) are not thrown. + +proc isDisconnectionError*(flags: set[TSocketFlags], + lastError: TOSErrorCode): bool = + ## Determines whether ``lastError`` is a disconnection error. Only does this + ## if flags contains ``SafeDisconn``. + when useWinVersion: + TSocketFlags.SafeDisconn in flags and + lastError.int32 in {WSAECONNRESET, WSAECONNABORTED, WSAENETRESET, + WSAEDISCON} + else: + TSocketFlags.SafeDisconn in flags and + lastError.int32 in {ECONNRESET, EPIPE, ENETRESET} + +proc toOSFlags*(socketFlags: set[TSocketFlags]): cint = + ## Converts the flags into the underlying OS representation. + for f in socketFlags: + case f + of TSocketFlags.Peek: + result = result or MSG_PEEK + of TSocketFlags.SafeDisconn: continue + proc createSocket(fd: TSocketHandle, isBuff: bool): PSocket = assert fd != osInvalidSocket new(result) @@ -470,7 +494,8 @@ when defined(ssl): if SSLSetFd(socket.sslHandle, socket.fd) != 1: SSLError() -proc socketError*(socket: PSocket, err: int = -1, async = false) = +proc socketError*(socket: PSocket, err: int = -1, async = false, + lastError = (-1).TOSErrorCode) = ## Raises an EOS error based on the error code returned by ``SSLGetError`` ## (for SSL sockets) and ``osLastError`` otherwise. ## @@ -500,17 +525,17 @@ proc socketError*(socket: PSocket, err: int = -1, async = false) = else: SSLError("Unknown Error") if err == -1 and not (when defined(ssl): socket.isSSL else: false): - let lastError = osLastError() + let lastE = if lastError.int == -1: osLastError() else: lastError if async: when useWinVersion: - if lastError.int32 == WSAEWOULDBLOCK: + if lastE.int32 == WSAEWOULDBLOCK: return - else: osError(lastError) + else: osError(lastE) else: - if lastError.int32 == EAGAIN or lastError.int32 == EWOULDBLOCK: + if lastE.int32 == EAGAIN or lastE.int32 == EWOULDBLOCK: return - else: osError(lastError) - else: osError(lastError) + else: osError(lastE) + else: osError(lastE) proc listen*(socket: PSocket, backlog = SOMAXCONN) {.tags: [FReadIO].} = ## Marks ``socket`` as accepting connections. @@ -881,7 +906,8 @@ proc recv*(socket: PSocket, data: pointer, size: int, timeout: int): int {. result = read -proc recv*(socket: PSocket, data: var string, size: int, timeout = -1): int = +proc recv*(socket: PSocket, data: var string, size: int, timeout = -1, + flags = {TSocketFlags.SafeDisconn}): int = ## Higher-level version of ``recv``. ## ## When 0 is returned the socket's connection has been closed. @@ -893,11 +919,15 @@ proc recv*(socket: PSocket, data: var string, size: int, timeout = -1): int = ## within the time specified an ETimeout exception will be raised. ## ## **Note**: ``data`` must be initialised. + ## + ## **Warning**: Only the ``SafeDisconn`` flag is currently supported. data.setLen(size) result = recv(socket, cstring(data), size, timeout) if result < 0: data.setLen(0) - socket.socketError(result) + let lastError = osLastError() + if flags.isDisconnectionError(lastError): return + socket.socketError(result, lastError = lastError) data.setLen(result) proc peekChar(socket: PSocket, c: var char): int {.tags: [FReadIO].} = @@ -920,7 +950,8 @@ proc peekChar(socket: PSocket, c: var char): int {.tags: [FReadIO].} = return result = recv(socket.fd, addr(c), 1, MSG_PEEK) -proc readLine*(socket: PSocket, line: var TaintedString, timeout = -1) {. +proc readLine*(socket: PSocket, line: var TaintedString, timeout = -1, + flags = {TSocketFlags.SafeDisconn}) {. tags: [FReadIO, FTime].} = ## Reads a line of data from ``socket``. ## @@ -934,11 +965,18 @@ proc readLine*(socket: PSocket, line: var TaintedString, timeout = -1) {. ## ## A timeout can be specified in miliseconds, if data is not received within ## the specified time an ETimeout exception will be raised. + ## + ## **Warning**: Only the ``SafeDisconn`` flag is currently supported. template addNLIfEmpty(): stmt = if line.len == 0: line.add("\c\L") + template raiseSockError(): stmt {.dirty, immediate.} = + let lastError = osLastError() + if flags.isDisconnectionError(lastError): setLen(line.string, 0); return + socket.socketError(n, lastError = lastError) + var waited = 0.0 setLen(line.string, 0) @@ -946,14 +984,14 @@ proc readLine*(socket: PSocket, line: var TaintedString, timeout = -1) {. var c: char discard waitFor(socket, waited, timeout, 1, "readLine") var n = recv(socket, addr(c), 1) - if n < 0: socket.socketError() - elif n == 0: return + if n < 0: raiseSockError() + elif n == 0: setLen(line.string, 0); return if c == '\r': discard waitFor(socket, waited, timeout, 1, "readLine") n = peekChar(socket, c) if n > 0 and c == '\L': discard recv(socket, addr(c), 1) - elif n <= 0: socket.socketError() + elif n <= 0: raiseSockError() addNLIfEmpty() return elif c == '\L': @@ -1021,11 +1059,14 @@ proc send*(socket: PSocket, data: pointer, size: int): int {. const MSG_NOSIGNAL = 0 result = send(socket.fd, data, size, int32(MSG_NOSIGNAL)) -proc send*(socket: PSocket, data: string) {.tags: [FWriteIO].} = +proc send*(socket: PSocket, data: string, + flags = {TSocketFlags.SafeDisconn}) {.tags: [FWriteIO].} = ## sends data to a socket. let sent = send(socket, cstring(data), data.len) if sent < 0: - socketError(socket) + let lastError = osLastError() + if flags.isDisconnectionError(lastError): return + socketError(socket, lastError = lastError) if sent != data.len: raise newException(EOS, "Could not send all data.") diff --git a/lib/pure/rawsockets.nim b/lib/pure/rawsockets.nim index 94189fd89..d96741846 100644 --- a/lib/pure/rawsockets.nim +++ b/lib/pure/rawsockets.nim @@ -21,11 +21,12 @@ const useWinVersion = defined(Windows) or defined(nimdoc) when useWinVersion: import winlean - export WSAEWOULDBLOCK + export WSAEWOULDBLOCK, WSAECONNRESET, WSAECONNABORTED, WSAENETRESET, + WSAEDISCON else: import posix export fcntl, F_GETFL, O_NONBLOCK, F_SETFL, EAGAIN, EWOULDBLOCK, MSG_NOSIGNAL, - EINTR, EINPROGRESS + EINTR, EINPROGRESS, ECONNRESET, EPIPE, ENETRESET export TSocketHandle, TSockaddr_in, TAddrinfo, INADDR_ANY, TSockAddr, TSockLen, inet_ntoa, recv, `==`, connect, send, accept, recvfrom, sendto diff --git a/lib/pure/selectors.nim b/lib/pure/selectors.nim index 3af5f699c..bd53c2dbf 100644 --- a/lib/pure/selectors.nim +++ b/lib/pure/selectors.nim @@ -163,7 +163,7 @@ elif defined(linux): proc newSelector*(): PSelector = new result result.epollFD = epoll_create(64) - result.events = cast[array[64, epoll_event]](alloc0(sizeof(epoll_event)*64)) + #result.events = cast[array[64, epoll_event]](alloc0(sizeof(epoll_event)*64)) result.fds = initTable[TSocketHandle, PSelectorKey]() if result.epollFD < 0: OSError(OSLastError()) diff --git a/tests/async/tasyncawait.nim b/tests/async/tasyncawait.nim index da4952677..2d65db4bd 100644 --- a/tests/async/tasyncawait.nim +++ b/tests/async/tasyncawait.nim @@ -61,11 +61,11 @@ proc createServer(port: TPort) {.async.} = discard server.TSocketHandle.listen() while true: var client = await accept(server) - readMessages(client) + asyncCheck readMessages(client) # TODO: Test: readMessages(disp, await disp.accept(server)) -createServer(TPort(10335)) -launchSwarm(TPort(10335)) +asyncCheck createServer(TPort(10335)) +asyncCheck launchSwarm(TPort(10335)) while true: poll() if clientCount == swarmSize: break diff --git a/tests/async/tasyncdiscard.nim b/tests/async/tasyncdiscard.nim index 48d8a8c4d..966851acc 100644 --- a/tests/async/tasyncdiscard.nim +++ b/tests/async/tasyncdiscard.nim @@ -36,4 +36,4 @@ proc main {.async.} = discard await g() echo 6 -main() +asyncCheck main() diff --git a/tests/async/tnestedpfuturetypeparam.nim b/tests/async/tnestedpfuturetypeparam.nim index d0d87e567..1db442170 100644 --- a/tests/async/tnestedpfuturetypeparam.nim +++ b/tests/async/tnestedpfuturetypeparam.nim @@ -5,4 +5,4 @@ proc main {.async.} = await newAsyncSocket().connect("www.google.com", TPort(80)) let x = await f() -main() +asyncCheck main() |