summary refs log tree commit diff stats
path: root/lib/pure/net.nim
diff options
context:
space:
mode:
Diffstat (limited to 'lib/pure/net.nim')
-rw-r--r--lib/pure/net.nim80
1 files changed, 52 insertions, 28 deletions
diff --git a/lib/pure/net.nim b/lib/pure/net.nim
index 1dfdbbbe6..24c94b651 100644
--- a/lib/pure/net.nim
+++ b/lib/pure/net.nim
@@ -93,23 +93,24 @@ import std/private/since
 when defined(nimPreviewSlimSystem):
   import std/assertions
 
-import nativesockets
-import os, strutils, times, sets, options, std/monotimes
-import ssl_config
+import std/nativesockets
+import std/[os, strutils, times, sets, options, monotimes]
+import std/ssl_config
 export nativesockets.Port, nativesockets.`$`, nativesockets.`==`
-export Domain, SockType, Protocol
+export Domain, SockType, Protocol, IPPROTO_NONE
 
 const useWinVersion = defined(windows) or defined(nimdoc)
-const useNimNetLite = defined(nimNetLite) or defined(freertos) or defined(zephyr)
+const useNimNetLite = defined(nimNetLite) or defined(freertos) or defined(zephyr) or
+    defined(nuttx)
 const defineSsl = defined(ssl) or defined(nimdoc)
 
 when useWinVersion:
-  from winlean import WSAESHUTDOWN
+  from std/winlean import WSAESHUTDOWN
 
 when defineSsl:
-  import openssl
+  import std/openssl
   when not defined(nimDisableCertificateValidation):
-    from ssl_certs import scanSSLCertificates
+    from std/ssl_certs import scanSSLCertificates
 
 # Note: The enumerations are mapped to Window's constants.
 
@@ -208,7 +209,7 @@ when defined(nimHasStyleChecks):
 
 
 when defined(posix) and not defined(lwip):
-  from posix import TPollfd, POLLIN, POLLPRI, POLLOUT, POLLWRBAND, Tnfds
+  from std/posix import TPollfd, POLLIN, POLLPRI, POLLOUT, POLLWRBAND, Tnfds
 
   template monitorPollEvent(x: var SocketHandle, y: cint, timeout: int): int =
     var tpollfd: TPollfd
@@ -621,7 +622,7 @@ when defineSsl:
 
   proc newContext*(protVersion = protSSLv23, verifyMode = CVerifyPeer,
                    certFile = "", keyFile = "", cipherList = CiphersIntermediate,
-                   caDir = "", caFile = ""): SslContext =
+                   caDir = "", caFile = "", ciphersuites = CiphersModern): SslContext =
     ## Creates an SSL context.
     ##
     ## Protocol version is currently ignored by default and TLS is used.
@@ -675,16 +676,16 @@ when defineSsl:
       raiseSSLError()
     when not defined(openssl10) and not defined(libressl):
       let sslVersion = getOpenSSLVersion()
-      if sslVersion >= 0x010101000 and not sslVersion == 0x020000000:
+      if sslVersion >= 0x010101000 and sslVersion != 0x020000000:
         # In OpenSSL >= 1.1.1, TLSv1.3 cipher suites can only be configured via
         # this API.
-        if newCTX.SSL_CTX_set_ciphersuites(cipherList) != 1:
+        if newCTX.SSL_CTX_set_ciphersuites(ciphersuites) != 1:
           raiseSSLError()
     # Automatically the best ECDH curve for client exchange. Without this, ECDH
     # ciphers will be ignored by the server.
     #
     # From OpenSSL >= 1.1.0, this setting is set by default and can't be
-    # overriden.
+    # overridden.
     if newCTX.SSL_CTX_set_ecdh_auto(1) != 1:
       raiseSSLError()
 
