diff options
Diffstat (limited to 'lib/pure/net.nim')
-rw-r--r-- | lib/pure/net.nim | 411 |
1 files changed, 237 insertions, 174 deletions
diff --git a/lib/pure/net.nim b/lib/pure/net.nim index 56f8b9399..629e916fa 100644 --- a/lib/pure/net.nim +++ b/lib/pure/net.nim @@ -66,7 +66,7 @@ ## {.deadCodeElim: on.} -import nativesockets, os, strutils, parseutils, times, sets +import nativesockets, os, strutils, parseutils, times, sets, options export Port, `$`, `==` export Domain, SockType, Protocol @@ -237,6 +237,180 @@ proc newSocket*(domain: Domain = AF_INET, sockType: SockType = SOCK_STREAM, raiseOSError(osLastError()) result = newSocket(fd, domain, sockType, protocol, buffered) +proc parseIPv4Address(address_str: string): IpAddress = + ## Parses IPv4 adresses + ## Raises EInvalidValue on errors + var + byteCount = 0 + currentByte:uint16 = 0 + seperatorValid = false + + result.family = IpAddressFamily.IPv4 + + for i in 0 .. high(address_str): + if address_str[i] in strutils.Digits: # Character is a number + currentByte = currentByte * 10 + + cast[uint16](ord(address_str[i]) - ord('0')) + if currentByte > 255'u16: + raise newException(ValueError, + "Invalid IP Address. Value is out of range") + seperatorValid = true + elif address_str[i] == '.': # IPv4 address separator + if not seperatorValid or byteCount >= 3: + raise newException(ValueError, + "Invalid IP Address. The address consists of too many groups") + result.address_v4[byteCount] = cast[uint8](currentByte) + currentByte = 0 + byteCount.inc + seperatorValid = false + else: + raise newException(ValueError, + "Invalid IP Address. Address contains an invalid character") + + if byteCount != 3 or not seperatorValid: + raise newException(ValueError, "Invalid IP Address") + result.address_v4[byteCount] = cast[uint8](currentByte) + +proc parseIPv6Address(address_str: string): IpAddress = + ## Parses IPv6 adresses + ## Raises EInvalidValue on errors + result.family = IpAddressFamily.IPv6 + if address_str.len < 2: + raise newException(ValueError, "Invalid IP Address") + + var + groupCount = 0 + currentGroupStart = 0 + currentShort:uint32 = 0 + seperatorValid = true + dualColonGroup = -1 + lastWasColon = false + v4StartPos = -1 + byteCount = 0 + + for i,c in address_str: + if c == ':': + if not seperatorValid: + raise newException(ValueError, + "Invalid IP Address. Address contains an invalid seperator") + if lastWasColon: + if dualColonGroup != -1: + raise newException(ValueError, + "Invalid IP Address. Address contains more than one \"::\" seperator") + dualColonGroup = groupCount + seperatorValid = false + elif i != 0 and i != high(address_str): + if groupCount >= 8: + raise newException(ValueError, + "Invalid IP Address. The address consists of too many groups") + result.address_v6[groupCount*2] = cast[uint8](currentShort shr 8) + result.address_v6[groupCount*2+1] = cast[uint8](currentShort and 0xFF) + currentShort = 0 + groupCount.inc() + if dualColonGroup != -1: seperatorValid = false + elif i == 0: # only valid if address starts with :: + if address_str[1] != ':': + raise newException(ValueError, + "Invalid IP Address. Address may not start with \":\"") + else: # i == high(address_str) - only valid if address ends with :: + if address_str[high(address_str)-1] != ':': + raise newException(ValueError, + "Invalid IP Address. Address may not end with \":\"") + lastWasColon = true + currentGroupStart = i + 1 + elif c == '.': # Switch to parse IPv4 mode + if i < 3 or not seperatorValid or groupCount >= 7: + raise newException(ValueError, "Invalid IP Address") + v4StartPos = currentGroupStart + currentShort = 0 + seperatorValid = false + break + elif c in strutils.HexDigits: + if c in strutils.Digits: # Normal digit + currentShort = (currentShort shl 4) + cast[uint32](ord(c) - ord('0')) + elif c >= 'a' and c <= 'f': # Lower case hex + currentShort = (currentShort shl 4) + cast[uint32](ord(c) - ord('a')) + 10 + else: # Upper case hex + currentShort = (currentShort shl 4) + cast[uint32](ord(c) - ord('A')) + 10 + if currentShort > 65535'u32: + raise newException(ValueError, + "Invalid IP Address. Value is out of range") + lastWasColon = false + seperatorValid = true + else: + raise newException(ValueError, + "Invalid IP Address. Address contains an invalid character") + + + if v4StartPos == -1: # Don't parse v4. Copy the remaining v6 stuff + if seperatorValid: # Copy remaining data + if groupCount >= 8: + raise newException(ValueError, + "Invalid IP Address. The address consists of too many groups") + result.address_v6[groupCount*2] = cast[uint8](currentShort shr 8) + result.address_v6[groupCount*2+1] = cast[uint8](currentShort and 0xFF) + groupCount.inc() + else: # Must parse IPv4 address + for i,c in address_str[v4StartPos..high(address_str)]: + if c in strutils.Digits: # Character is a number + currentShort = currentShort * 10 + cast[uint32](ord(c) - ord('0')) + if currentShort > 255'u32: + raise newException(ValueError, + "Invalid IP Address. Value is out of range") + seperatorValid = true + elif c == '.': # IPv4 address separator + if not seperatorValid or byteCount >= 3: + raise newException(ValueError, "Invalid IP Address") + result.address_v6[groupCount*2 + byteCount] = cast[uint8](currentShort) + currentShort = 0 + byteCount.inc() + seperatorValid = false + else: # Invalid character + raise newException(ValueError, + "Invalid IP Address. Address contains an invalid character") + + if byteCount != 3 or not seperatorValid: + raise newException(ValueError, "Invalid IP Address") + result.address_v6[groupCount*2 + byteCount] = cast[uint8](currentShort) + groupCount += 2 + + # Shift and fill zeros in case of :: + if groupCount > 8: + raise newException(ValueError, + "Invalid IP Address. The address consists of too many groups") + elif groupCount < 8: # must fill + if dualColonGroup == -1: + raise newException(ValueError, + "Invalid IP Address. The address consists of too few groups") + var toFill = 8 - groupCount # The number of groups to fill + var toShift = groupCount - dualColonGroup # Nr of known groups after :: + for i in 0..2*toShift-1: # shift + result.address_v6[15-i] = result.address_v6[groupCount*2-i-1] + for i in 0..2*toFill-1: # fill with 0s + result.address_v6[dualColonGroup*2+i] = 0 + elif dualColonGroup != -1: + raise newException(ValueError, + "Invalid IP Address. The address consists of too many groups") + +proc parseIpAddress*(address_str: string): IpAddress = + ## Parses an IP address + ## Raises EInvalidValue on error + if address_str == nil: + raise newException(ValueError, "IP Address string is nil") + if address_str.contains(':'): + return parseIPv6Address(address_str) + else: + return parseIPv4Address(address_str) + +proc isIpAddress*(address_str: string): bool {.tags: [].} = + ## Checks if a string is an IP address + ## Returns true if it is, false otherwise + try: + discard parseIpAddress(address_str) + except ValueError: + return false + return true + when defineSsl: CRYPTO_malloc_init() SslLibraryInit() @@ -438,9 +612,12 @@ when defineSsl: raiseSSLError() proc wrapConnectedSocket*(ctx: SSLContext, socket: Socket, - handshake: SslHandshakeType) = + handshake: SslHandshakeType, + hostname: string = nil) = ## Wraps a connected socket in an SSL context. This function effectively ## turns ``socket`` into an SSL socket. + ## ``hostname`` should be specified so that the client knows which hostname + ## the server certificate should be validated against. ## ## This should be called on a connected socket, and will perform ## an SSL handshake immediately. @@ -450,6 +627,10 @@ when defineSsl: wrapSocket(ctx, socket) case handshake of handshakeAsClient: + if not hostname.isNil and not isIpAddress(hostname): + # Discard result in case OpenSSL version doesn't support SNI, or we're + # not using TLSv1+ + discard SSL_set_tlsext_host_name(socket.sslHandle, hostname) let ret = SSLConnect(socket.sslHandle) socketError(socket, ret) of handshakeAsServer: @@ -669,7 +850,7 @@ proc close*(socket: Socket) = ## Closes a socket. try: when defineSsl: - if socket.isSSL: + if socket.isSSL and socket.sslHandle != nil: ErrClearError() # As we are closing the underlying socket immediately afterwards, # it is valid, under the TLS standard, to perform a unidirectional @@ -1302,181 +1483,63 @@ proc `$`*(address: IpAddress): string = mask = mask shr 4 printedLastGroup = true -proc parseIPv4Address(address_str: string): IpAddress = - ## Parses IPv4 adresses - ## Raises EInvalidValue on errors - var - byteCount = 0 - currentByte:uint16 = 0 - seperatorValid = false +proc dial*(address: string, port: Port, + protocol = IPPROTO_TCP, buffered = true): Socket + {.tags: [ReadIOEffect, WriteIOEffect].} = + ## Establishes connection to the specified ``address``:``port`` pair via the + ## specified protocol. The procedure iterates through possible + ## resolutions of the ``address`` until it succeeds, meaning that it + ## seamlessly works with both IPv4 and IPv6. + ## Returns Socket ready to send or receive data. + let sockType = protocol.toSockType() + + let aiList = getAddrInfo(address, port, AF_UNSPEC, sockType, protocol) + + var fdPerDomain: array[low(Domain).ord..high(Domain).ord, SocketHandle] + for i in low(fdPerDomain)..high(fdPerDomain): + fdPerDomain[i] = osInvalidSocket + template closeUnusedFds(domainToKeep = -1) {.dirty.} = + for i, fd in fdPerDomain: + if fd != osInvalidSocket and i != domainToKeep: + fd.close() - result.family = IpAddressFamily.IPv4 - - for i in 0 .. high(address_str): - if address_str[i] in strutils.Digits: # Character is a number - currentByte = currentByte * 10 + - cast[uint16](ord(address_str[i]) - ord('0')) - if currentByte > 255'u16: - raise newException(ValueError, - "Invalid IP Address. Value is out of range") - seperatorValid = true - elif address_str[i] == '.': # IPv4 address separator - if not seperatorValid or byteCount >= 3: - raise newException(ValueError, - "Invalid IP Address. The address consists of too many groups") - result.address_v4[byteCount] = cast[uint8](currentByte) - currentByte = 0 - byteCount.inc - seperatorValid = false - else: - raise newException(ValueError, - "Invalid IP Address. Address contains an invalid character") - - if byteCount != 3 or not seperatorValid: - raise newException(ValueError, "Invalid IP Address") - result.address_v4[byteCount] = cast[uint8](currentByte) - -proc parseIPv6Address(address_str: string): IpAddress = - ## Parses IPv6 adresses - ## Raises EInvalidValue on errors - result.family = IpAddressFamily.IPv6 - if address_str.len < 2: - raise newException(ValueError, "Invalid IP Address") - - var - groupCount = 0 - currentGroupStart = 0 - currentShort:uint32 = 0 - seperatorValid = true - dualColonGroup = -1 - lastWasColon = false - v4StartPos = -1 - byteCount = 0 - - for i,c in address_str: - if c == ':': - if not seperatorValid: - raise newException(ValueError, - "Invalid IP Address. Address contains an invalid seperator") - if lastWasColon: - if dualColonGroup != -1: - raise newException(ValueError, - "Invalid IP Address. Address contains more than one \"::\" seperator") - dualColonGroup = groupCount - seperatorValid = false - elif i != 0 and i != high(address_str): - if groupCount >= 8: - raise newException(ValueError, - "Invalid IP Address. The address consists of too many groups") - result.address_v6[groupCount*2] = cast[uint8](currentShort shr 8) - result.address_v6[groupCount*2+1] = cast[uint8](currentShort and 0xFF) - currentShort = 0 - groupCount.inc() - if dualColonGroup != -1: seperatorValid = false - elif i == 0: # only valid if address starts with :: - if address_str[1] != ':': - raise newException(ValueError, - "Invalid IP Address. Address may not start with \":\"") - else: # i == high(address_str) - only valid if address ends with :: - if address_str[high(address_str)-1] != ':': - raise newException(ValueError, - "Invalid IP Address. Address may not end with \":\"") - lastWasColon = true - currentGroupStart = i + 1 - elif c == '.': # Switch to parse IPv4 mode - if i < 3 or not seperatorValid or groupCount >= 7: - raise newException(ValueError, "Invalid IP Address") - v4StartPos = currentGroupStart - currentShort = 0 - seperatorValid = false + var success = false + var lastError: OSErrorCode + var it = aiList + var domain: Domain + var lastFd: SocketHandle + while it != nil: + let domainOpt = it.ai_family.toKnownDomain() + if domainOpt.isNone: + it = it.ai_next + continue + domain = domainOpt.unsafeGet() + lastFd = fdPerDomain[ord(domain)] + if lastFd == osInvalidSocket: + lastFd = newNativeSocket(domain, sockType, protocol) + if lastFd == osInvalidSocket: + # we always raise if socket creation failed, because it means a + # network system problem (e.g. not enough FDs), and not an unreachable + # address. + let err = osLastError() + freeAddrInfo(aiList) + closeUnusedFds() + raiseOSError(err) + fdPerDomain[ord(domain)] = lastFd + if connect(lastFd, it.ai_addr, it.ai_addrlen.SockLen) == 0'i32: + success = true break - elif c in strutils.HexDigits: - if c in strutils.Digits: # Normal digit - currentShort = (currentShort shl 4) + cast[uint32](ord(c) - ord('0')) - elif c >= 'a' and c <= 'f': # Lower case hex - currentShort = (currentShort shl 4) + cast[uint32](ord(c) - ord('a')) + 10 - else: # Upper case hex - currentShort = (currentShort shl 4) + cast[uint32](ord(c) - ord('A')) + 10 - if currentShort > 65535'u32: - raise newException(ValueError, - "Invalid IP Address. Value is out of range") - lastWasColon = false - seperatorValid = true - else: - raise newException(ValueError, - "Invalid IP Address. Address contains an invalid character") - - - if v4StartPos == -1: # Don't parse v4. Copy the remaining v6 stuff - if seperatorValid: # Copy remaining data - if groupCount >= 8: - raise newException(ValueError, - "Invalid IP Address. The address consists of too many groups") - result.address_v6[groupCount*2] = cast[uint8](currentShort shr 8) - result.address_v6[groupCount*2+1] = cast[uint8](currentShort and 0xFF) - groupCount.inc() - else: # Must parse IPv4 address - for i,c in address_str[v4StartPos..high(address_str)]: - if c in strutils.Digits: # Character is a number - currentShort = currentShort * 10 + cast[uint32](ord(c) - ord('0')) - if currentShort > 255'u32: - raise newException(ValueError, - "Invalid IP Address. Value is out of range") - seperatorValid = true - elif c == '.': # IPv4 address separator - if not seperatorValid or byteCount >= 3: - raise newException(ValueError, "Invalid IP Address") - result.address_v6[groupCount*2 + byteCount] = cast[uint8](currentShort) - currentShort = 0 - byteCount.inc() - seperatorValid = false - else: # Invalid character - raise newException(ValueError, - "Invalid IP Address. Address contains an invalid character") - - if byteCount != 3 or not seperatorValid: - raise newException(ValueError, "Invalid IP Address") - result.address_v6[groupCount*2 + byteCount] = cast[uint8](currentShort) - groupCount += 2 - - # Shift and fill zeros in case of :: - if groupCount > 8: - raise newException(ValueError, - "Invalid IP Address. The address consists of too many groups") - elif groupCount < 8: # must fill - if dualColonGroup == -1: - raise newException(ValueError, - "Invalid IP Address. The address consists of too few groups") - var toFill = 8 - groupCount # The number of groups to fill - var toShift = groupCount - dualColonGroup # Nr of known groups after :: - for i in 0..2*toShift-1: # shift - result.address_v6[15-i] = result.address_v6[groupCount*2-i-1] - for i in 0..2*toFill-1: # fill with 0s - result.address_v6[dualColonGroup*2+i] = 0 - elif dualColonGroup != -1: - raise newException(ValueError, - "Invalid IP Address. The address consists of too many groups") - + lastError = osLastError() + it = it.ai_next + freeAddrInfo(aiList) + closeUnusedFds(ord(domain)) -proc parseIpAddress*(address_str: string): IpAddress = - ## Parses an IP address - ## Raises EInvalidValue on error - if address_str == nil: - raise newException(ValueError, "IP Address string is nil") - if address_str.contains(':'): - return parseIPv6Address(address_str) + if success: + result = newSocket(lastFd, domain, sockType, protocol) + elif lastError != 0.OSErrorCode: + raiseOSError(lastError) else: - return parseIPv4Address(address_str) - -proc isIpAddress*(address_str: string): bool {.tags: [].} = - ## Checks if a string is an IP address - ## Returns true if it is, false otherwise - try: - discard parseIpAddress(address_str) - except ValueError: - return false - return true - + raise newException(IOError, "Couldn't resolve address: " & address) proc connect*(socket: Socket, address: string, port = Port(0)) {.tags: [ReadIOEffect].} = |