summary refs log tree commit diff stats
path: root/lib/pure/net.nim
diff options
context:
space:
mode:
Diffstat (limited to 'lib/pure/net.nim')
-rw-r--r--lib/pure/net.nim119
1 files changed, 80 insertions, 39 deletions
diff --git a/lib/pure/net.nim b/lib/pure/net.nim
index 16390d9af..24c94b651 100644
--- a/lib/pure/net.nim
+++ b/lib/pure/net.nim
@@ -93,23 +93,24 @@ import std/private/since
 when defined(nimPreviewSlimSystem):
   import std/assertions
 
-import nativesockets
-import os, strutils, times, sets, options, std/monotimes
-import ssl_config
+import std/nativesockets
+import std/[os, strutils, times, sets, options, monotimes]
+import std/ssl_config
 export nativesockets.Port, nativesockets.`$`, nativesockets.`==`
-export Domain, SockType, Protocol
+export Domain, SockType, Protocol, IPPROTO_NONE
 
 const useWinVersion = defined(windows) or defined(nimdoc)
-const useNimNetLite = defined(nimNetLite) or defined(freertos) or defined(zephyr)
+const useNimNetLite = defined(nimNetLite) or defined(freertos) or defined(zephyr) or
+    defined(nuttx)
 const defineSsl = defined(ssl) or defined(nimdoc)
 
 when useWinVersion:
-  from winlean import WSAESHUTDOWN
+  from std/winlean import WSAESHUTDOWN
 
 when defineSsl:
-  import openssl
+  import std/openssl
   when not defined(nimDisableCertificateValidation):
-    from ssl_certs import scanSSLCertificates
+    from std/ssl_certs import scanSSLCertificates
 
 # Note: The enumerations are mapped to Window's constants.
 
@@ -208,7 +209,7 @@ when defined(nimHasStyleChecks):
 
 
 when defined(posix) and not defined(lwip):
-  from posix import TPollfd, POLLIN, POLLPRI, POLLOUT, POLLWRBAND, Tnfds
+  from std/posix import TPollfd, POLLIN, POLLPRI, POLLOUT, POLLWRBAND, Tnfds
 
   template monitorPollEvent(x: var SocketHandle, y: cint, timeout: int): int =
     var tpollfd: TPollfd
@@ -480,8 +481,8 @@ proc parseIPv6Address(addressStr: string): IpAddress =
 proc parseIpAddress*(addressStr: string): IpAddress =
   ## Parses an IP address
   ##
-  ## Raises ValueError on error. 
-  ## 
+  ## Raises ValueError on error.
+  ##
   ## For IPv4 addresses, only the strict form as
   ## defined in RFC 6943 is considered valid, see
   ## https://datatracker.ietf.org/doc/html/rfc6943#section-3.1.1.
@@ -621,10 +622,11 @@ when defineSsl:
 
   proc newContext*(protVersion = protSSLv23, verifyMode = CVerifyPeer,
                    certFile = "", keyFile = "", cipherList = CiphersIntermediate,
-                   caDir = "", caFile = ""): SslContext =
+                   caDir = "", caFile = "", ciphersuites = CiphersModern): SslContext =
     ## Creates an SSL context.
     ##
-    ## protVersion is currently unsed.
+    ## Protocol version is currently ignored by default and TLS is used.
+    ## With `-d:openssl10`, only SSLv23 and TLSv1 may be used.
     ##
     ## There are three options for verify mode:
     ## `CVerifyNone`: certificates are not verified;
@@ -651,7 +653,19 @@ when defineSsl:
     ## or using ECDSA:
     ## - `openssl ecparam -out mykey.pem -name secp256k1 -genkey`
     ## - `openssl req -new -key mykey.pem -x509 -nodes -days 365 -out mycert.pem`
-    let mtd = TLS_method()
+    var mtd: PSSL_METHOD
+    when defined(openssl10):
+      case protVersion
+      of protSSLv23:
+        mtd = SSLv23_method()
+      of protSSLv2:
+        raiseSSLError("SSLv2 is no longer secure and has been deprecated, use protSSLv23")
+      of protSSLv3:
+        raiseSSLError("SSLv3 is no longer secure and has been deprecated, use protSSLv23")
+      of protTLSv1:
+        mtd = TLSv1_method()
+    else:
+      mtd = TLS_method()
     if mtd == nil:
       raiseSSLError("Failed to create TLS context")
     var newCTX = SSL_CTX_new(mtd)
