diff options
author | Andreas Rumpf <rumpf_a@web.de> | 2017-05-04 22:25:37 +0200 |
---|---|---|
committer | Andreas Rumpf <rumpf_a@web.de> | 2017-05-04 22:25:37 +0200 |
commit | 73b7edf9de59f97ccfdc59a71a6a4146b230ae81 (patch) | |
tree | fa9d0b2b3f6682f1648fbe7f35b76bbb41c7e9f3 /lib | |
parent | c0669326357cab1af0cc93bf562eeb5b7be484a0 (diff) | |
parent | c24dc7944aae2aa108b2638a9c70a5edf5e47915 (diff) | |
download | Nim-73b7edf9de59f97ccfdc59a71a6a4146b230ae81.tar.gz |
Merge branch 'devel' of github.com:nim-lang/Nim into devel
Diffstat (limited to 'lib')
-rw-r--r-- | lib/pure/asyncnet.nim | 9 | ||||
-rw-r--r-- | lib/pure/httpclient.nim | 8 | ||||
-rw-r--r-- | lib/pure/net.nim | 358 | ||||
-rw-r--r-- | lib/pure/parseutils.nim | 2 |
4 files changed, 196 insertions, 181 deletions
diff --git a/lib/pure/asyncnet.nim b/lib/pure/asyncnet.nim index 11b7998b2..9f73bc3cf 100644 --- a/lib/pure/asyncnet.nim +++ b/lib/pure/asyncnet.nim @@ -647,9 +647,12 @@ when defineSsl: sslSetBio(socket.sslHandle, socket.bioIn, socket.bioOut) proc wrapConnectedSocket*(ctx: SslContext, socket: AsyncSocket, - handshake: SslHandshakeType) = + handshake: SslHandshakeType, + hostname: string = nil) = ## Wraps a connected socket in an SSL context. This function effectively ## turns ``socket`` into an SSL socket. + ## ``hostname`` should be specified so that the client knows which hostname + ## the server certificate should be validated against. ## ## This should be called on a connected socket, and will perform ## an SSL handshake immediately. @@ -660,6 +663,10 @@ when defineSsl: case handshake of handshakeAsClient: + if not hostname.isNil and not isIpAddress(hostname): + # Set the SNI address for this connection. This call can fail if + # we're not using TLSv1+. + discard SSL_set_tlsext_host_name(socket.sslHandle, hostname) sslSetConnectState(socket.sslHandle) of handshakeAsServer: sslSetAcceptState(socket.sslHandle) diff --git a/lib/pure/httpclient.nim b/lib/pure/httpclient.nim index 1b8a20b65..4f43177a8 100644 --- a/lib/pure/httpclient.nim +++ b/lib/pure/httpclient.nim @@ -512,7 +512,7 @@ proc request*(url: string, httpMethod: string, extraHeaders = "", raise newException(HttpRequestError, "The proxy server rejected a CONNECT request, " & "so a secure connection could not be established.") - sslContext.wrapConnectedSocket(s, handshakeAsClient) + sslContext.wrapConnectedSocket(s, handshakeAsClient, hostUrl.hostname) else: raise newException(HttpRequestError, "SSL support not available. Cannot connect via proxy over SSL") else: @@ -1060,7 +1060,8 @@ proc newConnection(client: HttpClient | AsyncHttpClient, when defined(ssl): if isSsl: try: - client.sslContext.wrapConnectedSocket(client.socket, handshakeAsClient) + client.sslContext.wrapConnectedSocket( + client.socket, handshakeAsClient, url.hostname) except: client.socket.close() raise getCurrentException() @@ -1102,7 +1103,8 @@ proc requestAux(client: HttpClient | AsyncHttpClient, url: string, raise newException(HttpRequestError, "The proxy server rejected a CONNECT request, " & "so a secure connection could not be established.") - client.sslContext.wrapConnectedSocket(client.socket, handshakeAsClient) + client.sslContext.wrapConnectedSocket( + client.socket, handshakeAsClient, requestUrl.hostname) client.proxy = nil else: raise newException(HttpRequestError, diff --git a/lib/pure/net.nim b/lib/pure/net.nim index d175bd537..629e916fa 100644 --- a/lib/pure/net.nim +++ b/lib/pure/net.nim @@ -237,6 +237,180 @@ proc newSocket*(domain: Domain = AF_INET, sockType: SockType = SOCK_STREAM, raiseOSError(osLastError()) result = newSocket(fd, domain, sockType, protocol, buffered) +proc parseIPv4Address(address_str: string): IpAddress = + ## Parses IPv4 adresses + ## Raises EInvalidValue on errors + var + byteCount = 0 + currentByte:uint16 = 0 + seperatorValid = false + + result.family = IpAddressFamily.IPv4 + + for i in 0 .. high(address_str): + if address_str[i] in strutils.Digits: # Character is a number + currentByte = currentByte * 10 + + cast[uint16](ord(address_str[i]) - ord('0')) + if currentByte > 255'u16: + raise newException(ValueError, + "Invalid IP Address. Value is out of range") + seperatorValid = true + elif address_str[i] == '.': # IPv4 address separator + if not seperatorValid or byteCount >= 3: + raise newException(ValueError, + "Invalid IP Address. The address consists of too many groups") + result.address_v4[byteCount] = cast[uint8](currentByte) + currentByte = 0 + byteCount.inc + seperatorValid = false + else: + raise newException(ValueError, + "Invalid IP Address. Address contains an invalid character") + + if byteCount != 3 or not seperatorValid: + raise newException(ValueError, "Invalid IP Address") + result.address_v4[byteCount] = cast[uint8](currentByte) + +proc parseIPv6Address(address_str: string): IpAddress = + ## Parses IPv6 adresses + ## Raises EInvalidValue on errors + result.family = IpAddressFamily.IPv6 + if address_str.len < 2: + raise newException(ValueError, "Invalid IP Address") + + var + groupCount = 0 + currentGroupStart = 0 + currentShort:uint32 = 0 + seperatorValid = true + dualColonGroup = -1 + lastWasColon = false + v4StartPos = -1 + byteCount = 0 + + for i,c in address_str: + if c == ':': + if not seperatorValid: + raise newException(ValueError, + "Invalid IP Address. Address contains an invalid seperator") + if lastWasColon: + if dualColonGroup != -1: + raise newException(ValueError, + "Invalid IP Address. Address contains more than one \"::\" seperator") + dualColonGroup = groupCount + seperatorValid = false + elif i != 0 and i != high(address_str): + if groupCount >= 8: + raise newException(ValueError, + "Invalid IP Address. The address consists of too many groups") + result.address_v6[groupCount*2] = cast[uint8](currentShort shr 8) + result.address_v6[groupCount*2+1] = cast[uint8](currentShort and 0xFF) + currentShort = 0 + groupCount.inc() + if dualColonGroup != -1: seperatorValid = false + elif i == 0: # only valid if address starts with :: + if address_str[1] != ':': + raise newException(ValueError, + "Invalid IP Address. Address may not start with \":\"") + else: # i == high(address_str) - only valid if address ends with :: + if address_str[high(address_str)-1] != ':': + raise newException(ValueError, + "Invalid IP Address. Address may not end with \":\"") + lastWasColon = true + currentGroupStart = i + 1 + elif c == '.': # Switch to parse IPv4 mode + if i < 3 or not seperatorValid or groupCount >= 7: + raise newException(ValueError, "Invalid IP Address") + v4StartPos = currentGroupStart + currentShort = 0 + seperatorValid = false + break + elif c in strutils.HexDigits: + if c in strutils.Digits: # Normal digit + currentShort = (currentShort shl 4) + cast[uint32](ord(c) - ord('0')) + elif c >= 'a' and c <= 'f': # Lower case hex + currentShort = (currentShort shl 4) + cast[uint32](ord(c) - ord('a')) + 10 + else: # Upper case hex + currentShort = (currentShort shl 4) + cast[uint32](ord(c) - ord('A')) + 10 + if currentShort > 65535'u32: + raise newException(ValueError, + "Invalid IP Address. Value is out of range") + lastWasColon = false + seperatorValid = true + else: + raise newException(ValueError, + "Invalid IP Address. Address contains an invalid character") + + + if v4StartPos == -1: # Don't parse v4. Copy the remaining v6 stuff + if seperatorValid: # Copy remaining data + if groupCount >= 8: + raise newException(ValueError, + "Invalid IP Address. The address consists of too many groups") + result.address_v6[groupCount*2] = cast[uint8](currentShort shr 8) + result.address_v6[groupCount*2+1] = cast[uint8](currentShort and 0xFF) + groupCount.inc() + else: # Must parse IPv4 address + for i,c in address_str[v4StartPos..high(address_str)]: + if c in strutils.Digits: # Character is a number + currentShort = currentShort * 10 + cast[uint32](ord(c) - ord('0')) + if currentShort > 255'u32: + raise newException(ValueError, + "Invalid IP Address. Value is out of range") + seperatorValid = true + elif c == '.': # IPv4 address separator + if not seperatorValid or byteCount >= 3: + raise newException(ValueError, "Invalid IP Address") + result.address_v6[groupCount*2 + byteCount] = cast[uint8](currentShort) + currentShort = 0 + byteCount.inc() + seperatorValid = false + else: # Invalid character + raise newException(ValueError, + "Invalid IP Address. Address contains an invalid character") + + if byteCount != 3 or not seperatorValid: + raise newException(ValueError, "Invalid IP Address") + result.address_v6[groupCount*2 + byteCount] = cast[uint8](currentShort) + groupCount += 2 + + # Shift and fill zeros in case of :: + if groupCount > 8: + raise newException(ValueError, + "Invalid IP Address. The address consists of too many groups") + elif groupCount < 8: # must fill + if dualColonGroup == -1: + raise newException(ValueError, + "Invalid IP Address. The address consists of too few groups") + var toFill = 8 - groupCount # The number of groups to fill + var toShift = groupCount - dualColonGroup # Nr of known groups after :: + for i in 0..2*toShift-1: # shift + result.address_v6[15-i] = result.address_v6[groupCount*2-i-1] + for i in 0..2*toFill-1: # fill with 0s + result.address_v6[dualColonGroup*2+i] = 0 + elif dualColonGroup != -1: + raise newException(ValueError, + "Invalid IP Address. The address consists of too many groups") + +proc parseIpAddress*(address_str: string): IpAddress = + ## Parses an IP address + ## Raises EInvalidValue on error + if address_str == nil: + raise newException(ValueError, "IP Address string is nil") + if address_str.contains(':'): + return parseIPv6Address(address_str) + else: + return parseIPv4Address(address_str) + +proc isIpAddress*(address_str: string): bool {.tags: [].} = + ## Checks if a string is an IP address + ## Returns true if it is, false otherwise + try: + discard parseIpAddress(address_str) + except ValueError: + return false + return true + when defineSsl: CRYPTO_malloc_init() SslLibraryInit() @@ -438,9 +612,12 @@ when defineSsl: raiseSSLError() proc wrapConnectedSocket*(ctx: SSLContext, socket: Socket, - handshake: SslHandshakeType) = + handshake: SslHandshakeType, + hostname: string = nil) = ## Wraps a connected socket in an SSL context. This function effectively ## turns ``socket`` into an SSL socket. + ## ``hostname`` should be specified so that the client knows which hostname + ## the server certificate should be validated against. ## ## This should be called on a connected socket, and will perform ## an SSL handshake immediately. @@ -450,6 +627,10 @@ when defineSsl: wrapSocket(ctx, socket) case handshake of handshakeAsClient: + if not hostname.isNil and not isIpAddress(hostname): + # Discard result in case OpenSSL version doesn't support SNI, or we're + # not using TLSv1+ + discard SSL_set_tlsext_host_name(socket.sslHandle, hostname) let ret = SSLConnect(socket.sslHandle) socketError(socket, ret) of handshakeAsServer: @@ -1302,181 +1483,6 @@ proc `$`*(address: IpAddress): string = mask = mask shr 4 printedLastGroup = true -proc parseIPv4Address(address_str: string): IpAddress = - ## Parses IPv4 adresses - ## Raises EInvalidValue on errors - var - byteCount = 0 - currentByte:uint16 = 0 - seperatorValid = false - - result.family = IpAddressFamily.IPv4 - - for i in 0 .. high(address_str): - if address_str[i] in strutils.Digits: # Character is a number - currentByte = currentByte * 10 + - cast[uint16](ord(address_str[i]) - ord('0')) - if currentByte > 255'u16: - raise newException(ValueError, - "Invalid IP Address. Value is out of range") - seperatorValid = true - elif address_str[i] == '.': # IPv4 address separator - if not seperatorValid or byteCount >= 3: - raise newException(ValueError, - "Invalid IP Address. The address consists of too many groups") - result.address_v4[byteCount] = cast[uint8](currentByte) - currentByte = 0 - byteCount.inc - seperatorValid = false - else: - raise newException(ValueError, - "Invalid IP Address. Address contains an invalid character") - - if byteCount != 3 or not seperatorValid: - raise newException(ValueError, "Invalid IP Address") - result.address_v4[byteCount] = cast[uint8](currentByte) - -proc parseIPv6Address(address_str: string): IpAddress = - ## Parses IPv6 adresses - ## Raises EInvalidValue on errors - result.family = IpAddressFamily.IPv6 - if address_str.len < 2: - raise newException(ValueError, "Invalid IP Address") - - var - groupCount = 0 - currentGroupStart = 0 - currentShort:uint32 = 0 - seperatorValid = true - dualColonGroup = -1 - lastWasColon = false - v4StartPos = -1 - byteCount = 0 - - for i,c in address_str: - if c == ':': - if not seperatorValid: - raise newException(ValueError, - "Invalid IP Address. Address contains an invalid seperator") - if lastWasColon: - if dualColonGroup != -1: - raise newException(ValueError, - "Invalid IP Address. Address contains more than one \"::\" seperator") - dualColonGroup = groupCount - seperatorValid = false - elif i != 0 and i != high(address_str): - if groupCount >= 8: - raise newException(ValueError, - "Invalid IP Address. The address consists of too many groups") - result.address_v6[groupCount*2] = cast[uint8](currentShort shr 8) - result.address_v6[groupCount*2+1] = cast[uint8](currentShort and 0xFF) - currentShort = 0 - groupCount.inc() - if dualColonGroup != -1: seperatorValid = false - elif i == 0: # only valid if address starts with :: - if address_str[1] != ':': - raise newException(ValueError, - "Invalid IP Address. Address may not start with \":\"") - else: # i == high(address_str) - only valid if address ends with :: - if address_str[high(address_str)-1] != ':': - raise newException(ValueError, - "Invalid IP Address. Address may not end with \":\"") - lastWasColon = true - currentGroupStart = i + 1 - elif c == '.': # Switch to parse IPv4 mode - if i < 3 or not seperatorValid or groupCount >= 7: - raise newException(ValueError, "Invalid IP Address") - v4StartPos = currentGroupStart - currentShort = 0 - seperatorValid = false - break - elif c in strutils.HexDigits: - if c in strutils.Digits: # Normal digit - currentShort = (currentShort shl 4) + cast[uint32](ord(c) - ord('0')) - elif c >= 'a' and c <= 'f': # Lower case hex - currentShort = (currentShort shl 4) + cast[uint32](ord(c) - ord('a')) + 10 - else: # Upper case hex - currentShort = (currentShort shl 4) + cast[uint32](ord(c) - ord('A')) + 10 - if currentShort > 65535'u32: - raise newException(ValueError, - "Invalid IP Address. Value is out of range") - lastWasColon = false - seperatorValid = true - else: - raise newException(ValueError, - "Invalid IP Address. Address contains an invalid character") - - - if v4StartPos == -1: # Don't parse v4. Copy the remaining v6 stuff - if seperatorValid: # Copy remaining data - if groupCount >= 8: - raise newException(ValueError, - "Invalid IP Address. The address consists of too many groups") - result.address_v6[groupCount*2] = cast[uint8](currentShort shr 8) - result.address_v6[groupCount*2+1] = cast[uint8](currentShort and 0xFF) - groupCount.inc() - else: # Must parse IPv4 address - for i,c in address_str[v4StartPos..high(address_str)]: - if c in strutils.Digits: # Character is a number - currentShort = currentShort * 10 + cast[uint32](ord(c) - ord('0')) - if currentShort > 255'u32: - raise newException(ValueError, - "Invalid IP Address. Value is out of range") - seperatorValid = true - elif c == '.': # IPv4 address separator - if not seperatorValid or byteCount >= 3: - raise newException(ValueError, "Invalid IP Address") - result.address_v6[groupCount*2 + byteCount] = cast[uint8](currentShort) - currentShort = 0 - byteCount.inc() - seperatorValid = false - else: # Invalid character - raise newException(ValueError, - "Invalid IP Address. Address contains an invalid character") - - if byteCount != 3 or not seperatorValid: - raise newException(ValueError, "Invalid IP Address") - result.address_v6[groupCount*2 + byteCount] = cast[uint8](currentShort) - groupCount += 2 - - # Shift and fill zeros in case of :: - if groupCount > 8: - raise newException(ValueError, - "Invalid IP Address. The address consists of too many groups") - elif groupCount < 8: # must fill - if dualColonGroup == -1: - raise newException(ValueError, - "Invalid IP Address. The address consists of too few groups") - var toFill = 8 - groupCount # The number of groups to fill - var toShift = groupCount - dualColonGroup # Nr of known groups after :: - for i in 0..2*toShift-1: # shift - result.address_v6[15-i] = result.address_v6[groupCount*2-i-1] - for i in 0..2*toFill-1: # fill with 0s - result.address_v6[dualColonGroup*2+i] = 0 - elif dualColonGroup != -1: - raise newException(ValueError, - "Invalid IP Address. The address consists of too many groups") - - -proc parseIpAddress*(address_str: string): IpAddress = - ## Parses an IP address - ## Raises EInvalidValue on error - if address_str == nil: - raise newException(ValueError, "IP Address string is nil") - if address_str.contains(':'): - return parseIPv6Address(address_str) - else: - return parseIPv4Address(address_str) - -proc isIpAddress*(address_str: string): bool {.tags: [].} = - ## Checks if a string is an IP address - ## Returns true if it is, false otherwise - try: - discard parseIpAddress(address_str) - except ValueError: - return false - return true - proc dial*(address: string, port: Port, protocol = IPPROTO_TCP, buffered = true): Socket {.tags: [ReadIOEffect, WriteIOEffect].} = diff --git a/lib/pure/parseutils.nim b/lib/pure/parseutils.nim index 8d53a0360..b78e8d000 100644 --- a/lib/pure/parseutils.nim +++ b/lib/pure/parseutils.nim @@ -201,7 +201,7 @@ proc parseWhile*(s: string, token: var string, validChars: set[char], proc captureBetween*(s: string, first: char, second = '\0', start = 0): string = ## Finds the first occurrence of ``first``, then returns everything from there - ## up to ``second``(if ``second`` is '\0', then ``first`` is used). + ## up to ``second`` (if ``second`` is '\0', then ``first`` is used). var i = skipUntil(s, first, start)+1+start result = "" discard s.parseUntil(result, if second == '\0': first else: second, i) |