summary refs log tree commit diff stats
path: root/lib/pure/net.nim
diff options
context:
space:
mode:
Diffstat (limited to 'lib/pure/net.nim')
-rw-r--r--lib/pure/net.nim165
1 files changed, 115 insertions, 50 deletions
diff --git a/lib/pure/net.nim b/lib/pure/net.nim
index 8afc6c5c5..ffbc6e320 100644
--- a/lib/pure/net.nim
+++ b/lib/pure/net.nim
@@ -1,7 +1,7 @@
 #
 #
 #            Nim's Runtime Library
-#        (c) Copyright 2014 Dominik Picheta
+#        (c) Copyright 2015 Dominik Picheta
 #
 #    See the file "copying.txt", included in this
 #    distribution, for details about the copyright.
@@ -44,23 +44,24 @@ const
 
 type
   SocketImpl* = object ## socket type
-    fd*: SocketHandle
-    case isBuffered*: bool # determines whether this socket is buffered.
+    fd: SocketHandle
+    case isBuffered: bool # determines whether this socket is buffered.
     of true:
-      buffer*: array[0..BufferSize, char]
-      currPos*: int # current index in buffer
-      bufLen*: int # current length of buffer
+      buffer: array[0..BufferSize, char]
+      currPos: int # current index in buffer
+      bufLen: int # current length of buffer
     of false: nil
     when defined(ssl):
-      case isSsl*: bool
+      case isSsl: bool
       of true:
-        sslHandle*: SSLPtr
-        sslContext*: SSLContext
-        sslNoHandshake*: bool # True if needs handshake.
-        sslHasPeekChar*: bool
-        sslPeekChar*: char
+        sslHandle: SSLPtr
+        sslContext: SSLContext
+        sslNoHandshake: bool # True if needs handshake.
+        sslHasPeekChar: bool
+        sslPeekChar: char
       of false: nil
-  
+    lastError: OSErrorCode ## stores the last error on this socket
+
   Socket* = ref SocketImpl
 
   SOBool* = enum ## Boolean socket options.
@@ -80,6 +81,23 @@ type
     TReadLineResult: ReadLineResult, TSOBool: SOBool, PSocket: Socket,
     TSocketImpl: SocketImpl].}
 
+type
+  IpAddressFamily* {.pure.} = enum ## Describes the type of an IP address
+    IPv6, ## IPv6 address
+    IPv4  ## IPv4 address
+
+  TIpAddress* = object ## stores an arbitrary IP address    
+    case family*: IpAddressFamily ## the type of the IP address (IPv4 or IPv6)
+    of IpAddressFamily.IPv6:
+      address_v6*: array[0..15, uint8] ## Contains the IP address in bytes in
+                                       ## case of IPv6
+    of IpAddressFamily.IPv4:
+      address_v4*: array[0..3, uint8] ## Contains the IP address in bytes in
+                                      ## case of IPv4
+
+proc isIpAddress*(address_str: string): bool {.tags: [].}
+proc parseIpAddress*(address_str: string): TIpAddress
+
 proc isDisconnectionError*(flags: set[SocketFlag],
     lastError: OSErrorCode): bool =
   ## Determines whether ``lastError`` is a disconnection error. Only does this
@@ -100,7 +118,8 @@ proc toOSFlags*(socketFlags: set[SocketFlag]): cint =
       result = result or MSG_PEEK
     of SocketFlag.SafeDisconn: continue
 
-proc createSocket(fd: SocketHandle, isBuff: bool): Socket =
+proc newSocket(fd: SocketHandle, isBuff: bool): Socket =
+  ## Creates a new socket as specified by the params.
   assert fd != osInvalidSocket
   new(result)
   result.fd = fd
@@ -115,17 +134,17 @@ proc newSocket*(domain, typ, protocol: cint, buffered = true): Socket =
   let fd = newRawSocket(domain, typ, protocol)
   if fd == osInvalidSocket:
     raiseOSError(osLastError())
-  result = createSocket(fd, buffered)
+  result = newSocket(fd, buffered)
 
 proc newSocket*(domain: Domain = AF_INET, typ: SockType = SOCK_STREAM,
-             protocol: Protocol = IPPROTO_TCP, buffered = true): Socket =
+                protocol: Protocol = IPPROTO_TCP, buffered = true): Socket =
   ## Creates a new socket.
   ##
   ## If an error occurs EOS will be raised.
   let fd = newRawSocket(domain, typ, protocol)
   if fd == osInvalidSocket:
     raiseOSError(osLastError())
-  result = createSocket(fd, buffered)
+  result = newSocket(fd, buffered)
 
 when defined(ssl):
   CRYPTO_malloc_init()
@@ -230,6 +249,15 @@ when defined(ssl):
     if SSLSetFd(socket.sslHandle, socket.fd) != 1:
       raiseSSLError()
 
