summary refs log tree commit diff stats
path: root/lib/pure/net.nim
diff options
context:
space:
mode:
authorDaniil Yarancev <21169548+Yardanico@users.noreply.github.com>2018-06-05 21:25:45 +0300
committerGitHub <noreply@github.com>2018-06-05 21:25:45 +0300
commit642641359821b6a63c6cf7edaaa45873b7ea59c7 (patch)
tree627af3020528cb916b3174bd94304307ca875c77 /lib/pure/net.nim
parentfb44c522e6173528efa8035ecc459c84887d0167 (diff)
parent3cbc07ac7877b03c605498760fe198e3200cc197 (diff)
downloadNim-642641359821b6a63c6cf7edaaa45873b7ea59c7.tar.gz
Merge pull request #2 from nim-lang/devel
Update
Diffstat (limited to 'lib/pure/net.nim')
-rw-r--r--lib/pure/net.nim133
1 files changed, 94 insertions, 39 deletions
diff --git a/lib/pure/net.nim b/lib/pure/net.nim
index aad6ab3e8..bf5f3f57e 100644
--- a/lib/pure/net.nim
+++ b/lib/pure/net.nim
@@ -64,8 +64,11 @@
 ##     socket.acceptAddr(client, address)
 ##     echo("Client connected from: ", address)
 ##
+## **Note:** The ``client`` variable is initialised with ``new Socket`` **not**
+## ``newSocket()``. The difference is that the latter creates a new file
+## descriptor.
 
-{.deadCodeElim: on.}
+{.deadCodeElim: on.}  # dce option deprecated
 import nativesockets, os, strutils, parseutils, times, sets, options
 export Port, `$`, `==`
 export Domain, SockType, Protocol
@@ -107,9 +110,6 @@ when defineSsl:
       serverGetPskFunc: SslServerGetPskFunc
       clientGetPskFunc: SslClientGetPskFunc
 
-  {.deprecated: [ESSL: SSLError, TSSLCVerifyMode: SSLCVerifyMode,
-    TSSLProtVersion: SSLProtVersion, PSSLContext: SSLContext,
-    TSSLAcceptResult: SSLAcceptResult].}
 else:
   type
     SslContext* = void # TODO: Workaround #4797.
@@ -156,10 +156,6 @@ type
     Peek,
     SafeDisconn ## Ensures disconnection exceptions (ECONNRESET, EPIPE etc) are not thrown.
 
-{.deprecated: [TSocketFlags: SocketFlag, ETimeout: TimeoutError,
-    TReadLineResult: ReadLineResult, TSOBool: SOBool, PSocket: Socket,
-    TSocketImpl: SocketImpl].}
-
 type
   IpAddressFamily* {.pure.} = enum ## Describes the type of an IP address
     IPv6, ## IPv6 address
@@ -173,8 +169,6 @@ type
     of IpAddressFamily.IPv4:
       address_v4*: array[0..3, uint8] ## Contains the IP address in bytes in
                                       ## case of IPv4
-{.deprecated: [TIpAddress: IpAddress].}
-
 
 proc socketError*(socket: Socket, err: int = -1, async = false,
                   lastError = (-1).OSErrorCode): void {.gcsafe.}
@@ -221,7 +215,7 @@ proc newSocket*(domain, sockType, protocol: cint, buffered = true): Socket =
   ## Creates a new socket.
   ##
   ## If an error occurs EOS will be raised.
