summary refs log tree commit diff stats
path: root/lib/pure/asyncnet.nim
diff options
context:
space:
mode:
Diffstat (limited to 'lib/pure/asyncnet.nim')
-rw-r--r--lib/pure/asyncnet.nim131
1 files changed, 117 insertions, 14 deletions
diff --git a/lib/pure/asyncnet.nim b/lib/pure/asyncnet.nim
index a1988f4a6..3b64c278f 100644
--- a/lib/pure/asyncnet.nim
+++ b/lib/pure/asyncnet.nim
@@ -62,6 +62,8 @@ import os
 
 export SOBool
 
+# TODO: Remove duplication introduced by PR #4683.
+
 const defineSsl = defined(ssl) or defined(nimdoc)
 
 when defineSsl:
@@ -157,15 +159,23 @@ when defineSsl:
       await socket.fd.AsyncFd.send(data, flags)
 
   proc appeaseSsl(socket: AsyncSocket, flags: set[SocketFlag],
-                  sslError: cint) {.async.} =
+                  sslError: cint): Future[bool] {.async.} =
+    ## Returns ``true`` if ``socket`` is still connected, otherwise ``false``.
+    result = true
     case sslError
     of SSL_ERROR_WANT_WRITE:
       await sendPendingSslData(socket, flags)
     of SSL_ERROR_WANT_READ:
       var data = await recv(socket.fd.AsyncFD, BufferSize, flags)
-      let ret = bioWrite(socket.bioIn, addr data[0], data.len.cint)
-      if ret < 0:
-        raiseSSLError()
+      let length = len(data)
+      if length > 0:
+        let ret = bioWrite(socket.bioIn, addr data[0], data.len.cint)
+        if ret < 0:
+          raiseSSLError()
+      elif length == 0:
+        # connection not properly closed by remote side or connection dropped
+        SSL_set_shutdown(socket.sslHandle, SSL_RECEIVED_SHUTDOWN)
+        result = false
     else:
       raiseSSLError("Cannot appease SSL.")
 
@@ -173,13 +183,27 @@ when defineSsl:
                    op: expr) =
     var opResult {.inject.} = -1.cint
     while opResult < 0:
+      # Call the desired operation.
       opResult = op
       # Bit hackish here.
       # TODO: Introduce an async template transformation pragma?
+
+      # Send any remaining pending SSL data.
       yield sendPendingSslData(socket, flags)
+
+      # If the operation failed, try to see if SSL has some data to read
+      # or write.
       if opResult < 0:
         let err = getSslError(socket.sslHandle, opResult.cint)
-        yield appeaseSsl(socket, flags, err.cint)
+        let fut = appeaseSsl(socket, flags, err.cint)
+        yield fut
+        if not fut.read():
+          # Socket disconnected.
+          if SocketFlag.SafeDisconn in flags:
+            break
+          else:
+            raiseSSLError("Socket has been disconnected")
+
 
 proc connect*(socket: AsyncSocket, address: string, port: Port) {.async.} =
   ## Connects ``socket`` to server at ``address:port``.
@@ -193,7 +217,7 @@ proc connect*(socket: AsyncSocket, address: string, port: Port) {.async.} =
       sslSetConnectState(socket.sslHandle)
       sslLoop(socket, flags, sslDoHandshake(socket.sslHandle))
 
