summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--lib/pure/asyncdispatch.nim43
-rw-r--r--lib/pure/asyncnet.nim37
-rw-r--r--lib/pure/net.nim71
-rw-r--r--lib/pure/rawsockets.nim5
4 files changed, 108 insertions, 48 deletions
diff --git a/lib/pure/asyncdispatch.nim b/lib/pure/asyncdispatch.nim
index d93afce6c..208e83872 100644
--- a/lib/pure/asyncdispatch.nim
+++ b/lib/pure/asyncdispatch.nim
@@ -11,8 +11,9 @@ include "system/inclrtl"
 
 import os, oids, tables, strutils, macros
 
-import rawsockets
-export TPort
+import rawsockets, net
+
+export TPort, TSocketFlags
 
 #{.injectStmt: newGcInvariant().}
 
@@ -353,7 +354,7 @@ when defined(windows) or defined(nimdoc):
     return retFuture
 
   proc recv*(socket: TAsyncFD, size: int,
-             flags: int = 0): PFuture[string] =
+             flags = {TSocketFlags.SafeDisconn}): PFuture[string] =
     ## Reads **up to** ``size`` bytes from ``socket``. Returned future will
     ## complete once all the data requested is read, a part of the data has been
     ## read, or the socket has disconnected in which case the future will
@@ -373,7 +374,7 @@ when defined(windows) or defined(nimdoc):
     dataBuf.len = size
     
     var bytesReceived: DWord
-    var flagsio = flags.DWord
+    var flagsio = flags.toOSFlags().DWord
     var ol = PCustomOverlapped()
     GC_ref(ol)
     ol.data = TCompletionData(sock: socket, cb:
@@ -403,7 +404,10 @@ when defined(windows) or defined(nimdoc):
           dealloc dataBuf.buf
           dataBuf.buf = nil
         GC_unref(ol)
-        retFuture.fail(newException(EOS, osErrorMsg(err)))
+        if flags.isDisconnectionError(err):
+          retFuture.complete("")
+        else:
+          retFuture.fail(newException(EOS, osErrorMsg(err)))
     elif ret == 0 and bytesReceived == 0 and dataBuf.buf[0] == '\0':
       # We have to ensure that the buffer is empty because WSARecv will tell
       # us immediatelly when it was disconnected, even when there is still
@@ -434,7 +438,8 @@ when defined(windows) or defined(nimdoc):
       # free ``ol``.
     return retFuture
 
-  proc send*(socket: TAsyncFD, data: string): PFuture[void] =
+  proc send*(socket: TAsyncFD, data: string,
+             flags = {TSocketFlags.SafeDisconn}): PFuture[void] =
     ## Sends ``data`` to ``socket``. The returned future will complete once all
     ## data has been sent.
     verifyPresence(socket)
@@ -444,7 +449,7 @@ when defined(windows) or defined(nimdoc):
     dataBuf.buf = data # since this is not used in a callback, this is fine
     dataBuf.len = data.len
 
-    var bytesReceived, flags: DWord
+    var bytesReceived, lowFlags: DWord
     var ol = PCustomOverlapped()
     GC_ref(ol)
     ol.data = TCompletionData(sock: socket, cb:
@@ -457,12 +462,15 @@ when defined(windows) or defined(nimdoc):
     )
 
     let ret = WSASend(socket.TSocketHandle, addr dataBuf, 1, addr bytesReceived,
-                      flags, cast[POverlapped](ol), nil)
+                      lowFlags, cast[POverlapped](ol), nil)
     if ret == -1:
       let err = osLastError()
       if err.int32 != ERROR_IO_PENDING:
-        retFuture.fail(newException(EOS, osErrorMsg(err)))
         GC_unref(ol)
+        if flags.isDisconnectionError(err):
+          retFuture.complete()
+        else:
+          retFuture.fail(newException(EOS, osErrorMsg(err)))
     else:
       retFuture.complete()
       # We don't deallocate ``ol`` here because even though this completed