@@ -1153,7 +1154,7 @@ proc accept*(server: Socket, client: var owned(Socket),
   acceptAddr(server, client, addrDummy, flags)
 
 when defined(posix) and not defined(lwip):
-  from posix import Sigset, sigwait, sigismember, sigemptyset, sigaddset,
+  from std/posix import Sigset, sigwait, sigismember, sigemptyset, sigaddset,
     sigprocmask, pthread_sigmask, SIGPIPE, SIG_BLOCK, SIG_UNBLOCK
 
 template blockSigpipe(body: untyped): untyped =
@@ -1267,9 +1268,9 @@ proc close*(socket: Socket, flags = {SocketFlag.SafeDisconn}) =
     socket.fd = osInvalidSocket
 
 when defined(posix):
-  from posix import TCP_NODELAY
+  from std/posix import TCP_NODELAY
 else:
-  from winlean import TCP_NODELAY
+  from std/winlean import TCP_NODELAY
 
 proc toCInt*(opt: SOBool): cint =
   ## Converts a `SOBool` into its Socket Option cint representation.
@@ -1320,7 +1321,7 @@ when defined(nimdoc) or (defined(posix) and not useNimNetLite):
     when not defined(nimdoc):
       var socketAddr = makeUnixAddr(path)
       if socket.fd.connect(cast[ptr SockAddr](addr socketAddr),
-          (sizeof(socketAddr.sun_family) + path.len).SockLen) != 0'i32:
+          (offsetOf(socketAddr, sun_path) + path.len + 1).SockLen) != 0'i32:
         raiseOSError(osLastError())
 
   proc bindUnix*(socket: Socket, path: string) =
@@ -1329,7 +1330,7 @@ when defined(nimdoc) or (defined(posix) and not useNimNetLite):
     when not defined(nimdoc):
       var socketAddr = makeUnixAddr(path)
       if socket.fd.bindAddr(cast[ptr SockAddr](addr socketAddr),
-          (sizeof(socketAddr.sun_family) + path.len).SockLen) != 0'i32:
+          (offsetOf(socketAddr, sun_path) + path.len + 1).SockLen) != 0'i32:
         raiseOSError(osLastError())
 
 when defineSsl:
@@ -1735,15 +1736,36 @@ proc send*(socket: Socket, data: pointer, size: int): int {.
     result = send(socket.fd, data, size, int32(MSG_NOSIGNAL))
 
 proc send*(socket: Socket, data: string,
-           flags = {SocketFlag.SafeDisconn}) {.tags: [WriteIOEffect].} =
-  ## sends data to a socket.
-  let sent = send(socket, cstring(data), data.len)
-  if sent < 0:
-    let lastError = osLastError()
-    socketError(socket, lastError = lastError, flags = flags)
+           flags = {SocketFlag.SafeDisconn}, maxRetries = 100) {.tags: [WriteIOEffect].} =
+  ## Sends data to a socket. Will try to send all the data by handling interrupts
+  ## and incomplete writes up to `maxRetries`.
+  var written = 0
+  var attempts = 0
+  while data.len - written > 0:
+    let sent = send(socket, cstring(data), data.len)
+
+    if sent < 0:
+      let lastError = osLastError()
+      let isBlockingErr =
+        when defined(nimdoc):
+          false
+        elif useWinVersion:
+          lastError.int32 == WSAEINTR or
+          lastError.int32 == WSAEWOULDBLOCK
+        else:
+          lastError.int32 == EINTR or
+          lastError.int32 == EWOULDBLOCK or
+          lastError.int32 == EAGAIN
 
-  if sent != data.len:
-    raiseOSError(osLastError(), "Could not send all data.")
+      if not isBlockingErr:
+        let lastError = osLastError()
+        socketError(socket, lastError = lastError, flags = flags)
+      else:
+        attempts.inc()
+        if attempts > maxRetries:
+          raiseOSError(osLastError(), "Could not send all data.")
+    else:
+      written.inc(sent)
 
 template `&=`*(socket: Socket; data: typed) =
   ## an alias for 'send'.
@@ -1900,7 +1922,7 @@ proc `$`*(address: IpAddress): string =
   ## Converts an IpAddress into the textual representation
   case address.family
   of IpAddressFamily.IPv4:
-    result = newStringOfCap(16)
+    result = newStringOfCap(15)
     result.addInt address.address_v4[0]
     result.add '.'
     result.addInt address.address_v4[1]
@@ -1909,7 +1931,7 @@ proc `$`*(address: IpAddress): string =
     result.add '.'
     result.addInt address.address_v4[3]
   of IpAddressFamily.IPv6:
-    result = newStringOfCap(48)
+    result = newStringOfCap(39)
     var
       currentZeroStart = -1
       currentZeroCount = 0
@@ -2018,8 +2040,10 @@ proc dial*(address: string, port: Port,
   if success:
     result = newSocket(lastFd, domain, sockType, protocol, buffered)
   elif lastError != 0.OSErrorCode:
+    lastFd.close()
     raiseOSError(lastError)
   else:
+    lastFd.close()
     raise newException(IOError, "Couldn't resolve address: " & address)
 
 proc connect*(socket: Socket, address: string,