diff options
Diffstat (limited to 'lib/pure/net.nim')
-rw-r--r-- | lib/pure/net.nim | 80 |
1 files changed, 52 insertions, 28 deletions
diff --git a/lib/pure/net.nim b/lib/pure/net.nim index 1dfdbbbe6..24c94b651 100644 --- a/lib/pure/net.nim +++ b/lib/pure/net.nim @@ -93,23 +93,24 @@ import std/private/since when defined(nimPreviewSlimSystem): import std/assertions -import nativesockets -import os, strutils, times, sets, options, std/monotimes -import ssl_config +import std/nativesockets +import std/[os, strutils, times, sets, options, monotimes] +import std/ssl_config export nativesockets.Port, nativesockets.`$`, nativesockets.`==` -export Domain, SockType, Protocol +export Domain, SockType, Protocol, IPPROTO_NONE const useWinVersion = defined(windows) or defined(nimdoc) -const useNimNetLite = defined(nimNetLite) or defined(freertos) or defined(zephyr) +const useNimNetLite = defined(nimNetLite) or defined(freertos) or defined(zephyr) or + defined(nuttx) const defineSsl = defined(ssl) or defined(nimdoc) when useWinVersion: - from winlean import WSAESHUTDOWN + from std/winlean import WSAESHUTDOWN when defineSsl: - import openssl + import std/openssl when not defined(nimDisableCertificateValidation): - from ssl_certs import scanSSLCertificates + from std/ssl_certs import scanSSLCertificates # Note: The enumerations are mapped to Window's constants. @@ -208,7 +209,7 @@ when defined(nimHasStyleChecks): when defined(posix) and not defined(lwip): - from posix import TPollfd, POLLIN, POLLPRI, POLLOUT, POLLWRBAND, Tnfds + from std/posix import TPollfd, POLLIN, POLLPRI, POLLOUT, POLLWRBAND, Tnfds template monitorPollEvent(x: var SocketHandle, y: cint, timeout: int): int = var tpollfd: TPollfd @@ -621,7 +622,7 @@ when defineSsl: proc newContext*(protVersion = protSSLv23, verifyMode = CVerifyPeer, certFile = "", keyFile = "", cipherList = CiphersIntermediate, - caDir = "", caFile = ""): SslContext = + caDir = "", caFile = "", ciphersuites = CiphersModern): SslContext = ## Creates an SSL context. ## ## Protocol version is currently ignored by default and TLS is used. @@ -675,16 +676,16 @@ when defineSsl: raiseSSLError() when not defined(openssl10) and not defined(libressl): let sslVersion = getOpenSSLVersion() - if sslVersion >= 0x010101000 and not sslVersion == 0x020000000: + if sslVersion >= 0x010101000 and sslVersion != 0x020000000: # In OpenSSL >= 1.1.1, TLSv1.3 cipher suites can only be configured via # this API. - if newCTX.SSL_CTX_set_ciphersuites(cipherList) != 1: + if newCTX.SSL_CTX_set_ciphersuites(ciphersuites) != 1: raiseSSLError() # Automatically the best ECDH curve for client exchange. Without this, ECDH # ciphers will be ignored by the server. # # From OpenSSL >= 1.1.0, this setting is set by default and can't be - # overriden. + # overridden. if newCTX.SSL_CTX_set_ecdh_auto(1) != 1: raiseSSLError() @@ -1153,7 +1154,7 @@ proc accept*(server: Socket, client: var owned(Socket), acceptAddr(server, client, addrDummy, flags) when defined(posix) and not defined(lwip): - from posix import Sigset, sigwait, sigismember, sigemptyset, sigaddset, + from std/posix import Sigset, sigwait, sigismember, sigemptyset, sigaddset, sigprocmask, pthread_sigmask, SIGPIPE, SIG_BLOCK, SIG_UNBLOCK template blockSigpipe(body: untyped): untyped = @@ -1267,9 +1268,9 @@ proc close*(socket: Socket, flags = {SocketFlag.SafeDisconn}) = socket.fd = osInvalidSocket when defined(posix): - from posix import TCP_NODELAY + from std/posix import TCP_NODELAY else: - from winlean import TCP_NODELAY + from std/winlean import TCP_NODELAY proc toCInt*(opt: SOBool): cint = ## Converts a `SOBool` into its Socket Option cint representation. @@ -1320,7 +1321,7 @@ when defined(nimdoc) or (defined(posix) and not useNimNetLite): when not defined(nimdoc): var socketAddr = makeUnixAddr(path) if socket.fd.connect(cast[ptr SockAddr](addr socketAddr), - (sizeof(socketAddr.sun_family) + path.len).SockLen) != 0'i32: + (offsetOf(socketAddr, sun_path) + path.len + 1).SockLen) != 0'i32: raiseOSError(osLastError()) proc bindUnix*(socket: Socket, path: string) = @@ -1329,7 +1330,7 @@ when defined(nimdoc) or (defined(posix) and not useNimNetLite): when not defined(nimdoc): var socketAddr = makeUnixAddr(path) if socket.fd.bindAddr(cast[ptr SockAddr](addr socketAddr), - (sizeof(socketAddr.sun_family) + path.len).SockLen) != 0'i32: + (offsetOf(socketAddr, sun_path) + path.len + 1).SockLen) != 0'i32: raiseOSError(osLastError()) when defineSsl: @@ -1735,15 +1736,36 @@ proc send*(socket: Socket, data: pointer, size: int): int {. result = send(socket.fd, data, size, int32(MSG_NOSIGNAL)) proc send*(socket: Socket, data: string, - flags = {SocketFlag.SafeDisconn}) {.tags: [WriteIOEffect].} = - ## sends data to a socket. - let sent = send(socket, cstring(data), data.len) - if sent < 0: - let lastError = osLastError() - socketError(socket, lastError = lastError, flags = flags) + flags = {SocketFlag.SafeDisconn}, maxRetries = 100) {.tags: [WriteIOEffect].} = + ## Sends data to a socket. Will try to send all the data by handling interrupts + ## and incomplete writes up to `maxRetries`. + var written = 0 + var attempts = 0 + while data.len - written > 0: + let sent = send(socket, cstring(data), data.len) + + if sent < 0: + let lastError = osLastError() + let isBlockingErr = + when defined(nimdoc): + false + elif useWinVersion: + lastError.int32 == WSAEINTR or + lastError.int32 == WSAEWOULDBLOCK + else: + lastError.int32 == EINTR or + lastError.int32 == EWOULDBLOCK or + lastError.int32 == EAGAIN - if sent != data.len: - raiseOSError(osLastError(), "Could not send all data.") + if not isBlockingErr: + let lastError = osLastError() + socketError(socket, lastError = lastError, flags = flags) + else: + attempts.inc() + if attempts > maxRetries: + raiseOSError(osLastError(), "Could not send all data.") + else: + written.inc(sent) template `&=`*(socket: Socket; data: typed) = ## an alias for 'send'. @@ -1900,7 +1922,7 @@ proc `$`*(address: IpAddress): string = ## Converts an IpAddress into the textual representation case address.family of IpAddressFamily.IPv4: - result = newStringOfCap(16) + result = newStringOfCap(15) result.addInt address.address_v4[0] result.add '.' result.addInt address.address_v4[1] @@ -1909,7 +1931,7 @@ proc `$`*(address: IpAddress): string = result.add '.' result.addInt address.address_v4[3] of IpAddressFamily.IPv6: - result = newStringOfCap(48) + result = newStringOfCap(39) var currentZeroStart = -1 currentZeroCount = 0 @@ -2018,8 +2040,10 @@ proc dial*(address: string, port: Port, if success: result = newSocket(lastFd, domain, sockType, protocol, buffered) elif lastError != 0.OSErrorCode: + lastFd.close() raiseOSError(lastError) else: + lastFd.close() raise newException(IOError, "Couldn't resolve address: " & address) proc connect*(socket: Socket, address: string, |