-template readInto(buf: cstring, size: int, socket: AsyncSocket,
+template readInto(buf: pointer, size: int, socket: AsyncSocket,
                   flags: set[SocketFlag]): int =
   ## Reads **up to** ``size`` bytes from ``socket`` into ``buf``. Note that
   ## this is a template and not a proc.
@@ -202,10 +226,10 @@ template readInto(buf: cstring, size: int, socket: AsyncSocket,
     when defineSsl:
       # SSL mode.
       sslLoop(socket, flags,
-        sslRead(socket.sslHandle, buf, size.cint))
+        sslRead(socket.sslHandle, cast[cstring](buf), size.cint))
       res = opResult
   else:
-    var recvIntoFut = recvInto(socket.fd.AsyncFD, buf, size, flags)
+    var recvIntoFut = asyncdispatch.recvInto(socket.fd.AsyncFD, buf, size, flags)
     yield recvIntoFut
     # Not in SSL mode.
     res = recvIntoFut.read()
@@ -218,6 +242,54 @@ template readIntoBuf(socket: AsyncSocket,
   socket.bufLen = size
   size
 
+proc recvInto*(socket: AsyncSocket, buf: pointer, size: int,
+           flags = {SocketFlag.SafeDisconn}): Future[int] {.async.} =
+  ## Reads **up to** ``size`` bytes from ``socket`` into ``buf``.
+  ##
+  ## For buffered sockets this function will attempt to read all the requested
+  ## data. It will read this data in ``BufferSize`` chunks.
+  ##
+  ## For unbuffered sockets this function makes no effort to read
+  ## all the data requested. It will return as much data as the operating system
+  ## gives it.
+  ##
+  ## If socket is disconnected during the
+  ## recv operation then the future may complete with only a part of the
+  ## requested data.
+  ##
+  ## If socket is disconnected and no data is available
+  ## to be read then the future will complete with a value of ``0``.
+  if socket.isBuffered:
+    let originalBufPos = socket.currPos
+
+    if socket.bufLen == 0:
+      let res = socket.readIntoBuf(flags - {SocketFlag.Peek})
+      if res == 0:
+        return 0
+
+    var read = 0
+    var cbuf = cast[cstring](buf)
+    while read < size:
+      if socket.currPos >= socket.bufLen:
+        if SocketFlag.Peek in flags:
+          # We don't want to get another buffer if we're peeking.
+          break
+        let res = socket.readIntoBuf(flags - {SocketFlag.Peek})
+        if res == 0:
+          break
+
+      let chunk = min(socket.bufLen-socket.currPos, size-read)
+      copyMem(addr(cbuf[read]), addr(socket.buffer[socket.currPos]), chunk)
+      read.inc(chunk)
+      socket.currPos.inc(chunk)
+
+    if SocketFlag.Peek in flags:
+      # Restore old buffer cursor position.
+      socket.currPos = originalBufPos
+    result = read
+  else:
+    result = readInto(buf, size, socket, flags)
+
 proc recv*(socket: AsyncSocket, size: int,
            flags = {SocketFlag.SafeDisconn}): Future[string] {.async.} =
   ## Reads **up to** ``size`` bytes from ``socket``.
@@ -270,6 +342,19 @@ proc recv*(socket: AsyncSocket, size: int,
     let read = readInto(addr result[0], size, socket, flags)
     result.setLen(read)
 
+proc send*(socket: AsyncSocket, buf: pointer, size: int,
+            flags = {SocketFlag.SafeDisconn}) {.async.} =
+  ## Sends ``size`` bytes from ``buf`` to ``socket``. The returned future will complete once all
+  ## data has been sent.
+  assert socket != nil
+  if socket.isSsl:
+    when defineSsl:
+      sslLoop(socket, flags,
+              sslWrite(socket.sslHandle, cast[cstring](buf), size.cint))
+      await sendPendingSslData(socket, flags)
+  else:
+    await send(socket.fd.AsyncFD, buf, size, flags)
+
 proc send*(socket: AsyncSocket, data: string,
            flags = {SocketFlag.SafeDisconn}) {.async.} =
   ## Sends ``data`` to ``socket``. The returned future will complete once all
@@ -320,7 +405,7 @@ proc accept*(socket: AsyncSocket,
   return retFut
 
 proc recvLineInto*(socket: AsyncSocket, resString: FutureVar[string],
-    flags = {SocketFlag.SafeDisconn}) {.async.} =
+    flags = {SocketFlag.SafeDisconn}, maxLength = MaxLineLength) {.async.} =
   ## Reads a line of data from ``socket`` into ``resString``.
   ##
   ## If a full line is read ``\r\L`` is not
@@ -333,13 +418,14 @@ proc recvLineInto*(socket: AsyncSocket, resString: FutureVar[string],
   ## is read) then line will be set to ``""``.
   ## The partial line **will be lost**.
   ##
+  ## The ``maxLength`` parameter determines the maximum amount of characters
+  ## that can be read before a ``ValueError`` is raised. This prevents Denial
+  ## of Service (DOS) attacks.
+  ##
   ## **Warning**: The ``Peek`` flag is not yet implemented.
   ##
   ## **Warning**: ``recvLineInto`` on unbuffered sockets assumes that the
   ## protocol uses ``\r\L`` to delimit a new line.
-  ##
-  ## **Warning**: ``recvLineInto`` currently uses a raw pointer to a string for
-  ## performance reasons. This will likely change soon to use FutureVars.
   assert SocketFlag.Peek notin flags ## TODO:
   assert(not resString.mget.isNil(),
          "String inside resString future needs to be initialised")
@@ -386,6 +472,12 @@ proc recvLineInto*(socket: AsyncSocket, resString: FutureVar[string],
         else:
           resString.mget.add socket.buffer[socket.currPos]
       socket.currPos.inc()
+
+      # Verify that this isn't a DOS attack: #3847.
+      if resString.mget.len > maxLength:
+        let msg = "recvLine received more than the specified `maxLength` " &
+                  "allowed."
+        raise newException(ValueError, msg)
   else:
     var c = ""
     while true:
@@ -407,10 +499,17 @@ proc recvLineInto*(socket: AsyncSocket, resString: FutureVar[string],
         resString.complete()
         return
       resString.mget.add c
+
+      # Verify that this isn't a DOS attack: #3847.
+      if resString.mget.len > maxLength:
+        let msg = "recvLine received more than the specified `maxLength` " &
+                  "allowed."
+        raise newException(ValueError, msg)
   resString.complete()
 
 proc recvLine*(socket: AsyncSocket,
-    flags = {SocketFlag.SafeDisconn}): Future[string] {.async.} =
+    flags = {SocketFlag.SafeDisconn},
+    maxLength = MaxLineLength): Future[string] {.async.} =
   ## Reads a line of data from ``socket``. Returned future will complete once
   ## a full line is read or an error occurs.
   ##
@@ -424,6 +523,10 @@ proc recvLine*(socket: AsyncSocket,
   ## is read) then line will be set to ``""``.
   ## The partial line **will be lost**.
   ##
+  ## The ``maxLength`` parameter determines the maximum amount of characters
+  ## that can be read before a ``ValueError`` is raised. This prevents Denial
+  ## of Service (DOS) attacks.
+  ##
   ## **Warning**: The ``Peek`` flag is not yet implemented.
   ##
   ## **Warning**: ``recvLine`` on unbuffered sockets assumes that the protocol
@@ -433,7 +536,7 @@ proc recvLine*(socket: AsyncSocket,
   # TODO: Optimise this
   var resString = newFutureVar[string]("asyncnet.recvLine")
   resString.mget() = ""
-  await socket.recvLineInto(resString, flags)
+  await socket.recvLineInto(resString, flags, maxLength)
   result = resString.mget()
 
 proc listen*(socket: AsyncSocket, backlog = SOMAXCONN) {.tags: [ReadIOEffect].} =