diff options
Diffstat (limited to 'lib/pure/asyncnet.nim')
-rw-r--r-- | lib/pure/asyncnet.nim | 131 |
1 files changed, 117 insertions, 14 deletions
diff --git a/lib/pure/asyncnet.nim b/lib/pure/asyncnet.nim index a1988f4a6..3b64c278f 100644 --- a/lib/pure/asyncnet.nim +++ b/lib/pure/asyncnet.nim @@ -62,6 +62,8 @@ import os export SOBool +# TODO: Remove duplication introduced by PR #4683. + const defineSsl = defined(ssl) or defined(nimdoc) when defineSsl: @@ -157,15 +159,23 @@ when defineSsl: await socket.fd.AsyncFd.send(data, flags) proc appeaseSsl(socket: AsyncSocket, flags: set[SocketFlag], - sslError: cint) {.async.} = + sslError: cint): Future[bool] {.async.} = + ## Returns ``true`` if ``socket`` is still connected, otherwise ``false``. + result = true case sslError of SSL_ERROR_WANT_WRITE: await sendPendingSslData(socket, flags) of SSL_ERROR_WANT_READ: var data = await recv(socket.fd.AsyncFD, BufferSize, flags) - let ret = bioWrite(socket.bioIn, addr data[0], data.len.cint) - if ret < 0: - raiseSSLError() + let length = len(data) + if length > 0: + let ret = bioWrite(socket.bioIn, addr data[0], data.len.cint) + if ret < 0: + raiseSSLError() + elif length == 0: + # connection not properly closed by remote side or connection dropped + SSL_set_shutdown(socket.sslHandle, SSL_RECEIVED_SHUTDOWN) + result = false else: raiseSSLError("Cannot appease SSL.") @@ -173,13 +183,27 @@ when defineSsl: op: expr) = var opResult {.inject.} = -1.cint while opResult < 0: + # Call the desired operation. opResult = op # Bit hackish here. # TODO: Introduce an async template transformation pragma? + + # Send any remaining pending SSL data. yield sendPendingSslData(socket, flags) + + # If the operation failed, try to see if SSL has some data to read + # or write. if opResult < 0: let err = getSslError(socket.sslHandle, opResult.cint) - yield appeaseSsl(socket, flags, err.cint) + let fut = appeaseSsl(socket, flags, err.cint) + yield fut + if not fut.read(): + # Socket disconnected. + if SocketFlag.SafeDisconn in flags: + break + else: + raiseSSLError("Socket has been disconnected") + proc connect*(socket: AsyncSocket, address: string, port: Port) {.async.} = ## Connects ``socket`` to server at ``address:port``. @@ -193,7 +217,7 @@ proc connect*(socket: AsyncSocket, address: string, port: Port) {.async.} = sslSetConnectState(socket.sslHandle) sslLoop(socket, flags, sslDoHandshake(socket.sslHandle)) -template readInto(buf: cstring, size: int, socket: AsyncSocket, +template readInto(buf: pointer, size: int, socket: AsyncSocket, flags: set[SocketFlag]): int = ## Reads **up to** ``size`` bytes from ``socket`` into ``buf``. Note that ## this is a template and not a proc. @@ -202,10 +226,10 @@ template readInto(buf: cstring, size: int, socket: AsyncSocket, when defineSsl: # SSL mode. sslLoop(socket, flags, - sslRead(socket.sslHandle, buf, size.cint)) + sslRead(socket.sslHandle, cast[cstring](buf), size.cint)) res = opResult else: - var recvIntoFut = recvInto(socket.fd.AsyncFD, buf, size, flags) + var recvIntoFut = asyncdispatch.recvInto(socket.fd.AsyncFD, buf, size, flags) yield recvIntoFut # Not in SSL mode. res = recvIntoFut.read() @@ -218,6 +242,54 @@ template readIntoBuf(socket: AsyncSocket, socket.bufLen = size size +proc recvInto*(socket: AsyncSocket, buf: pointer, size: int, + flags = {SocketFlag.SafeDisconn}): Future[int] {.async.} = + ## Reads **up to** ``size`` bytes from ``socket`` into ``buf``. + ## + ## For buffered sockets this function will attempt to read all the requested + ## data. It will read this data in ``BufferSize`` chunks. + ## + ## For unbuffered sockets this function makes no effort to read + ## all the data requested. It will return as much data as the operating system + ## gives it. + ## + ## If socket is disconnected during the + ## recv operation then the future may complete with only a part of the + ## requested data. + ## + ## If socket is disconnected and no data is available + ## to be read then the future will complete with a value of ``0``. + if socket.isBuffered: + let originalBufPos = socket.currPos + + if socket.bufLen == 0: + let res = socket.readIntoBuf(flags - {SocketFlag.Peek}) + if res == 0: + return 0 + + var read = 0 + var cbuf = cast[cstring](buf) + while read < size: + if socket.currPos >= socket.bufLen: + if SocketFlag.Peek in flags: + # We don't want to get another buffer if we're peeking. + break + let res = socket.readIntoBuf(flags - {SocketFlag.Peek}) + if res == 0: + break + + let chunk = min(socket.bufLen-socket.currPos, size-read) + copyMem(addr(cbuf[read]), addr(socket.buffer[socket.currPos]), chunk) + read.inc(chunk) + socket.currPos.inc(chunk) + + if SocketFlag.Peek in flags: + # Restore old buffer cursor position. + socket.currPos = originalBufPos + result = read + else: + result = readInto(buf, size, socket, flags) + proc recv*(socket: AsyncSocket, size: int, flags = {SocketFlag.SafeDisconn}): Future[string] {.async.} = ## Reads **up to** ``size`` bytes from ``socket``. @@ -270,6 +342,19 @@ proc recv*(socket: AsyncSocket, size: int, let read = readInto(addr result[0], size, socket, flags) result.setLen(read) +proc send*(socket: AsyncSocket, buf: pointer, size: int, + flags = {SocketFlag.SafeDisconn}) {.async.} = + ## Sends ``size`` bytes from ``buf`` to ``socket``. The returned future will complete once all + ## data has been sent. + assert socket != nil + if socket.isSsl: + when defineSsl: + sslLoop(socket, flags, + sslWrite(socket.sslHandle, cast[cstring](buf), size.cint)) + await sendPendingSslData(socket, flags) + else: + await send(socket.fd.AsyncFD, buf, size, flags) + proc send*(socket: AsyncSocket, data: string, flags = {SocketFlag.SafeDisconn}) {.async.} = ## Sends ``data`` to ``socket``. The returned future will complete once all @@ -320,7 +405,7 @@ proc accept*(socket: AsyncSocket, return retFut proc recvLineInto*(socket: AsyncSocket, resString: FutureVar[string], - flags = {SocketFlag.SafeDisconn}) {.async.} = + flags = {SocketFlag.SafeDisconn}, maxLength = MaxLineLength) {.async.} = ## Reads a line of data from ``socket`` into ``resString``. ## ## If a full line is read ``\r\L`` is not @@ -333,13 +418,14 @@ proc recvLineInto*(socket: AsyncSocket, resString: FutureVar[string], ## is read) then line will be set to ``""``. ## The partial line **will be lost**. ## + ## The ``maxLength`` parameter determines the maximum amount of characters + ## that can be read before a ``ValueError`` is raised. This prevents Denial + ## of Service (DOS) attacks. + ## ## **Warning**: The ``Peek`` flag is not yet implemented. ## ## **Warning**: ``recvLineInto`` on unbuffered sockets assumes that the ## protocol uses ``\r\L`` to delimit a new line. - ## - ## **Warning**: ``recvLineInto`` currently uses a raw pointer to a string for - ## performance reasons. This will likely change soon to use FutureVars. assert SocketFlag.Peek notin flags ## TODO: assert(not resString.mget.isNil(), "String inside resString future needs to be initialised") @@ -386,6 +472,12 @@ proc recvLineInto*(socket: AsyncSocket, resString: FutureVar[string], else: resString.mget.add socket.buffer[socket.currPos] socket.currPos.inc() + + # Verify that this isn't a DOS attack: #3847. + if resString.mget.len > maxLength: + let msg = "recvLine received more than the specified `maxLength` " & + "allowed." + raise newException(ValueError, msg) else: var c = "" while true: @@ -407,10 +499,17 @@ proc recvLineInto*(socket: AsyncSocket, resString: FutureVar[string], resString.complete() return resString.mget.add c + + # Verify that this isn't a DOS attack: #3847. + if resString.mget.len > maxLength: + let msg = "recvLine received more than the specified `maxLength` " & + "allowed." + raise newException(ValueError, msg) resString.complete() proc recvLine*(socket: AsyncSocket, - flags = {SocketFlag.SafeDisconn}): Future[string] {.async.} = + flags = {SocketFlag.SafeDisconn}, + maxLength = MaxLineLength): Future[string] {.async.} = ## Reads a line of data from ``socket``. Returned future will complete once ## a full line is read or an error occurs. ## @@ -424,6 +523,10 @@ proc recvLine*(socket: AsyncSocket, ## is read) then line will be set to ``""``. ## The partial line **will be lost**. ## + ## The ``maxLength`` parameter determines the maximum amount of characters + ## that can be read before a ``ValueError`` is raised. This prevents Denial + ## of Service (DOS) attacks. + ## ## **Warning**: The ``Peek`` flag is not yet implemented. ## ## **Warning**: ``recvLine`` on unbuffered sockets assumes that the protocol @@ -433,7 +536,7 @@ proc recvLine*(socket: AsyncSocket, # TODO: Optimise this var resString = newFutureVar[string]("asyncnet.recvLine") resString.mget() = "" - await socket.recvLineInto(resString, flags) + await socket.recvLineInto(resString, flags, maxLength) result = resString.mget() proc listen*(socket: AsyncSocket, backlog = SOMAXCONN) {.tags: [ReadIOEffect].} = |