@@ -706,7 +714,7 @@ else:
     return retFuture
 
   proc recv*(socket: TAsyncFD, size: int,
-             flags: int = 0): PFuture[string] =
+             flags = {TSocketFlags.SafeDisconn}): PFuture[string] =
     var retFuture = newFuture[string]()
     
     var readBuffer = newString(size)
@@ -719,7 +727,10 @@ else:
       if res < 0:
         let lastError = osLastError()
         if lastError.int32 notin {EINTR, EWOULDBLOCK, EAGAIN}:
-          retFuture.fail(newException(EOS, osErrorMsg(lastError)))
+          if flags.isDisconnectionError(lastError):
+            retFuture.complete("")
+          else:
+            retFuture.fail(newException(EOS, osErrorMsg(lastError)))
         else:
           result = false # We still want this callback to be called.
       elif res == 0:
@@ -733,7 +744,8 @@ else:
     addRead(socket, cb)
     return retFuture
 
-  proc send*(socket: TAsyncFD, data: string): PFuture[void] =
+  proc send*(socket: TAsyncFD, data: string,
+             flags = {TSocketFlags.SafeDisconn}): PFuture[void] =
     var retFuture = newFuture[void]()
     
     var written = 0
@@ -747,7 +759,10 @@ else:
       if res < 0:
         let lastError = osLastError()
         if lastError.int32 notin {EINTR, EWOULDBLOCK, EAGAIN}:
-          retFuture.fail(newException(EOS, osErrorMsg(lastError)))
+          if flags.isDisconnectionError(lastError):
+            retFuture.complete("")
+          else:
+            retFuture.fail(newException(EOS, osErrorMsg(lastError)))
         else:
           result = false # We still want this callback to be called.
       else:
@@ -1065,7 +1080,7 @@ proc recvLine*(socket: TAsyncFD): PFuture[string] {.async.} =
     if c.len == 0:
       return ""
     if c == "\r":
-      c = await recv(socket, 1, MSG_PEEK)
+      c = await recv(socket, 1, {TSocketFlags.SafeDisconn, TSocketFlags.Peek})
       if c.len > 0 and c == "\L":
         discard await recv(socket, 1)
       addNLIfEmpty()
diff --git a/lib/pure/asyncnet.nim b/lib/pure/asyncnet.nim
index fb9f1a26b..374ac77e3 100644
--- a/lib/pure/asyncnet.nim
+++ b/lib/pure/asyncnet.nim
@@ -80,7 +80,8 @@ proc connect*(socket: PAsyncSocket, address: string, port: TPort,
   ## or an error occurs.
   result = connect(socket.fd.TAsyncFD, address, port, af)
 
-proc readIntoBuf(socket: PAsyncSocket, flags: int): PFuture[int] {.async.} =
+proc readIntoBuf(socket: PAsyncSocket,
+    flags: set[TSocketFlags]): PFuture[int] {.async.} =
   var data = await recv(socket.fd.TAsyncFD, BufferSize, flags)
   if data.len != 0:
     copyMem(addr socket.buffer[0], addr data[0], data.len)
@@ -89,7 +90,7 @@ proc readIntoBuf(socket: PAsyncSocket, flags: int): PFuture[int] {.async.} =
   result = data.len
 
 proc recv*(socket: PAsyncSocket, size: int,
-           flags: int = 0): PFuture[string] {.async.} =
+           flags = {TSocketFlags.SafeDisconn}): PFuture[string] {.async.} =
   ## Reads ``size`` bytes from ``socket``. Returned future will complete once
   ## all of the requested data is read. If socket is disconnected during the
   ## recv operation then the future may complete with only a part of the