-  let fd = newNativeSocket(domain, sockType, protocol)
+  let fd = createNativeSocket(domain, sockType, protocol)
   if fd == osInvalidSocket:
     raiseOSError(osLastError())
   result = newSocket(fd, domain.Domain, sockType.SockType, protocol.Protocol,
@@ -232,7 +226,7 @@ proc newSocket*(domain: Domain = AF_INET, sockType: SockType = SOCK_STREAM,
   ## Creates a new socket.
   ##
   ## If an error occurs EOS will be raised.
-  let fd = newNativeSocket(domain, sockType, protocol)
+  let fd = createNativeSocket(domain, sockType, protocol)
   if fd == osInvalidSocket:
     raiseOSError(osLastError())
   result = newSocket(fd, domain, sockType, protocol, buffered)
@@ -411,9 +405,46 @@ proc isIpAddress*(address_str: string): bool {.tags: [].} =
     return false
   return true
 
+proc toSockAddr*(address: IpAddress, port: Port, sa: var Sockaddr_storage, sl: var Socklen) =
+  ## Converts `IpAddress` and `Port` to `SockAddr` and `Socklen`
+  let port = htons(uint16(port))
+  case address.family
+  of IpAddressFamily.IPv4:
+    sl = sizeof(Sockaddr_in).Socklen
+    let s = cast[ptr Sockaddr_in](addr sa)
+    s.sin_family = type(s.sin_family)(AF_INET)
+    s.sin_port = port
+    copyMem(addr s.sin_addr, unsafeAddr address.address_v4[0], sizeof(s.sin_addr))
+  of IpAddressFamily.IPv6:
+    sl = sizeof(Sockaddr_in6).Socklen
+    let s = cast[ptr Sockaddr_in6](addr sa)
+    s.sin6_family = type(s.sin6_family)(AF_INET6)
+    s.sin6_port = port
+    copyMem(addr s.sin6_addr, unsafeAddr address.address_v6[0], sizeof(s.sin6_addr))
+
+proc fromSockAddrAux(sa: ptr Sockaddr_storage, sl: Socklen, address: var IpAddress, port: var Port) =
+  if sa.ss_family.int == AF_INET.int and sl == sizeof(Sockaddr_in).Socklen:
+    address = IpAddress(family: IpAddressFamily.IPv4)
+    let s = cast[ptr Sockaddr_in](sa)
+    copyMem(addr address.address_v4[0], addr s.sin_addr, sizeof(address.address_v4))
+    port = ntohs(s.sin_port).Port
+  elif sa.ss_family.int == AF_INET6.int and sl == sizeof(Sockaddr_in6).Socklen:
+    address = IpAddress(family: IpAddressFamily.IPv6)
+    let s = cast[ptr Sockaddr_in6](sa)
+    copyMem(addr address.address_v6[0], addr s.sin6_addr, sizeof(address.address_v6))
+    port = ntohs(s.sin6_port).Port
+  else:
+    raise newException(ValueError, "Neither IPv4 nor IPv6")
+
+proc fromSockAddr*(sa: Sockaddr_storage | SockAddr | Sockaddr_in | Sockaddr_in6,
+    sl: Socklen, address: var IpAddress, port: var Port) {.inline.} =
+  ## Converts `SockAddr` and `Socklen` to `IpAddress` and `Port`. Raises
+  ## `ObjectConversionError` in case of invalid `sa` and `sl` arguments.
+  fromSockAddrAux(unsafeAddr sa, sl, address, port)
+
 when defineSsl:
   CRYPTO_malloc_init()
-  SslLibraryInit()
+  doAssert SslLibraryInit() == 1
   SslLoadErrorStrings()
   ErrLoadBioStrings()
   OpenSSL_add_all_algorithms()
@@ -427,8 +458,14 @@ when defineSsl:
       raise newException(SSLError, "No error reported.")
     if err == -1:
       raiseOSError(osLastError())
-    var errStr = ErrErrorString(err, nil)
-    raise newException(SSLError, $errStr)
+    var errStr = $ErrErrorString(err, nil)
+    case err
+    of 336032814, 336032784:
+      errStr = "Please upgrade your OpenSSL library, it does not support the " &
+               "necessary protocols. OpenSSL error is: " & errStr
+    else:
+      discard
+    raise newException(SSLError, errStr)
 
   proc getExtraData*(ctx: SSLContext, index: int): RootRef =
     ## Retrieves arbitrary data stored inside SSLContext.
@@ -753,10 +790,10 @@ proc acceptAddr*(server: Socket, client: var Socket, address: var string,
   ## flag is specified then this error will not be raised and instead
   ## accept will be called again.
   assert(client != nil)
-  var sockAddress: Sockaddr_in
-  var addrLen = sizeof(sockAddress).SockLen
-  var sock = accept(server.fd, cast[ptr SockAddr](addr(sockAddress)),
-                    addr(addrLen))
+  assert client.fd.int <= 0, "Client socket needs to be initialised with " &
+                             "`new`, not `newSocket`."
+  let ret = accept(server.fd)
+  let sock = ret[0]
 
   if sock == osInvalidSocket:
     let err = osLastError()
@@ -764,7 +801,9 @@ proc acceptAddr*(server: Socket, client: var Socket, address: var string,
       acceptAddr(server, client, address, flags)
     raiseOSError(err)
   else:
+    address = ret[1]
     client.fd = sock
+    client.domain = getSockDomain(sock)
     client.isBuffered = server.isBuffered
 
     # Handle SSL.
@@ -776,9 +815,6 @@ proc acceptAddr*(server: Socket, client: var Socket, address: var string,
         let ret = SSLAccept(client.sslHandle)
         socketError(client, ret, false)
 
-    # Client socket is set above.
-    address = $inet_ntoa(sockAddress.sin_addr)
-
 when false: #defineSsl:
   proc acceptAddrSSL*(server: Socket, client: var Socket,
                       address: var string): SSLAcceptResult {.
@@ -868,6 +904,7 @@ proc close*(socket: Socket) =
         socket.sslHandle = nil
 
     socket.fd.close()
+    socket.fd = osInvalidSocket
 
 when defined(posix):
   from posix import TCP_NODELAY
@@ -924,7 +961,7 @@ when defined(posix) and not defined(nimdoc):
       raise newException(ValueError, "socket path too long")
     copyMem(addr result.sun_path, path.cstring, path.len + 1)
 
-when defined(posix):
+when defined(posix) or defined(nimdoc):
   proc connectUnix*(socket: Socket, path: string) =
     ## Connects to Unix socket on `path`.
     ## This only works on Unix-style systems: Mac OS X, BSD and Linux
@@ -1005,15 +1042,25 @@ proc select(readfd: Socket, timeout = 500): int =
   var fds = @[readfd.fd]
   result = select(fds, timeout)
 
-proc readIntoBuf(socket: Socket, flags: int32): int =
+proc isClosed(socket: Socket): bool =
+  socket.fd == osInvalidSocket
+
+proc uniRecv(socket: Socket, buffer: pointer, size, flags: cint): int =
+  ## Handles SSL and non-ssl recv in a nice package.
+  ##
+  ## In particular handles the case where socket has been closed properly
+  ## for both SSL and non-ssl.
   result = 0
+  assert(not socket.isClosed, "Cannot `recv` on a closed socket")
   when defineSsl:
-    if socket.isSSL:
-      result = SSLRead(socket.sslHandle, addr(socket.buffer), int(socket.buffer.high))
-    else:
-      result = recv(socket.fd, addr(socket.buffer), cint(socket.buffer.high), flags)
-  else:
-    result = recv(socket.fd, addr(socket.buffer), cint(socket.buffer.high), flags)
+    if socket.isSsl:
+      return SSLRead(socket.sslHandle, buffer, size)
+
+  return recv(socket.fd, buffer, size, flags)
+
+proc readIntoBuf(socket: Socket, flags: int32): int =
+  result = 0
+  result = uniRecv(socket, addr(socket.buffer), socket.buffer.high, flags)
   if result < 0:
     # Save it in case it gets reset (the Nim codegen occasionally may call
     # Win API functions which reset it).
@@ -1059,16 +1106,16 @@ proc recv*(socket: Socket, data: pointer, size: int): int {.tags: [ReadIOEffect]
   else:
     when defineSsl:
       if socket.isSSL:
-        if socket.sslHasPeekChar:
+        if socket.sslHasPeekChar: # TODO: Merge this peek char mess into uniRecv
           copyMem(data, addr(socket.sslPeekChar), 1)
           socket.sslHasPeekChar = false
           if size-1 > 0:
             var d = cast[cstring](data)
-            result = SSLRead(socket.sslHandle, addr(d[1]), size-1) + 1
+            result = uniRecv(socket, addr(d[1]), cint(size-1), 0'i32) + 1
           else:
             result = 1
         else:
-          result = SSLRead(socket.sslHandle, data, size)
+          result = uniRecv(socket, data, size.cint, 0'i32)
       else:
         result = recv(socket.fd, data, size.cint, 0'i32)
     else:
@@ -1145,7 +1192,11 @@ proc recv*(socket: Socket, data: var string, size: int, timeout = -1,
   ##
   ## **Warning**: Only the ``SafeDisconn`` flag is currently supported.
   data.setLen(size)
-  result = recv(socket, cstring(data), size, timeout)
+  result =
+    if timeout == -1:
+      recv(socket, cstring(data), size)
+    else:
+      recv(socket, cstring(data), size, timeout)
   if result < 0:
     data.setLen(0)
     let lastError = getSocketError(socket)
@@ -1182,7 +1233,7 @@ proc peekChar(socket: Socket, c: var char): int {.tags: [ReadIOEffect].} =
     when defineSsl:
       if socket.isSSL:
         if not socket.sslHasPeekChar:
-          result = SSLRead(socket.sslHandle, addr(socket.sslPeekChar), 1)
+          result = uniRecv(socket, addr(socket.sslPeekChar), 1, 0'i32)
           socket.sslHasPeekChar = true
 
         c = socket.sslPeekChar
@@ -1316,6 +1367,7 @@ proc send*(socket: Socket, data: pointer, size: int): int {.
   ##
   ## **Note**: This is a low-level version of ``send``. You likely should use
   ## the version below.
+  assert(not socket.isClosed, "Cannot `send` on a closed socket")
   when defineSsl:
     if socket.isSSL:
       return SSLWrite(socket.sslHandle, cast[cstring](data), size)
@@ -1360,8 +1412,8 @@ proc sendTo*(socket: Socket, address: string, port: Port, data: pointer,
   ## which is defined below.
   ##
   ## **Note:** This proc is not available for SSL sockets.
-  var aiList = getAddrInfo(address, port, af)
-
+  assert(not socket.isClosed, "Cannot `sendTo` on a closed socket")
+  var aiList = getAddrInfo(address, port, af, socket.sockType, socket.protocol)
   # try all possibilities:
   var success = false
   var it = aiList
@@ -1382,7 +1434,7 @@ proc sendTo*(socket: Socket, address: string, port: Port,
   ## this function will try each IP of that hostname.
   ##
   ## This is the high-level version of the above ``sendTo`` function.
-  result = socket.sendTo(address, port, cstring(data), data.len)
+  result = socket.sendTo(address, port, cstring(data), data.len, socket.domain )
 
 
 proc isSsl*(socket: Socket): bool =
@@ -1531,7 +1583,7 @@ proc dial*(address: string, port: Port,
     domain = domainOpt.unsafeGet()
     lastFd = fdPerDomain[ord(domain)]
     if lastFd == osInvalidSocket:
-      lastFd = newNativeSocket(domain, sockType, protocol)
+      lastFd = createNativeSocket(domain, sockType, protocol)
       if lastFd == osInvalidSocket:
         # we always raise if socket creation failed, because it means a
         # network system problem (e.g. not enough FDs), and not an unreachable
@@ -1640,6 +1692,9 @@ proc connect*(socket: Socket, address: string, port = Port(0),
   if selectWrite(s, timeout) != 1:
     raise newException(TimeoutError, "Call to 'connect' timed out.")
   else:
+    let res = getSockOptInt(socket.fd, SOL_SOCKET, SO_ERROR)
+    if res != 0:
+      raiseOSError(OSErrorCode(res))
     when defineSsl and not defined(nimdoc):
       if socket.isSSL:
         socket.fd.setBlocking(true)