diff options
Diffstat (limited to 'lib/pure/net.nim')
-rw-r--r-- | lib/pure/net.nim | 466 |
1 files changed, 302 insertions, 164 deletions
diff --git a/lib/pure/net.nim b/lib/pure/net.nim index b37782271..24c94b651 100644 --- a/lib/pure/net.nim +++ b/lib/pure/net.nim @@ -44,64 +44,73 @@ ## After you create a socket with the `newSocket` procedure, you can easily ## connect it to a server running at a known hostname (or IP address) and port. ## To do so over TCP, use the example below. -## -## .. code-block:: Nim -## var socket = newSocket() -## socket.connect("google.com", Port(80)) -## -## For SSL, use the following example (and make sure to compile with `-d:ssl`): -## -## .. code-block:: Nim -## var socket = newSocket() -## var ctx = newContext() -## wrapSocket(ctx, socket) -## socket.connect("google.com", Port(443)) -## + +runnableExamples("-r:off"): + let socket = newSocket() + socket.connect("google.com", Port(80)) + +## For SSL, use the following example: + +runnableExamples("-r:off -d:ssl"): + let socket = newSocket() + let ctx = newContext() + wrapSocket(ctx, socket) + socket.connect("google.com", Port(443)) + ## UDP is a connectionless protocol, so UDP sockets don't have to explicitly ## call the `connect <net.html#connect%2CSocket%2Cstring>`_ procedure. They can ## simply start sending data immediately. -## -## .. code-block:: Nim -## var socket = newSocket(AF_INET, SOCK_DGRAM, IPPROTO_UDP) -## socket.sendTo("192.168.0.1", Port(27960), "status\n") -## + +runnableExamples("-r:off"): + let socket = newSocket(AF_INET, SOCK_DGRAM, IPPROTO_UDP) + socket.sendTo("192.168.0.1", Port(27960), "status\n") + +runnableExamples("-r:off"): + let socket = newSocket(AF_INET, SOCK_DGRAM, IPPROTO_UDP) + let ip = parseIpAddress("192.168.0.1") + doAssert socket.sendTo(ip, Port(27960), "status\c\l") == 8 + ## Creating a server ## ----------------- ## ## After you create a socket with the `newSocket` procedure, you can create a ## TCP server by calling the `bindAddr` and `listen` procedures. -## -## .. code-block:: Nim -## var socket = newSocket() -## socket.bindAddr(Port(1234)) -## socket.listen() -## -## You can then begin accepting connections using the `accept` procedure. -## -## .. code-block:: Nim -## var client: Socket -## var address = "" -## while true: -## socket.acceptAddr(client, address) -## echo("Client connected from: ", address) + +runnableExamples("-r:off"): + let socket = newSocket() + socket.bindAddr(Port(1234)) + socket.listen() + + # You can then begin accepting connections using the `accept` procedure. + var client: Socket + var address = "" + while true: + socket.acceptAddr(client, address) + echo "Client connected from: ", address import std/private/since -import nativesockets, os, strutils, times, sets, options, std/monotimes -import ssl_config +when defined(nimPreviewSlimSystem): + import std/assertions + +import std/nativesockets +import std/[os, strutils, times, sets, options, monotimes] +import std/ssl_config export nativesockets.Port, nativesockets.`$`, nativesockets.`==` -export Domain, SockType, Protocol +export Domain, SockType, Protocol, IPPROTO_NONE const useWinVersion = defined(windows) or defined(nimdoc) +const useNimNetLite = defined(nimNetLite) or defined(freertos) or defined(zephyr) or + defined(nuttx) const defineSsl = defined(ssl) or defined(nimdoc) when useWinVersion: - from winlean import WSAESHUTDOWN + from std/winlean import WSAESHUTDOWN when defineSsl: - import openssl + import std/openssl when not defined(nimDisableCertificateValidation): - from ssl_certs import scanSSLCertificates + from std/ssl_certs import scanSSLCertificates # Note: The enumerations are mapped to Window's constants. @@ -198,6 +207,30 @@ type when defined(nimHasStyleChecks): {.pop.} + +when defined(posix) and not defined(lwip): + from std/posix import TPollfd, POLLIN, POLLPRI, POLLOUT, POLLWRBAND, Tnfds + + template monitorPollEvent(x: var SocketHandle, y: cint, timeout: int): int = + var tpollfd: TPollfd + tpollfd.fd = cast[cint](x) + tpollfd.events = y + posix.poll(addr(tpollfd), Tnfds(1), timeout) + +proc timeoutRead(fd: var SocketHandle, timeout = 500): int = + when defined(windows) or defined(lwip): + var fds = @[fd] + selectRead(fds, timeout) + else: + monitorPollEvent(fd, POLLIN or POLLPRI, timeout) + +proc timeoutWrite(fd: var SocketHandle, timeout = 500): int = + when defined(windows) or defined(lwip): + var fds = @[fd] + selectWrite(fds, timeout) + else: + monitorPollEvent(fd, POLLOUT or POLLWRBAND, timeout) + proc socketError*(socket: Socket, err: int = -1, async = false, lastError = (-1).OSErrorCode, flags: set[SocketFlag] = {}) {.gcsafe.} @@ -283,14 +316,20 @@ proc parseIPv4Address(addressStr: string): IpAddress = byteCount = 0 currentByte: uint16 = 0 separatorValid = false + leadingZero = false result = IpAddress(family: IpAddressFamily.IPv4) for i in 0 .. high(addressStr): if addressStr[i] in strutils.Digits: # Character is a number + if leadingZero: + raise newException(ValueError, + "Invalid IP address. Octal numbers are not allowed") currentByte = currentByte * 10 + cast[uint16](ord(addressStr[i]) - ord('0')) - if currentByte > 255'u16: + if currentByte == 0'u16: + leadingZero = true + elif currentByte > 255'u16: raise newException(ValueError, "Invalid IP Address. Value is out of range") separatorValid = true @@ -302,6 +341,7 @@ proc parseIPv4Address(addressStr: string): IpAddress = currentByte = 0 byteCount.inc separatorValid = false + leadingZero = false else: raise newException(ValueError, "Invalid IP Address. Address contains an invalid character") @@ -390,10 +430,16 @@ proc parseIPv6Address(addressStr: string): IpAddress = result.address_v6[groupCount*2+1] = cast[uint8](currentShort and 0xFF) groupCount.inc() else: # Must parse IPv4 address + var leadingZero = false for i, c in addressStr[v4StartPos..high(addressStr)]: if c in strutils.Digits: # Character is a number + if leadingZero: + raise newException(ValueError, + "Invalid IP address. Octal numbers not allowed") currentShort = currentShort * 10 + cast[uint32](ord(c) - ord('0')) - if currentShort > 255'u32: + if currentShort == 0'u32: + leadingZero = true + elif currentShort > 255'u32: raise newException(ValueError, "Invalid IP Address. Value is out of range") separatorValid = true @@ -404,6 +450,7 @@ proc parseIPv6Address(addressStr: string): IpAddress = currentShort = 0 byteCount.inc() separatorValid = false + leadingZero = false else: # Invalid character raise newException(ValueError, "Invalid IP Address. Address contains an invalid character") @@ -433,7 +480,12 @@ proc parseIPv6Address(addressStr: string): IpAddress = proc parseIpAddress*(addressStr: string): IpAddress = ## Parses an IP address - ## Raises ValueError on error + ## + ## Raises ValueError on error. + ## + ## For IPv4 addresses, only the strict form as + ## defined in RFC 6943 is considered valid, see + ## https://datatracker.ietf.org/doc/html/rfc6943#section-3.1.1. if addressStr.len == 0: raise newException(ValueError, "IP Address string is empty") if addressStr.contains(':'): @@ -495,13 +547,20 @@ proc fromSockAddr*(sa: Sockaddr_storage | SockAddr | Sockaddr_in | Sockaddr_in6, fromSockAddrAux(cast[ptr Sockaddr_storage](unsafeAddr sa), sl, address, port) when defineSsl: - CRYPTO_malloc_init() - doAssert SslLibraryInit() == 1 - SSL_load_error_strings() - ERR_load_BIO_strings() - OpenSSL_add_all_algorithms() - - proc raiseSSLError*(s = "") = + # OpenSSL >= 1.1.0 does not need explicit init. + when not useOpenssl3: + CRYPTO_malloc_init() + doAssert SslLibraryInit() == 1 + SSL_load_error_strings() + ERR_load_BIO_strings() + OpenSSL_add_all_algorithms() + + proc sslHandle*(self: Socket): SslPtr = + ## Retrieve the ssl pointer of `socket`. + ## Useful for interfacing with `openssl`. + self.sslHandle + + proc raiseSSLError*(s = "") {.raises: [SslError].}= ## Raises a new SSL error. if s != "": raise newException(SslError, s) @@ -563,12 +622,11 @@ when defineSsl: proc newContext*(protVersion = protSSLv23, verifyMode = CVerifyPeer, certFile = "", keyFile = "", cipherList = CiphersIntermediate, - caDir = "", caFile = ""): SslContext = + caDir = "", caFile = "", ciphersuites = CiphersModern): SslContext = ## Creates an SSL context. ## - ## Protocol version specifies the protocol to use. SSLv2, SSLv3, TLSv1 - ## are available with the addition of `protSSLv23` which allows for - ## compatibility with all of them. + ## Protocol version is currently ignored by default and TLS is used. + ## With `-d:openssl10`, only SSLv23 and TLSv1 may be used. ## ## There are three options for verify mode: ## `CVerifyNone`: certificates are not verified; @@ -582,10 +640,10 @@ when defineSsl: ## ## CA certificates will be loaded, in the following order, from: ## - ## - caFile, caDir, parameters, if set - ## - if `verifyMode` is set to `CVerifyPeerUseEnvVars`, - ## the SSL_CERT_FILE and SSL_CERT_DIR environment variables are used - ## - a set of files and directories from the `ssl_certs <ssl_certs.html>`_ file. + ## - caFile, caDir, parameters, if set + ## - if `verifyMode` is set to `CVerifyPeerUseEnvVars`, + ## the SSL_CERT_FILE and SSL_CERT_DIR environment variables are used + ## - a set of files and directories from the `ssl_certs <ssl_certs.html>`_ file. ## ## The last two parameters specify the certificate file path and the key file ## path, a server socket will most likely not work without these. @@ -595,31 +653,39 @@ when defineSsl: ## or using ECDSA: ## - `openssl ecparam -out mykey.pem -name secp256k1 -genkey` ## - `openssl req -new -key mykey.pem -x509 -nodes -days 365 -out mycert.pem` - var newCTX: SslCtx - case protVersion - of protSSLv23: - newCTX = SSL_CTX_new(SSLv23_method()) # SSlv2,3 and TLS1 support. - of protSSLv2: - raiseSSLError("SSLv2 is no longer secure and has been deprecated, use protSSLv23") - of protSSLv3: - raiseSSLError("SSLv3 is no longer secure and has been deprecated, use protSSLv23") - of protTLSv1: - newCTX = SSL_CTX_new(TLSv1_method()) + var mtd: PSSL_METHOD + when defined(openssl10): + case protVersion + of protSSLv23: + mtd = SSLv23_method() + of protSSLv2: + raiseSSLError("SSLv2 is no longer secure and has been deprecated, use protSSLv23") + of protSSLv3: + raiseSSLError("SSLv3 is no longer secure and has been deprecated, use protSSLv23") + of protTLSv1: + mtd = TLSv1_method() + else: + mtd = TLS_method() + if mtd == nil: + raiseSSLError("Failed to create TLS context") + var newCTX = SSL_CTX_new(mtd) + if newCTX == nil: + raiseSSLError("Failed to create TLS context") if newCTX.SSL_CTX_set_cipher_list(cipherList) != 1: raiseSSLError() when not defined(openssl10) and not defined(libressl): let sslVersion = getOpenSSLVersion() - if sslVersion >= 0x010101000 and not sslVersion == 0x020000000: + if sslVersion >= 0x010101000 and sslVersion != 0x020000000: # In OpenSSL >= 1.1.1, TLSv1.3 cipher suites can only be configured via # this API. - if newCTX.SSL_CTX_set_ciphersuites(cipherList) != 1: + if newCTX.SSL_CTX_set_ciphersuites(ciphersuites) != 1: raiseSSLError() # Automatically the best ECDH curve for client exchange. Without this, ECDH # ciphers will be ignored by the server. # # From OpenSSL >= 1.1.0, this setting is set by default and can't be - # overriden. + # overridden. if newCTX.SSL_CTX_set_ecdh_auto(1) != 1: raiseSSLError() @@ -644,15 +710,20 @@ when defineSsl: if verifyMode != CVerifyNone: # Use the caDir and caFile parameters if set if caDir != "" or caFile != "": - if newCTX.SSL_CTX_load_verify_locations(caFile, caDir) != VerifySuccess: + if newCTX.SSL_CTX_load_verify_locations(if caFile == "": nil else: caFile.cstring, if caDir == "": nil else: caDir.cstring) != VerifySuccess: raise newException(IOError, "Failed to load SSL/TLS CA certificate(s).") else: # Scan for certs in known locations. For CVerifyPeerUseEnvVars also scan # the SSL_CERT_FILE and SSL_CERT_DIR env vars var found = false - for fn in scanSSLCertificates(): - if newCTX.SSL_CTX_load_verify_locations(fn, nil) == VerifySuccess: + let useEnvVars = (if verifyMode == CVerifyPeerUseEnvVars: true else: false) + for fn in scanSSLCertificates(useEnvVars = useEnvVars): + if fn.extractFilename == "": + if newCTX.SSL_CTX_load_verify_locations(nil, cstring(fn.normalizePathEnd(false))) == VerifySuccess: + found = true + break + elif newCTX.SSL_CTX_load_verify_locations(cstring(fn), nil) == VerifySuccess: found = true break if not found: @@ -685,17 +756,16 @@ when defineSsl: return ctx.getExtraInternal().clientGetPskFunc proc pskClientCallback(ssl: SslPtr; hint: cstring; identity: cstring; - max_identity_len: cuint; psk: ptr cuchar; + max_identity_len: cuint; psk: ptr uint8; max_psk_len: cuint): cuint {.cdecl.} = let ctx = SslContext(context: ssl.SSL_get_SSL_CTX) let hintString = if hint == nil: "" else: $hint let (identityString, pskString) = (ctx.clientGetPskFunc)(hintString) - if psk.len.cuint > max_psk_len: + if pskString.len.cuint > max_psk_len: return 0 if identityString.len.cuint >= max_identity_len: return 0 - - copyMem(identity, identityString.cstring, pskString.len + 1) # with the last zero byte + copyMem(identity, identityString.cstring, identityString.len + 1) # with the last zero byte copyMem(psk, pskString.cstring, pskString.len) return pskString.len.cuint @@ -712,11 +782,11 @@ when defineSsl: proc serverGetPskFunc*(ctx: SslContext): SslServerGetPskFunc = return ctx.getExtraInternal().serverGetPskFunc - proc pskServerCallback(ssl: SslCtx; identity: cstring; psk: ptr cuchar; + proc pskServerCallback(ssl: SslCtx; identity: cstring; psk: ptr uint8; max_psk_len: cint): cuint {.cdecl.} = let ctx = SslContext(context: ssl.SSL_get_SSL_CTX) let pskString = (ctx.serverGetPskFunc)($identity) - if psk.len.cint > max_psk_len: + if pskString.len.cint > max_psk_len: return 0 copyMem(psk, pskString.cstring, pskString.len) @@ -759,23 +829,28 @@ when defineSsl: if SSL_set_fd(socket.sslHandle, socket.fd) != 1: raiseSSLError() - proc checkCertName(socket: Socket, hostname: string) = + proc checkCertName(socket: Socket, hostname: string) {.raises: [SslError], tags:[RootEffect].} = ## Check if the certificate Subject Alternative Name (SAN) or Subject CommonName (CN) matches hostname. ## Wildcards match only in the left-most label. ## When name starts with a dot it will be matched by a certificate valid for any subdomain when not defined(nimDisableCertificateValidation) and not defined(windows): assert socket.isSsl - let certificate = socket.sslHandle.SSL_get_peer_certificate() - if certificate.isNil: - raiseSSLError("No SSL certificate found.") - - const X509_CHECK_FLAG_ALWAYS_CHECK_SUBJECT = 0x1.cuint - const size = 1024 - var peername: string = newString(size) - let match = certificate.X509_check_host(hostname.cstring, hostname.len.cint, - X509_CHECK_FLAG_ALWAYS_CHECK_SUBJECT, peername) - if match != 1: - raiseSSLError("SSL Certificate check failed.") + try: + let certificate = socket.sslHandle.SSL_get_peer_certificate() + if certificate.isNil: + raiseSSLError("No SSL certificate found.") + + const X509_CHECK_FLAG_ALWAYS_CHECK_SUBJECT = 0x1.cuint + # https://www.openssl.org/docs/man1.1.1/man3/X509_check_host.html + let match = certificate.X509_check_host(hostname.cstring, hostname.len.cint, + X509_CHECK_FLAG_ALWAYS_CHECK_SUBJECT, nil) + # https://www.openssl.org/docs/man1.1.1/man3/SSL_get_peer_certificate.html + X509_free(certificate) + if match != 1: + raiseSSLError("SSL Certificate check failed.") + + except LibraryError: + raiseSSLError("SSL import failed") proc wrapConnectedSocket*(ctx: SslContext, socket: Socket, handshake: SslHandshakeType, @@ -802,6 +877,7 @@ when defineSsl: let ret = SSL_connect(socket.sslHandle) socketError(socket, ret) when not defined(nimDisableCertificateValidation) and not defined(windows): + # FIXME: this should be skipped on CVerifyNone if hostname.len > 0 and not isIpAddress(hostname): socket.checkCertName(hostname) of handshakeAsServer: @@ -954,14 +1030,16 @@ proc bindAddr*(socket: Socket, port = Port(0), address = "") {. var aiList = getAddrInfo(realaddr, port, socket.domain) if bindAddr(socket.fd, aiList.ai_addr, aiList.ai_addrlen.SockLen) < 0'i32: - freeaddrinfo(aiList) - raiseOSError(osLastError()) - freeaddrinfo(aiList) + freeAddrInfo(aiList) + var address2: string + address2.addQuoted address + raiseOSError(osLastError(), "address: $# port: $#" % [address2, $port]) + freeAddrInfo(aiList) proc acceptAddr*(server: Socket, client: var owned(Socket), address: var string, flags = {SocketFlag.SafeDisconn}, inheritable = defined(nimInheritHandles)) {. - tags: [ReadIOEffect], gcsafe, locks: 0.} = + tags: [ReadIOEffect], gcsafe.} = ## Blocks until a connection is being made from a client. When a connection ## is made sets `client` to the client socket and `address` to the address ## of the connecting client. @@ -1076,7 +1154,7 @@ proc accept*(server: Socket, client: var owned(Socket), acceptAddr(server, client, addrDummy, flags) when defined(posix) and not defined(lwip): - from posix import Sigset, sigwait, sigismember, sigemptyset, sigaddset, + from std/posix import Sigset, sigwait, sigismember, sigemptyset, sigaddset, sigprocmask, pthread_sigmask, SIGPIPE, SIG_BLOCK, SIG_UNBLOCK template blockSigpipe(body: untyped): untyped = @@ -1190,9 +1268,9 @@ proc close*(socket: Socket, flags = {SocketFlag.SafeDisconn}) = socket.fd = osInvalidSocket when defined(posix): - from posix import TCP_NODELAY + from std/posix import TCP_NODELAY else: - from winlean import TCP_NODELAY + from std/winlean import TCP_NODELAY proc toCInt*(opt: SOBool): cint = ## Converts a `SOBool` into its Socket Option cint representation. @@ -1219,32 +1297,31 @@ proc getLocalAddr*(socket: Socket): (string, Port) = ## This is high-level interface for `getsockname`:idx:. getLocalAddr(socket.fd, socket.domain) -proc getPeerAddr*(socket: Socket): (string, Port) = - ## Get the socket's peer address and port number. - ## - ## This is high-level interface for `getpeername`:idx:. - getPeerAddr(socket.fd, socket.domain) +when not useNimNetLite: + proc getPeerAddr*(socket: Socket): (string, Port) = + ## Get the socket's peer address and port number. + ## + ## This is high-level interface for `getpeername`:idx:. + getPeerAddr(socket.fd, socket.domain) proc setSockOpt*(socket: Socket, opt: SOBool, value: bool, level = SOL_SOCKET) {.tags: [WriteIOEffect].} = ## Sets option `opt` to a boolean value specified by `value`. - ## - ## .. code-block:: Nim - ## var socket = newSocket() - ## socket.setSockOpt(OptReusePort, true) - ## socket.setSockOpt(OptNoDelay, true, level=IPPROTO_TCP.toInt) - ## + runnableExamples("-r:off"): + let socket = newSocket() + socket.setSockOpt(OptReusePort, true) + socket.setSockOpt(OptNoDelay, true, level = IPPROTO_TCP.cint) var valuei = cint(if value: 1 else: 0) setSockOptInt(socket.fd, cint(level), toCInt(opt), valuei) -when defined(posix) or defined(nimdoc): +when defined(nimdoc) or (defined(posix) and not useNimNetLite): proc connectUnix*(socket: Socket, path: string) = ## Connects to Unix socket on `path`. ## This only works on Unix-style systems: Mac OS X, BSD and Linux when not defined(nimdoc): var socketAddr = makeUnixAddr(path) if socket.fd.connect(cast[ptr SockAddr](addr socketAddr), - (sizeof(socketAddr.sun_family) + path.len).SockLen) != 0'i32: + (offsetOf(socketAddr, sun_path) + path.len + 1).SockLen) != 0'i32: raiseOSError(osLastError()) proc bindUnix*(socket: Socket, path: string) = @@ -1253,10 +1330,10 @@ when defined(posix) or defined(nimdoc): when not defined(nimdoc): var socketAddr = makeUnixAddr(path) if socket.fd.bindAddr(cast[ptr SockAddr](addr socketAddr), - (sizeof(socketAddr.sun_family) + path.len).SockLen) != 0'i32: + (offsetOf(socketAddr, sun_path) + path.len + 1).SockLen) != 0'i32: raiseOSError(osLastError()) -when defined(ssl): +when defineSsl: proc gotHandshake*(socket: Socket): bool = ## Determines whether a handshake has occurred between a client (`socket`) ## and the server that `socket` is connected to. @@ -1277,14 +1354,6 @@ proc hasDataBuffered*(s: Socket): bool = if s.isSsl and not result: result = s.sslHasPeekChar -proc select(readfd: Socket, timeout = 500): int = - ## Used for socket operation timeouts. - if readfd.hasDataBuffered: - return 1 - - var fds = @[readfd.fd] - result = selectRead(fds, timeout) - proc isClosed(socket: Socket): bool = socket.fd == osInvalidSocket @@ -1398,7 +1467,9 @@ proc waitFor(socket: Socket, waited: var Duration, timeout, size: int, return min(sslPending, size) var startTime = getMonoTime() - let selRet = select(socket, (timeout - waited.inMilliseconds).int) + let selRet = if socket.hasDataBuffered: 1 + else: + timeoutRead(socket.fd, (timeout - waited.inMilliseconds).int) if selRet < 0: raiseOSError(osLastError()) if selRet != 1: raise newException(TimeoutError, "Call to '" & funcName & "' timed out.") @@ -1426,7 +1497,7 @@ proc recv*(socket: Socket, data: var string, size: int, timeout = -1, flags = {SocketFlag.SafeDisconn}): int = ## Higher-level version of `recv`. ## - ## Reads **up to** `size` bytes from `socket` into `buf`. + ## Reads **up to** `size` bytes from `socket` into `data`. ## ## For buffered sockets this function will attempt to read all the requested ## data. It will read this data in `BufferSize` chunks. @@ -1443,8 +1514,6 @@ proc recv*(socket: Socket, data: var string, size: int, timeout = -1, ## A timeout may be specified in milliseconds, if enough data is not received ## within the time specified a TimeoutError exception will be raised. ## - ## **Note**: `data` must be initialised. - ## ## .. warning:: Only the `SafeDisconn` flag is currently supported. data.setLen(size) result = @@ -1463,7 +1532,7 @@ proc recv*(socket: Socket, size: int, timeout = -1, flags = {SocketFlag.SafeDisconn}): string {.inline.} = ## Higher-level version of `recv` which returns a string. ## - ## Reads **up to** `size` bytes from `socket` into `buf`. + ## Reads **up to** `size` bytes from `socket` into the result. ## ## For buffered sockets this function will attempt to read all the requested ## data. It will read this data in `BufferSize` chunks. @@ -1534,6 +1603,7 @@ proc readLine*(socket: Socket, line: var string, timeout = -1, if flags.isDisconnectionError(lastError): setLen(line, 0) socket.socketError(n, lastError = lastError, flags = flags) + return var waited: Duration @@ -1583,11 +1653,12 @@ proc recvLine*(socket: Socket, timeout = -1, result = "" readLine(socket, result, timeout, flags, maxLength) -proc recvFrom*(socket: Socket, data: var string, length: int, - address: var string, port: var Port, flags = 0'i32): int {. +proc recvFrom*[T: string | IpAddress](socket: Socket, data: var string, length: int, + address: var T, port: var Port, flags = 0'i32): int {. tags: [ReadIOEffect].} = ## Receives data from `socket`. This function should normally be used with - ## connection-less sockets (UDP sockets). + ## connection-less sockets (UDP sockets). The source address of the data + ## packet is stored in the `address` argument as either a string or an IpAddress. ## ## If an error occurs an OSError exception will be raised. Otherwise the return ## value will be the length of data received. @@ -1596,31 +1667,37 @@ proc recvFrom*(socket: Socket, data: var string, length: int, ## so when `socket` is buffered the non-buffered implementation will be ## used. Therefore if `socket` contains something in its buffer this ## function will make no effort to return it. - template adaptRecvFromToDomain(domain: Domain) = - var addrLen = sizeof(sockAddress).SockLen + template adaptRecvFromToDomain(sockAddress: untyped, domain: Domain) = + var addrLen = SockLen(sizeof(sockAddress)) result = recvfrom(socket.fd, cstring(data), length.cint, flags.cint, cast[ptr SockAddr](addr(sockAddress)), addr(addrLen)) if result != -1: data.setLen(result) - address = getAddrString(cast[ptr SockAddr](addr(sockAddress))) - when domain == AF_INET6: - port = ntohs(sockAddress.sin6_port).Port + + when typeof(address) is string: + address = getAddrString(cast[ptr SockAddr](addr(sockAddress))) + when domain == AF_INET6: + port = ntohs(sockAddress.sin6_port).Port + else: + port = ntohs(sockAddress.sin_port).Port else: - port = ntohs(sockAddress.sin_port).Port + data.setLen(result) + sockAddress.fromSockAddr(addrLen, address, port) else: raiseOSError(osLastError()) assert(socket.protocol != IPPROTO_TCP, "Cannot `recvFrom` on a TCP socket") # TODO: Buffered sockets data.setLen(length) + case socket.domain of AF_INET6: var sockAddress: Sockaddr_in6 - adaptRecvFromToDomain(AF_INET6) + adaptRecvFromToDomain(sockAddress, AF_INET6) of AF_INET: var sockAddress: Sockaddr_in - adaptRecvFromToDomain(AF_INET) + adaptRecvFromToDomain(sockAddress, AF_INET) else: raise newException(ValueError, "Unknown socket address family") @@ -1659,15 +1736,36 @@ proc send*(socket: Socket, data: pointer, size: int): int {. result = send(socket.fd, data, size, int32(MSG_NOSIGNAL)) proc send*(socket: Socket, data: string, - flags = {SocketFlag.SafeDisconn}) {.tags: [WriteIOEffect].} = - ## sends data to a socket. - let sent = send(socket, cstring(data), data.len) - if sent < 0: - let lastError = osLastError() - socketError(socket, lastError = lastError, flags = flags) + flags = {SocketFlag.SafeDisconn}, maxRetries = 100) {.tags: [WriteIOEffect].} = + ## Sends data to a socket. Will try to send all the data by handling interrupts + ## and incomplete writes up to `maxRetries`. + var written = 0 + var attempts = 0 + while data.len - written > 0: + let sent = send(socket, cstring(data), data.len) + + if sent < 0: + let lastError = osLastError() + let isBlockingErr = + when defined(nimdoc): + false + elif useWinVersion: + lastError.int32 == WSAEINTR or + lastError.int32 == WSAEWOULDBLOCK + else: + lastError.int32 == EINTR or + lastError.int32 == EWOULDBLOCK or + lastError.int32 == EAGAIN - if sent != data.len: - raiseOSError(osLastError(), "Could not send all data.") + if not isBlockingErr: + let lastError = osLastError() + socketError(socket, lastError = lastError, flags = flags) + else: + attempts.inc() + if attempts > maxRetries: + raiseOSError(osLastError(), "Could not send all data.") + else: + written.inc(sent) template `&=`*(socket: Socket; data: typed) = ## an alias for 'send'. @@ -1683,7 +1781,8 @@ proc sendTo*(socket: Socket, address: string, port: Port, data: pointer, tags: [WriteIOEffect].} = ## This proc sends `data` to the specified `address`, ## which may be an IP address or a hostname, if a hostname is specified - ## this function will try each IP of that hostname. + ## this function will try each IP of that hostname. This function + ## should normally be used with connection-less sockets (UDP sockets). ## ## If an error occurs an OSError exception will be raised. ## @@ -1707,7 +1806,7 @@ proc sendTo*(socket: Socket, address: string, port: Port, data: pointer, it = it.ai_next let osError = osLastError() - freeaddrinfo(aiList) + freeAddrInfo(aiList) if not success: raiseOSError(osError) @@ -1718,11 +1817,37 @@ proc sendTo*(socket: Socket, address: string, port: Port, ## which may be an IP address or a hostname, if a hostname is specified ## this function will try each IP of that hostname. ## + ## Generally for use with connection-less (UDP) sockets. + ## ## If an error occurs an OSError exception will be raised. ## ## This is the high-level version of the above `sendTo` function. socket.sendTo(address, port, cstring(data), data.len, socket.domain) +proc sendTo*(socket: Socket, address: IpAddress, port: Port, + data: string, flags = 0'i32): int {. + discardable, tags: [WriteIOEffect].} = + ## This proc sends `data` to the specified `IpAddress` and returns + ## the number of bytes written. + ## + ## Generally for use with connection-less (UDP) sockets. + ## + ## If an error occurs an OSError exception will be raised. + ## + ## This is the high-level version of the above `sendTo` function. + assert(socket.protocol != IPPROTO_TCP, "Cannot `sendTo` on a TCP socket") + assert(not socket.isClosed, "Cannot `sendTo` on a closed socket") + + var sa: Sockaddr_storage + var sl: SockLen + toSockAddr(address, port, sa, sl) + result = sendto(socket.fd, cstring(data), data.len().cint, flags.cint, + cast[ptr SockAddr](addr sa), sl) + + if result == -1'i32: + let osError = osLastError() + raiseOSError(osError) + proc isSsl*(socket: Socket): bool = ## Determines whether `socket` is a SSL socket. @@ -1734,6 +1859,16 @@ proc isSsl*(socket: Socket): bool = proc getFd*(socket: Socket): SocketHandle = return socket.fd ## Returns the socket's file descriptor +when defined(zephyr) or defined(nimNetSocketExtras): # Remove in future + proc getDomain*(socket: Socket): Domain = return socket.domain + ## Returns the socket's domain + + proc getType*(socket: Socket): SockType = return socket.sockType + ## Returns the socket's type + + proc getProtocol*(socket: Socket): Protocol = return socket.protocol + ## Returns the socket's protocol + when defined(nimHasStyleChecks): {.push styleChecks: off.} @@ -1785,14 +1920,18 @@ proc `==`*(lhs, rhs: IpAddress): bool = proc `$`*(address: IpAddress): string = ## Converts an IpAddress into the textual representation - result = "" case address.family of IpAddressFamily.IPv4: - for i in 0 .. 3: - if i != 0: - result.add('.') - result.add($address.address_v4[i]) + result = newStringOfCap(15) + result.addInt address.address_v4[0] + result.add '.' + result.addInt address.address_v4[1] + result.add '.' + result.addInt address.address_v4[2] + result.add '.' + result.addInt address.address_v4[3] of IpAddressFamily.IPv6: + result = newStringOfCap(39) var currentZeroStart = -1 currentZeroCount = 0 @@ -1886,7 +2025,7 @@ proc dial*(address: string, port: Port, # network system problem (e.g. not enough FDs), and not an unreachable # address. let err = osLastError() - freeaddrinfo(aiList) + freeAddrInfo(aiList) closeUnusedFds() raiseOSError(err) fdPerDomain[ord(domain)] = lastFd @@ -1895,18 +2034,20 @@ proc dial*(address: string, port: Port, break lastError = osLastError() it = it.ai_next - freeaddrinfo(aiList) + freeAddrInfo(aiList) closeUnusedFds(ord(domain)) if success: - result = newSocket(lastFd, domain, sockType, protocol) + result = newSocket(lastFd, domain, sockType, protocol, buffered) elif lastError != 0.OSErrorCode: + lastFd.close() raiseOSError(lastError) else: + lastFd.close() raise newException(IOError, "Couldn't resolve address: " & address) proc connect*(socket: Socket, address: string, - port = Port(0)) {.tags: [ReadIOEffect].} = + port = Port(0)) {.tags: [ReadIOEffect, RootEffect].} = ## Connects socket to `address`:`port`. `Address` can be an IP address or a ## host name. If `address` is a host name, this function will try each IP ## of that host name. `htons` is already performed on `port` so you must @@ -1925,7 +2066,7 @@ proc connect*(socket: Socket, address: string, else: lastError = osLastError() it = it.ai_next - freeaddrinfo(aiList) + freeAddrInfo(aiList) if not success: raiseOSError(lastError) when defineSsl: @@ -1977,11 +2118,11 @@ proc connectAsync(socket: Socket, name: string, port = Port(0), it = it.ai_next - freeaddrinfo(aiList) + freeAddrInfo(aiList) if not success: raiseOSError(lastError) proc connect*(socket: Socket, address: string, port = Port(0), - timeout: int) {.tags: [ReadIOEffect, WriteIOEffect].} = + timeout: int) {.tags: [ReadIOEffect, WriteIOEffect, RootEffect].} = ## Connects to server as specified by `address` on port specified by `port`. ## ## The `timeout` parameter specifies the time in milliseconds to allow for @@ -1989,8 +2130,7 @@ proc connect*(socket: Socket, address: string, port = Port(0), socket.fd.setBlocking(false) socket.connectAsync(address, port, socket.domain) - var s = @[socket.fd] - if selectWrite(s, timeout) != 1: + if timeoutWrite(socket.fd, timeout) != 1: raise newException(TimeoutError, "Call to 'connect' timed out.") else: let res = getSockOptInt(socket.fd, SOL_SOCKET, SO_ERROR) @@ -2021,10 +2161,8 @@ proc getPrimaryIPAddr*(dest = parseIpAddress("8.8.8.8")): IpAddress = ## ## Supports IPv4 and v6. ## Raises OSError if external networking is not set up. - ## - ## .. code-block:: Nim - ## echo $getPrimaryIPAddr() # "192.168.1.2" - + runnableExamples("-r:off"): + echo getPrimaryIPAddr() # "192.168.1.2" let socket = if dest.family == IpAddressFamily.IPv4: newSocket(AF_INET, SOCK_DGRAM, IPPROTO_UDP) |