@@ -662,16 +676,16 @@ when defineSsl:
       raiseSSLError()
     when not defined(openssl10) and not defined(libressl):
       let sslVersion = getOpenSSLVersion()
-      if sslVersion >= 0x010101000 and not sslVersion == 0x020000000:
+      if sslVersion >= 0x010101000 and sslVersion != 0x020000000:
         # In OpenSSL >= 1.1.1, TLSv1.3 cipher suites can only be configured via
         # this API.
-        if newCTX.SSL_CTX_set_ciphersuites(cipherList) != 1:
+        if newCTX.SSL_CTX_set_ciphersuites(ciphersuites) != 1:
           raiseSSLError()
     # Automatically the best ECDH curve for client exchange. Without this, ECDH
     # ciphers will be ignored by the server.
     #
     # From OpenSSL >= 1.1.0, this setting is set by default and can't be
-    # overriden.
+    # overridden.
     if newCTX.SSL_CTX_set_ecdh_auto(1) != 1:
       raiseSSLError()
 
@@ -1025,7 +1039,7 @@ proc bindAddr*(socket: Socket, port = Port(0), address = "") {.
 proc acceptAddr*(server: Socket, client: var owned(Socket), address: var string,
                  flags = {SocketFlag.SafeDisconn},
                  inheritable = defined(nimInheritHandles)) {.
-                 tags: [ReadIOEffect], gcsafe, locks: 0.} =
+                 tags: [ReadIOEffect], gcsafe.} =
   ## Blocks until a connection is being made from a client. When a connection
   ## is made sets `client` to the client socket and `address` to the address
   ## of the connecting client.
@@ -1140,7 +1154,7 @@ proc accept*(server: Socket, client: var owned(Socket),
   acceptAddr(server, client, addrDummy, flags)
 
 when defined(posix) and not defined(lwip):
-  from posix import Sigset, sigwait, sigismember, sigemptyset, sigaddset,
+  from std/posix import Sigset, sigwait, sigismember, sigemptyset, sigaddset,
     sigprocmask, pthread_sigmask, SIGPIPE, SIG_BLOCK, SIG_UNBLOCK
 
 template blockSigpipe(body: untyped): untyped =
@@ -1254,9 +1268,9 @@ proc close*(socket: Socket, flags = {SocketFlag.SafeDisconn}) =
     socket.fd = osInvalidSocket
 
 when defined(posix):
-  from posix import TCP_NODELAY
+  from std/posix import TCP_NODELAY
 else:
-  from winlean import TCP_NODELAY
+  from std/winlean import TCP_NODELAY
 
 proc toCInt*(opt: SOBool): cint =
   ## Converts a `SOBool` into its Socket Option cint representation.
@@ -1307,7 +1321,7 @@ when defined(nimdoc) or (defined(posix) and not useNimNetLite):
     when not defined(nimdoc):
       var socketAddr = makeUnixAddr(path)
       if socket.fd.connect(cast[ptr SockAddr](addr socketAddr),
-          (sizeof(socketAddr.sun_family) + path.len).SockLen) != 0'i32:
+          (offsetOf(socketAddr, sun_path) + path.len + 1).SockLen) != 0'i32:
         raiseOSError(osLastError())
 
   proc bindUnix*(socket: Socket, path: string) =
@@ -1316,7 +1330,7 @@ when defined(nimdoc) or (defined(posix) and not useNimNetLite):
     when not defined(nimdoc):
       var socketAddr = makeUnixAddr(path)
       if socket.fd.bindAddr(cast[ptr SockAddr](addr socketAddr),
-          (sizeof(socketAddr.sun_family) + path.len).SockLen) != 0'i32:
+          (offsetOf(socketAddr, sun_path) + path.len + 1).SockLen) != 0'i32:
         raiseOSError(osLastError())
 
 when defineSsl:
@@ -1722,15 +1736,36 @@ proc send*(socket: Socket, data: pointer, size: int): int {.
     result = send(socket.fd, data, size, int32(MSG_NOSIGNAL))
 
 proc send*(socket: Socket, data: string,
-           flags = {SocketFlag.SafeDisconn}) {.tags: [WriteIOEffect].} =
-  ## sends data to a socket.
-  let sent = send(socket, cstring(data), data.len)
-  if sent < 0:
-    let lastError = osLastError()
-    socketError(socket, lastError = lastError, flags = flags)
+           flags = {SocketFlag.SafeDisconn}, maxRetries = 100) {.tags: [WriteIOEffect].} =
+  ## Sends data to a socket. Will try to send all the data by handling interrupts
+  ## and incomplete writes up to `maxRetries`.
+  var written = 0
+  var attempts = 0
+  while data.len - written > 0:
+    let sent = send(socket, cstring(data), data.len)
+
+    if sent < 0:
+      let lastError = osLastError()
+      let isBlockingErr =
+        when defined(nimdoc):
+          false
+        elif useWinVersion:
+          lastError.int32 == WSAEINTR or
+          lastError.int32 == WSAEWOULDBLOCK
+        else:
+          lastError.int32 == EINTR or
+          lastError.int32 == EWOULDBLOCK or
+          lastError.int32 == EAGAIN
 
-  if sent != data.len:
-    raiseOSError(osLastError(), "Could not send all data.")
+      if not isBlockingErr:
+        let lastError = osLastError()
+        socketError(socket, lastError = lastError, flags = flags)
+      else:
+        attempts.inc()
+        if attempts > maxRetries:
+          raiseOSError(osLastError(), "Could not send all data.")
+    else:
+      written.inc(sent)
 
 template `&=`*(socket: Socket; data: typed) =
   ## an alias for 'send'.
@@ -1781,7 +1816,7 @@ proc sendTo*(socket: Socket, address: string, port: Port,
   ## This proc sends `data` to the specified `address`,
   ## which may be an IP address or a hostname, if a hostname is specified
   ## this function will try each IP of that hostname.
-  ## 
+  ##
   ## Generally for use with connection-less (UDP) sockets.
   ##
   ## If an error occurs an OSError exception will be raised.
@@ -1793,9 +1828,9 @@ proc sendTo*(socket: Socket, address: IpAddress, port: Port,
              data: string, flags = 0'i32): int {.
               discardable, tags: [WriteIOEffect].} =
   ## This proc sends `data` to the specified `IpAddress` and returns
-  ## the number of bytes written. 
+  ## the number of bytes written.
   ##
-  ## Generally for use with connection-less (UDP) sockets. 
+  ## Generally for use with connection-less (UDP) sockets.
   ##
   ## If an error occurs an OSError exception will be raised.
   ##
@@ -1885,14 +1920,18 @@ proc `==`*(lhs, rhs: IpAddress): bool =
 
 proc `$`*(address: IpAddress): string =
   ## Converts an IpAddress into the textual representation
-  result = ""
   case address.family
   of IpAddressFamily.IPv4:
-    for i in 0 .. 3:
-      if i != 0:
-        result.add('.')
-      result.add($address.address_v4[i])
+    result = newStringOfCap(15)
+    result.addInt address.address_v4[0]
+    result.add '.'
+    result.addInt address.address_v4[1]
+    result.add '.'
+    result.addInt address.address_v4[2]
+    result.add '.'
+    result.addInt address.address_v4[3]
   of IpAddressFamily.IPv6:
+    result = newStringOfCap(39)
     var
       currentZeroStart = -1
       currentZeroCount = 0
@@ -2001,8 +2040,10 @@ proc dial*(address: string, port: Port,
   if success:
     result = newSocket(lastFd, domain, sockType, protocol, buffered)
   elif lastError != 0.OSErrorCode:
+    lastFd.close()
     raiseOSError(lastError)
   else:
+    lastFd.close()
     raise newException(IOError, "Couldn't resolve address: " & address)
 
 proc connect*(socket: Socket, address: string,