summary refs log tree commit diff stats
path: root/lib/pure/net.nim
diff options
context:
space:
mode:
Diffstat (limited to 'lib/pure/net.nim')
-rw-r--r--lib/pure/net.nim75
1 files changed, 59 insertions, 16 deletions
diff --git a/lib/pure/net.nim b/lib/pure/net.nim
index 74739630b..ddc2bbe2d 100644
--- a/lib/pure/net.nim
+++ b/lib/pure/net.nim
@@ -11,7 +11,7 @@
 
 {.deadCodeElim: on.}
 import rawsockets, os, strutils, unsigned, parseutils, times
-export TPort
+export TPort, `$`
 
 const useWinVersion = defined(Windows) or defined(nimdoc)
 
@@ -350,6 +350,30 @@ type
 
   ETimeout* = object of ESynch
 
+  TSocketFlags* {.pure.} = enum
+    Peek,
+    SafeDisconn ## Ensures disconnection exceptions (ECONNRESET, EPIPE etc) are not thrown.
+
+proc isDisconnectionError*(flags: set[TSocketFlags],
+    lastError: TOSErrorCode): bool =
+  ## Determines whether ``lastError`` is a disconnection error. Only does this
+  ## if flags contains ``SafeDisconn``.
+  when useWinVersion:
+    TSocketFlags.SafeDisconn in flags and
+      lastError.int32 in {WSAECONNRESET, WSAECONNABORTED, WSAENETRESET,
+                          WSAEDISCON}
+  else:
+    TSocketFlags.SafeDisconn in flags and
+      lastError.int32 in {ECONNRESET, EPIPE, ENETRESET} 
+
+proc toOSFlags*(socketFlags: set[TSocketFlags]): cint =
+  ## Converts the flags into the underlying OS representation.
+  for f in socketFlags:
+    case f
+    of TSocketFlags.Peek:
+      result = result or MSG_PEEK
+    of TSocketFlags.SafeDisconn: continue
+
 proc createSocket(fd: TSocketHandle, isBuff: bool): PSocket =
   assert fd != osInvalidSocket
   new(result)
@@ -470,7 +494,8 @@ when defined(ssl):
     if SSLSetFd(socket.sslHandle, socket.fd) != 1:
       SSLError()
 
-proc socketError*(socket: PSocket, err: int = -1, async = false) =
+proc socketError*(socket: PSocket, err: int = -1, async = false,
+                  lastError = (-1).TOSErrorCode) =
   ## Raises an EOS error based on the error code returned by ``SSLGetError``
   ## (for SSL sockets) and ``osLastError`` otherwise.
   ##
@@ -500,17 +525,17 @@ proc socketError*(socket: PSocket, err: int = -1, async = false) =
         else: SSLError("Unknown Error")
   
   if err == -1 and not (when defined(ssl): socket.isSSL else: false):
-    let lastError = osLastError()
+    let lastE = if lastError.int == -1: osLastError() else: lastError
     if async:
       when useWinVersion:
-        if lastError.int32 == WSAEWOULDBLOCK:
+        if lastE.int32 == WSAEWOULDBLOCK:
           return
-        else: osError(lastError)
+        else: osError(lastE)
       else:
-        if lastError.int32 == EAGAIN or lastError.int32 == EWOULDBLOCK:
+        if lastE.int32 == EAGAIN or lastE.int32 == EWOULDBLOCK:
           return
-        else: osError(lastError)
-    else: osError(lastError)
+        else: osError(lastE)
+    else: osError(lastE)
 
 proc listen*(socket: PSocket, backlog = SOMAXCONN) {.tags: [FReadIO].} =
   ## Marks ``socket`` as accepting connections. 
@@ -805,6 +830,7 @@ proc recv*(socket: PSocket, data: pointer, size: int): int {.tags: [FReadIO].} =
     
       let chunk = min(socket.bufLen-socket.currPos, size-read)
       var d = cast[cstring](data)
+      assert size-read >= chunk
       copyMem(addr(d[read]), addr(socket.buffer[socket.currPos]), chunk)
       read.inc(chunk)
       socket.currPos.inc(chunk)
