summary refs log tree commit diff stats
path: root/lib/pure/asyncnet.nim
diff options
context:
space:
mode:
Diffstat (limited to 'lib/pure/asyncnet.nim')
-rw-r--r--lib/pure/asyncnet.nim51
1 files changed, 43 insertions, 8 deletions
diff --git a/lib/pure/asyncnet.nim b/lib/pure/asyncnet.nim
index 334f95baa..3b64c278f 100644
--- a/lib/pure/asyncnet.nim
+++ b/lib/pure/asyncnet.nim
@@ -159,7 +159,9 @@ when defineSsl:
       await socket.fd.AsyncFd.send(data, flags)
 
   proc appeaseSsl(socket: AsyncSocket, flags: set[SocketFlag],
-                  sslError: cint) {.async.} =
+                  sslError: cint): Future[bool] {.async.} =
+    ## Returns ``true`` if ``socket`` is still connected, otherwise ``false``.
+    result = true
     case sslError
     of SSL_ERROR_WANT_WRITE:
       await sendPendingSslData(socket, flags)
@@ -173,6 +175,7 @@ when defineSsl:
       elif length == 0:
         # connection not properly closed by remote side or connection dropped
         SSL_set_shutdown(socket.sslHandle, SSL_RECEIVED_SHUTDOWN)
+        result = false
     else:
       raiseSSLError("Cannot appease SSL.")
 
@@ -180,13 +183,27 @@ when defineSsl:
                    op: expr) =
     var opResult {.inject.} = -1.cint
     while opResult < 0:
+      # Call the desired operation.
       opResult = op
       # Bit hackish here.
       # TODO: Introduce an async template transformation pragma?
+
+      # Send any remaining pending SSL data.
       yield sendPendingSslData(socket, flags)
+
+      # If the operation failed, try to see if SSL has some data to read
+      # or write.
       if opResult < 0:
         let err = getSslError(socket.sslHandle, opResult.cint)
-        yield appeaseSsl(socket, flags, err.cint)
+        let fut = appeaseSsl(socket, flags, err.cint)
+        yield fut
+        if not fut.read():
+          # Socket disconnected.
+          if SocketFlag.SafeDisconn in flags:
+            break
+          else:
+            raiseSSLError("Socket has been disconnected")
+
 
 proc connect*(socket: AsyncSocket, address: string, port: Port) {.async.} =
   ## Connects ``socket`` to server at ``address:port``.
@@ -388,7 +405,7 @@ proc accept*(socket: AsyncSocket,
   return retFut
 
 proc recvLineInto*(socket: AsyncSocket, resString: FutureVar[string],
-    flags = {SocketFlag.SafeDisconn}) {.async.} =
+    flags = {SocketFlag.SafeDisconn}, maxLength = MaxLineLength) {.async.} =
   ## Reads a line of data from ``socket`` into ``resString``.
   ##
   ## If a full line is read ``\r\L`` is not
@@ -401,13 +418,14 @@ proc recvLineInto*(socket: AsyncSocket, resString: FutureVar[string],
   ## is read) then line will be set to ``""``.
   ## The partial line **will be lost**.
   ##
+  ## The ``maxLength`` parameter determines the maximum amount of characters
+  ## that can be read before a ``ValueError`` is raised. This prevents Denial
+  ## of Service (DOS) attacks.
+  ##
   ## **Warning**: The ``Peek`` flag is not yet implemented.
   ##
   ## **Warning**: ``recvLineInto`` on unbuffered sockets assumes that the
   ## protocol uses ``\r\L`` to delimit a new line.
-  ##
-  ## **Warning**: ``recvLineInto`` currently uses a raw pointer to a string for
-  ## performance reasons. This will likely change soon to use FutureVars.
   assert SocketFlag.Peek notin flags ## TODO:
   assert(not resString.mget.isNil(),
          "String inside resString future needs to be initialised")
@@ -454,6 +472,12 @@ proc recvLineInto*(socket: AsyncSocket, resString: FutureVar[string],
         else:
           resString.mget.add socket.buffer[socket.currPos]
       socket.currPos.inc()
+
+      # Verify that this isn't a DOS attack: #3847.
+      if resString.mget.len > maxLength:
+        let msg = "recvLine received more than the specified `maxLength` " &
+                  "allowed."
+        raise newException(ValueError, msg)
   else:
     var c = ""
     while true:
@@ -475,10 +499,17 @@ proc recvLineInto*(socket: AsyncSocket, resString: FutureVar[string],
         resString.complete()
         return
       resString.mget.add c
+
+      # Verify that this isn't a DOS attack: #3847.
+      if resString.mget.len > maxLength:
+        let msg = "recvLine received more than the specified `maxLength` " &
+                  "allowed."
+        raise newException(ValueError, msg)
   resString.complete()
 
 proc recvLine*(socket: AsyncSocket,
-    flags = {SocketFlag.SafeDisconn}): Future[string] {.async.} =
+    flags = {SocketFlag.SafeDisconn},
+    maxLength = MaxLineLength): Future[string] {.async.} =
   ## Reads a line of data from ``socket``. Returned future will complete once
   ## a full line is read or an error occurs.
   ##
@@ -492,6 +523,10 @@ proc recvLine*(socket: AsyncSocket,
   ## is read) then line will be set to ``""``.
   ## The partial line **will be lost**.
   ##
+  ## The ``maxLength`` parameter determines the maximum amount of characters
+  ## that can be read before a ``ValueError`` is raised. This prevents Denial
+  ## of Service (DOS) attacks.
+  ##
   ## **Warning**: The ``Peek`` flag is not yet implemented.
   ##
   ## **Warning**: ``recvLine`` on unbuffered sockets assumes that the protocol
@@ -501,7 +536,7 @@ proc recvLine*(socket: AsyncSocket,
   # TODO: Optimise this
   var resString = newFutureVar[string]("asyncnet.recvLine")
   resString.mget() = ""
-  await socket.recvLineInto(resString, flags)
+  await socket.recvLineInto(resString, flags, maxLength)
   result = resString.mget()
 
 proc listen*(socket: AsyncSocket, backlog = SOMAXCONN) {.tags: [ReadIOEffect].} =