@@ -100,7 +101,7 @@ proc recv*(socket: PAsyncSocket, size: int,
     let originalBufPos = socket.currPos
 
     if socket.bufLen == 0:
-      let res = await socket.readIntoBuf(flags and (not MSG_PEEK))
+      let res = await socket.readIntoBuf(flags - {TSocketFlags.Peek})
       if res == 0:
         result.setLen(0)
         return
@@ -108,10 +109,10 @@ proc recv*(socket: PAsyncSocket, size: int,
     var read = 0
     while read < size:
       if socket.currPos >= socket.bufLen:
-        if (flags and MSG_PEEK) == MSG_PEEK:
+        if TSocketFlags.Peek in flags:
           # We don't want to get another buffer if we're peeking.
           break
-        let res = await socket.readIntoBuf(flags and (not MSG_PEEK))
+        let res = await socket.readIntoBuf(flags - {TSocketFlags.Peek})
         if res == 0:
           break
 
@@ -120,18 +121,19 @@ proc recv*(socket: PAsyncSocket, size: int,
       read.inc(chunk)
       socket.currPos.inc(chunk)
 
-    if (flags and MSG_PEEK) == MSG_PEEK:
+    if TSocketFlags.Peek in flags:
       # Restore old buffer cursor position.
       socket.currPos = originalBufPos
     result.setLen(read)
   else:
     result = await recv(socket.fd.TAsyncFD, size, flags)
 
-proc send*(socket: PAsyncSocket, data: string): PFuture[void] =
+proc send*(socket: PAsyncSocket, data: string,
+           flags = {TSocketFlags.SafeDisconn}): PFuture[void] =
   ## Sends ``data`` to ``socket``. The returned future will complete once all
   ## data has been sent.
   assert socket != nil
-  result = send(socket.fd.TAsyncFD, data)
+  result = send(socket.fd.TAsyncFD, data, flags)
 
 proc acceptAddr*(socket: PAsyncSocket): 
       PFuture[tuple[address: string, client: PAsyncSocket]] =
@@ -166,7 +168,8 @@ proc accept*(socket: PAsyncSocket): PFuture[PAsyncSocket] =
         retFut.complete(future.read.client)
   return retFut
 
-proc recvLine*(socket: PAsyncSocket): PFuture[string] {.async.} =
+proc recvLine*(socket: PAsyncSocket,
+    flags = {TSocketFlags.SafeDisconn}): PFuture[string] {.async.} =
   ## Reads a line of data from ``socket``. Returned future will complete once
   ## a full line is read or an error occurs.
   ##
@@ -179,21 +182,23 @@ proc recvLine*(socket: PAsyncSocket): PFuture[string] {.async.} =
   ## If the socket is disconnected in the middle of a line (before ``\r\L``
   ## is read) then line will be set to ``""``.
   ## The partial line **will be lost**.
+  ##
+  ## **Warning**: The ``Peek`` flag is not yet implemented.
   template addNLIfEmpty(): stmt =
     if result.len == 0:
       result.add("\c\L")
-
+  assert TSocketFlags.Peek notin flags ## TODO:
   if socket.isBuffered:
     result = ""
     if socket.bufLen == 0:
-      let res = await socket.readIntoBuf(0)
+      let res = await socket.readIntoBuf(flags)
       if res == 0:
         return
 
     var lastR = false
     while true:
       if socket.currPos >= socket.bufLen:
-        let res = await socket.readIntoBuf(0)
+        let res = await socket.readIntoBuf(flags)
         if res == 0:
           result = ""
           break
@@ -214,18 +219,16 @@ proc recvLine*(socket: PAsyncSocket): PFuture[string] {.async.} =
           result.add socket.buffer[socket.currPos]
       socket.currPos.inc()
   else:
-    
-
     result = ""
     var c = ""
     while true:
-      c = await recv(socket, 1)
+      c = await recv(socket, 1, flags)
       if c.len == 0:
         return ""
       if c == "\r":
-        c = await recv(socket, 1, MSG_PEEK)
+        c = await recv(socket, 1, flags + {TSocketFlags.Peek})
         if c.len > 0 and c == "\L":
-          let dummy = await recv(socket, 1)
+          let dummy = await recv(socket, 1, flags)
           assert dummy == "\L"
         addNLIfEmpty()
         return
diff --git a/lib/pure/net.nim b/lib/pure/net.nim
index e34c88327..ddc2bbe2d 100644
--- a/lib/pure/net.nim
+++ b/lib/pure/net.nim
@@ -350,6 +350,30 @@ type
 
   ETimeout* = object of ESynch
 
+  TSocketFlags* {.pure.} = enum
+    Peek,
+    SafeDisconn ## Ensures disconnection exceptions (ECONNRESET, EPIPE etc) are not thrown.
+
+proc isDisconnectionError*(flags: set[TSocketFlags],
+    lastError: TOSErrorCode): bool =
+  ## Determines whether ``lastError`` is a disconnection error. Only does this
+  ## if flags contains ``SafeDisconn``.
+  when useWinVersion:
+    TSocketFlags.SafeDisconn in flags and
+      lastError.int32 in {WSAECONNRESET, WSAECONNABORTED, WSAENETRESET,
+                          WSAEDISCON}
+  else:
+    TSocketFlags.SafeDisconn in flags and
+      lastError.int32 in {ECONNRESET, EPIPE, ENETRESET} 
+
+proc toOSFlags*(socketFlags: set[TSocketFlags]): cint =
+  ## Converts the flags into the underlying OS representation.
+  for f in socketFlags:
+    case f
+    of TSocketFlags.Peek:
+      result = result or MSG_PEEK
+    of TSocketFlags.SafeDisconn: continue
+
 proc createSocket(fd: TSocketHandle, isBuff: bool): PSocket =
   assert fd != osInvalidSocket
   new(result)
@@ -470,7 +494,8 @@ when defined(ssl):
     if SSLSetFd(socket.sslHandle, socket.fd) != 1:
       SSLError()
 
-proc socketError*(socket: PSocket, err: int = -1, async = false) =
+proc socketError*(socket: PSocket, err: int = -1, async = false,
+                  lastError = (-1).TOSErrorCode) =
   ## Raises an EOS error based on the error code returned by ``SSLGetError``
   ## (for SSL sockets) and ``osLastError`` otherwise.
   ##
@@ -500,17 +525,17 @@ proc socketError*(socket: PSocket, err: int = -1, async = false) =
         else: SSLError("Unknown Error")
   
   if err == -1 and not (when defined(ssl): socket.isSSL else: false):
-    let lastError = osLastError()
+    let lastE = if lastError.int == -1: osLastError() else: lastError
     if async:
       when useWinVersion:
-        if lastError.int32 == WSAEWOULDBLOCK:
+        if lastE.int32 == WSAEWOULDBLOCK:
           return
-        else: osError(lastError)
+        else: osError(lastE)
       else:
-        if lastError.int32 == EAGAIN or lastError.int32 == EWOULDBLOCK:
+        if lastE.int32 == EAGAIN or lastE.int32 == EWOULDBLOCK:
           return
-        else: osError(lastError)
-    else: osError(lastError)
+        else: osError(lastE)
+    else: osError(lastE)
 
 proc listen*(socket: PSocket, backlog = SOMAXCONN) {.tags: [FReadIO].} =
   ## Marks ``socket`` as accepting connections. 
@@ -881,7 +906,8 @@ proc recv*(socket: PSocket, data: pointer, size: int, timeout: int): int {.
   
   result = read
 
-proc recv*(socket: PSocket, data: var string, size: int, timeout = -1): int =
+proc recv*(socket: PSocket, data: var string, size: int, timeout = -1,
+           flags = {TSocketFlags.SafeDisconn}): int =
   ## Higher-level version of ``recv``.
   ##
   ## When 0 is returned the socket's connection has been closed.
@@ -893,11 +919,15 @@ proc recv*(socket: PSocket, data: var string, size: int, timeout = -1): int =
   ## within the time specified an ETimeout exception will be raised.
   ##
   ## **Note**: ``data`` must be initialised.
+  ##
+  ## **Warning**: Only the ``SafeDisconn`` flag is currently supported.
   data.setLen(size)
   result = recv(socket, cstring(data), size, timeout)
   if result < 0:
     data.setLen(0)
-    socket.socketError(result)
+    let lastError = osLastError()
+    if flags.isDisconnectionError(lastError): return
+    socket.socketError(result, lastError = lastError)
   data.setLen(result)
 
 proc peekChar(socket: PSocket, c: var char): int {.tags: [FReadIO].} =
@@ -920,7 +950,8 @@ proc peekChar(socket: PSocket, c: var char): int {.tags: [FReadIO].} =
         return
     result = recv(socket.fd, addr(c), 1, MSG_PEEK)
 
-proc readLine*(socket: PSocket, line: var TaintedString, timeout = -1) {.
+proc readLine*(socket: PSocket, line: var TaintedString, timeout = -1,
+               flags = {TSocketFlags.SafeDisconn}) {.
   tags: [FReadIO, FTime].} =
   ## Reads a line of data from ``socket``.
   ##
@@ -934,11 +965,18 @@ proc readLine*(socket: PSocket, line: var TaintedString, timeout = -1) {.
   ##
   ## A timeout can be specified in miliseconds, if data is not received within
   ## the specified time an ETimeout exception will be raised.
+  ##
+  ## **Warning**: Only the ``SafeDisconn`` flag is currently supported.
   
   template addNLIfEmpty(): stmt =
     if line.len == 0:
       line.add("\c\L")
 
+  template raiseSockError(): stmt {.dirty, immediate.} =
+    let lastError = osLastError()
+    if flags.isDisconnectionError(lastError): setLen(line.string, 0); return
+    socket.socketError(n, lastError = lastError)
+
   var waited = 0.0
 
   setLen(line.string, 0)
@@ -946,14 +984,14 @@ proc readLine*(socket: PSocket, line: var TaintedString, timeout = -1) {.
     var c: char
     discard waitFor(socket, waited, timeout, 1, "readLine")
     var n = recv(socket, addr(c), 1)
-    if n < 0: socket.socketError()
-    elif n == 0: return
+    if n < 0: raiseSockError()
+    elif n == 0: setLen(line.string, 0); return
     if c == '\r':
       discard waitFor(socket, waited, timeout, 1, "readLine")
       n = peekChar(socket, c)
       if n > 0 and c == '\L':
         discard recv(socket, addr(c), 1)
-      elif n <= 0: socket.socketError()
+      elif n <= 0: raiseSockError()
       addNLIfEmpty()
       return
     elif c == '\L': 
@@ -1021,11 +1059,14 @@ proc send*(socket: PSocket, data: pointer, size: int): int {.
       const MSG_NOSIGNAL = 0
     result = send(socket.fd, data, size, int32(MSG_NOSIGNAL))
 
-proc send*(socket: PSocket, data: string) {.tags: [FWriteIO].} =
+proc send*(socket: PSocket, data: string,
+           flags = {TSocketFlags.SafeDisconn}) {.tags: [FWriteIO].} =
   ## sends data to a socket.
   let sent = send(socket, cstring(data), data.len)
   if sent < 0:
-    socketError(socket)
+    let lastError = osLastError()
+    if flags.isDisconnectionError(lastError): return
+    socketError(socket, lastError = lastError)
 
   if sent != data.len:
     raise newException(EOS, "Could not send all data.")
diff --git a/lib/pure/rawsockets.nim b/lib/pure/rawsockets.nim
index 94189fd89..d96741846 100644
--- a/lib/pure/rawsockets.nim
+++ b/lib/pure/rawsockets.nim
@@ -21,11 +21,12 @@ const useWinVersion = defined(Windows) or defined(nimdoc)
 
 when useWinVersion:
   import winlean
-  export WSAEWOULDBLOCK
+  export WSAEWOULDBLOCK, WSAECONNRESET, WSAECONNABORTED, WSAENETRESET,
+         WSAEDISCON
 else:
   import posix
   export fcntl, F_GETFL, O_NONBLOCK, F_SETFL, EAGAIN, EWOULDBLOCK, MSG_NOSIGNAL,
-    EINTR, EINPROGRESS
+    EINTR, EINPROGRESS, ECONNRESET, EPIPE, ENETRESET
 
 export TSocketHandle, TSockaddr_in, TAddrinfo, INADDR_ANY, TSockAddr, TSockLen,
   inet_ntoa, recv, `==`, connect, send, accept, recvfrom, sendto