@@ -871,6 +897,7 @@ proc recv*(socket: PSocket, data: pointer, size: int, timeout: int): int {.
   while read < size:
     let avail = waitFor(socket, waited, timeout, size-read, "recv")
     var d = cast[cstring](data)
+    assert avail <= size-read
     result = recv(socket, addr(d[read]), avail)
     if result == 0: break
     if result < 0:
@@ -879,7 +906,8 @@ proc recv*(socket: PSocket, data: pointer, size: int, timeout: int): int {.
   
   result = read
 
-proc recv*(socket: PSocket, data: var string, size: int, timeout = -1): int =
+proc recv*(socket: PSocket, data: var string, size: int, timeout = -1,
+           flags = {TSocketFlags.SafeDisconn}): int =
   ## Higher-level version of ``recv``.
   ##
   ## When 0 is returned the socket's connection has been closed.
@@ -891,11 +919,15 @@ proc recv*(socket: PSocket, data: var string, size: int, timeout = -1): int =
   ## within the time specified an ETimeout exception will be raised.
   ##
   ## **Note**: ``data`` must be initialised.
+  ##
+  ## **Warning**: Only the ``SafeDisconn`` flag is currently supported.
   data.setLen(size)
   result = recv(socket, cstring(data), size, timeout)
   if result < 0:
     data.setLen(0)
-    socket.socketError(result)
+    let lastError = osLastError()
+    if flags.isDisconnectionError(lastError): return
+    socket.socketError(result, lastError = lastError)
   data.setLen(result)
 
 proc peekChar(socket: PSocket, c: var char): int {.tags: [FReadIO].} =
@@ -918,7 +950,8 @@ proc peekChar(socket: PSocket, c: var char): int {.tags: [FReadIO].} =
         return
     result = recv(socket.fd, addr(c), 1, MSG_PEEK)
 
-proc readLine*(socket: PSocket, line: var TaintedString, timeout = -1) {.
+proc readLine*(socket: PSocket, line: var TaintedString, timeout = -1,
+               flags = {TSocketFlags.SafeDisconn}) {.
   tags: [FReadIO, FTime].} =
   ## Reads a line of data from ``socket``.
   ##
@@ -932,11 +965,18 @@ proc readLine*(socket: PSocket, line: var TaintedString, timeout = -1) {.
   ##
   ## A timeout can be specified in miliseconds, if data is not received within
   ## the specified time an ETimeout exception will be raised.
+  ##
+  ## **Warning**: Only the ``SafeDisconn`` flag is currently supported.
   
   template addNLIfEmpty(): stmt =
     if line.len == 0:
       line.add("\c\L")
 
+  template raiseSockError(): stmt {.dirty, immediate.} =
+    let lastError = osLastError()
+    if flags.isDisconnectionError(lastError): setLen(line.string, 0); return
+    socket.socketError(n, lastError = lastError)
+
   var waited = 0.0
 
   setLen(line.string, 0)
@@ -944,14 +984,14 @@ proc readLine*(socket: PSocket, line: var TaintedString, timeout = -1) {.
     var c: char
     discard waitFor(socket, waited, timeout, 1, "readLine")
     var n = recv(socket, addr(c), 1)
-    if n < 0: socket.socketError()
-    elif n == 0: return
+    if n < 0: raiseSockError()
+    elif n == 0: setLen(line.string, 0); return
     if c == '\r':
       discard waitFor(socket, waited, timeout, 1, "readLine")
       n = peekChar(socket, c)
       if n > 0 and c == '\L':
         discard recv(socket, addr(c), 1)
-      elif n <= 0: socket.socketError()
+      elif n <= 0: raiseSockError()
       addNLIfEmpty()
       return
     elif c == '\L': 
@@ -1019,11 +1059,14 @@ proc send*(socket: PSocket, data: pointer, size: int): int {.
       const MSG_NOSIGNAL = 0
     result = send(socket.fd, data, size, int32(MSG_NOSIGNAL))
 
-proc send*(socket: PSocket, data: string) {.tags: [FWriteIO].} =
+proc send*(socket: PSocket, data: string,
+           flags = {TSocketFlags.SafeDisconn}) {.tags: [FWriteIO].} =
   ## sends data to a socket.
   let sent = send(socket, cstring(data), data.len)
   if sent < 0:
-    socketError(socket)
+    let lastError = osLastError()
+    if flags.isDisconnectionError(lastError): return
+    socketError(socket, lastError = lastError)
 
   if sent != data.len:
     raise newException(EOS, "Could not send all data.")