+proc getSocketError*(socket: Socket): OSErrorCode =
+  ## Checks ``osLastError`` for a valid error. If it has been reset it uses
+  ## the last error stored in the socket object.
+  result = osLastError()
+  if result == 0.OSErrorCode:
+    result = socket.lastError
+  if result == 0.OSErrorCode:
+    raise newException(OSError, "No valid socket error code available")
+
 proc socketError*(socket: Socket, err: int = -1, async = false,
                   lastError = (-1).OSErrorCode) =
   ## Raises an OSError based on the error code returned by ``SSLGetError``
@@ -256,12 +284,26 @@ proc socketError*(socket: Socket, err: int = -1, async = false,
           else: raiseSSLError("Not enough data on socket.")
         of SSL_ERROR_WANT_X509_LOOKUP:
           raiseSSLError("Function for x509 lookup has been called.")
-        of SSL_ERROR_SYSCALL, SSL_ERROR_SSL:
+        of SSL_ERROR_SYSCALL:
+          var errStr = "IO error has occurred "
+          let sslErr = ErrPeekLastError()
+          if sslErr == 0 and err == 0:
+            errStr.add "because an EOF was observed that violates the protocol"
+          elif sslErr == 0 and err == -1:
+            errStr.add "in the BIO layer"
+          else:
+            let errStr = $ErrErrorString(sslErr, nil)
+            raiseSSLError(errStr & ": " & errStr)
+          let osMsg = osErrorMsg osLastError()
+          if osMsg != "":
+            errStr.add ". The OS reports: " & osMsg
+          raise newException(OSError, errStr)
+        of SSL_ERROR_SSL:
           raiseSSLError()
         else: raiseSSLError("Unknown Error")
   
   if err == -1 and not (when defined(ssl): socket.isSSL else: false):
-    let lastE = if lastError.int == -1: osLastError() else: lastError
+    var lastE = if lastError.int == -1: getSocketError(socket) else: lastError
     if async:
       when useWinVersion:
         if lastE.int32 == WSAEWOULDBLOCK:
@@ -279,7 +321,8 @@ proc listen*(socket: Socket, backlog = SOMAXCONN) {.tags: [ReadIOEffect].} =
   ## queue of pending connections.
   ##
   ## Raises an EOS error upon failure.
-  if listen(socket.fd, backlog) < 0'i32: raiseOSError(osLastError())
+  if rawsockets.listen(socket.fd, backlog) < 0'i32:
+    raiseOSError(osLastError())
 
 proc bindAddr*(socket: Socket, port = Port(0), address = "") {.
   tags: [ReadIOEffect].} =
@@ -306,7 +349,8 @@ proc bindAddr*(socket: Socket, port = Port(0), address = "") {.
     dealloc(aiList)
 
 proc acceptAddr*(server: Socket, client: var Socket, address: var string,
-                 flags = {SocketFlag.SafeDisconn}) {.tags: [ReadIOEffect].} =
+                 flags = {SocketFlag.SafeDisconn}) {.
+                 tags: [ReadIOEffect], gcsafe, locks: 0.} =
   ## Blocks until a connection is being made from a client. When a connection
   ## is made sets ``client`` to the client socket and ``address`` to the address
   ## of the connecting client.
@@ -418,15 +462,23 @@ proc accept*(server: Socket, client: var Socket,
 
 proc close*(socket: Socket) =
   ## Closes a socket.
-  socket.fd.close()
-  when defined(ssl):
-    if socket.isSSL:
-      let res = SSLShutdown(socket.sslHandle)
-      if res == 0:
-        if SSLShutdown(socket.sslHandle) != 1:
-          socketError(socket)
-      elif res != 1:
-        socketError(socket)
+  try:
+    when defined(ssl):
+      if socket.isSSL:
+        ErrClearError()
+        # As we are closing the underlying socket immediately afterwards,
+        # it is valid, under the TLS standard, to perform a unidirectional
+        # shutdown i.e not wait for the peers "close notify" alert with a second
+        # call to SSLShutdown
+        let res = SSLShutdown(socket.sslHandle)
+        SSLFree(socket.sslHandle)
+        socket.sslHandle = nil
+        if res == 0:
+          discard
+        elif res != 1:
+          socketError(socket, res)
+  finally:
+    socket.fd.close()
 
 proc toCInt*(opt: SOBool): cint =
   ## Converts a ``SOBool`` into its Socket Option cint representation.
@@ -476,6 +528,12 @@ proc connect*(socket: Socket, address: string, port = Port(0),
   
   when defined(ssl):
     if socket.isSSL:
+      # RFC3546 for SNI specifies that IP addresses are not allowed.
+      if not isIpAddress(address):
+        # Discard result in case OpenSSL version doesn't support SNI, or we're
+        # not using TLSv1+
+        discard SSL_set_tlsext_host_name(socket.sslHandle, address)
+
       let ret = SSLConnect(socket.sslHandle)
       socketError(socket, ret)
 
@@ -547,6 +605,10 @@ proc readIntoBuf(socket: Socket, flags: int32): int =
       result = recv(socket.fd, addr(socket.buffer), cint(socket.buffer.high), flags)
   else:
     result = recv(socket.fd, addr(socket.buffer), cint(socket.buffer.high), flags)
+  if result < 0:
+    # Save it in case it gets reset (the Nim codegen occassionally may call
+    # Win API functions which reset it).
+    socket.lastError = osLastError()
   if result <= 0:
     socket.bufLen = 0
     socket.currPos = 0
@@ -602,6 +664,9 @@ proc recv*(socket: Socket, data: pointer, size: int): int {.tags: [ReadIOEffect]
         result = recv(socket.fd, data, size.cint, 0'i32)
     else:
       result = recv(socket.fd, data, size.cint, 0'i32)
+    if result < 0:
+      # Save the error in case it gets reset.
+      socket.lastError = osLastError()
 
 proc waitFor(socket: Socket, waited: var float, timeout, size: int,
              funcName: string): int {.tags: [TimeEffect].} =
@@ -674,7 +739,7 @@ proc recv*(socket: Socket, data: var string, size: int, timeout = -1,
   result = recv(socket, cstring(data), size, timeout)
   if result < 0:
     data.setLen(0)
-    let lastError = osLastError()
+    let lastError = getSocketError(socket)
     if flags.isDisconnectionError(lastError): return
     socket.socketError(result, lastError = lastError)
   data.setLen(result)
@@ -722,7 +787,7 @@ proc readLine*(socket: Socket, line: var TaintedString, timeout = -1,
       line.add("\c\L")
 
   template raiseSockError(): stmt {.dirty, immediate.} =
-    let lastError = osLastError()
+    let lastError = getSocketError(socket)
     if flags.isDisconnectionError(lastError): setLen(line.string, 0); return
     socket.socketError(n, lastError = lastError)
 
@@ -865,7 +930,7 @@ proc connectAsync(socket: Socket, name: string, port = Port(0),
                   af: Domain = AF_INET) {.tags: [ReadIOEffect].} =
   ## A variant of ``connect`` for non-blocking sockets.
   ##
-  ## This procedure will immediatelly return, it will not block until a connection
+  ## This procedure will immediately return, it will not block until a connection
   ## is made. It is up to the caller to make sure the connection has been established
   ## by checking (using ``select``) whether the socket is writeable.
   ##
@@ -917,26 +982,16 @@ proc connect*(socket: Socket, address: string, port = Port(0), timeout: int,
         doAssert socket.handshake()
   socket.fd.setBlocking(true)
 
-proc isSSL*(socket: Socket): bool = return socket.isSSL
+proc isSsl*(socket: Socket): bool = 
   ## Determines whether ``socket`` is a SSL socket.
+  when defined(ssl):
+    result = socket.isSSL
+  else:
+    result = false
 
-proc getFD*(socket: Socket): SocketHandle = return socket.fd
+proc getFd*(socket: Socket): SocketHandle = return socket.fd
   ## Returns the socket's file descriptor
 
-type
-  IpAddressFamily* {.pure.} = enum ## Describes the type of an IP address
-    IPv6, ## IPv6 address
-    IPv4  ## IPv4 address
-
-  TIpAddress* = object ## stores an arbitrary IP address    
-    case family*: IpAddressFamily ## the type of the IP address (IPv4 or IPv6)
-    of IpAddressFamily.IPv6:
-      address_v6*: array[0..15, uint8] ## Contains the IP address in bytes in
-                                       ## case of IPv6
-    of IpAddressFamily.IPv4:
-      address_v4*: array[0..3, uint8] ## Contains the IP address in bytes in
-                                      ## case of IPv4
-
 proc IPv4_any*(): TIpAddress =
   ## Returns the IPv4 any address, which can be used to listen on all available
   ## network adapters
@@ -1195,7 +1250,7 @@ proc parseIPv6Address(address_str: string): TIpAddress =
     raise newException(ValueError,
       "Invalid IP Address. The address consists of too many groups")
 
-proc parseIpAddress*(address_str: string): TIpAddress =
+proc parseIpAddress(address_str: string): TIpAddress =
   ## Parses an IP address
   ## Raises EInvalidValue on error
   if address_str == nil:
@@ -1204,3 +1259,13 @@ proc parseIpAddress*(address_str: string): TIpAddress =
     return parseIPv6Address(address_str)
   else:
     return parseIPv4Address(address_str)
+
+
+proc isIpAddress(address_str: string): bool =
+  ## Checks if a string is an IP address
+  ## Returns true if it is, false otherwise
+  try:
+    discard parseIpAddress(address_str)
+  except ValueError:
+    return false
+  return true