summary refs log tree commit diff stats
path: root/lib/pure/net.nim
diff options
context:
space:
mode:
authorRuslan Mustakov <r.mustakov@gmail.com>2017-05-04 16:27:08 +0700
committerRuslan Mustakov <r.mustakov@gmail.com>2017-05-04 16:27:08 +0700
commite0059287bb971a52a11d9ba5129dbbf01a6762df (patch)
tree83e168382d0d09f8f462be7d075f5e48d68d28c2 /lib/pure/net.nim
parent27b571dd9598d144a1d15963d9c73e75d7d764c4 (diff)
downloadNim-e0059287bb971a52a11d9ba5129dbbf01a6762df.tar.gz
Add 'hostname' param to wrapConnectedSocket
Diffstat (limited to 'lib/pure/net.nim')
-rw-r--r--lib/pure/net.nim358
1 files changed, 182 insertions, 176 deletions
diff --git a/lib/pure/net.nim b/lib/pure/net.nim
index d175bd537..629e916fa 100644
--- a/lib/pure/net.nim
+++ b/lib/pure/net.nim
@@ -237,6 +237,180 @@ proc newSocket*(domain: Domain = AF_INET, sockType: SockType = SOCK_STREAM,
     raiseOSError(osLastError())
   result = newSocket(fd, domain, sockType, protocol, buffered)
 
+proc parseIPv4Address(address_str: string): IpAddress =
+  ## Parses IPv4 adresses
+  ## Raises EInvalidValue on errors
+  var
+    byteCount = 0
+    currentByte:uint16 = 0
+    seperatorValid = false
+
+  result.family = IpAddressFamily.IPv4
+
+  for i in 0 .. high(address_str):
+    if address_str[i] in strutils.Digits: # Character is a number
+      currentByte = currentByte * 10 +
+        cast[uint16](ord(address_str[i]) - ord('0'))
+      if currentByte > 255'u16:
+        raise newException(ValueError,
+          "Invalid IP Address. Value is out of range")
+      seperatorValid = true
+    elif address_str[i] == '.': # IPv4 address separator
+      if not seperatorValid or byteCount >= 3:
+        raise newException(ValueError,
+          "Invalid IP Address. The address consists of too many groups")
+      result.address_v4[byteCount] = cast[uint8](currentByte)
+      currentByte = 0
+      byteCount.inc
+      seperatorValid = false
+    else:
+      raise newException(ValueError,
+        "Invalid IP Address. Address contains an invalid character")
+
+  if byteCount != 3 or not seperatorValid:
+    raise newException(ValueError, "Invalid IP Address")
+  result.address_v4[byteCount] = cast[uint8](currentByte)
+
+proc parseIPv6Address(address_str: string): IpAddress =
+  ## Parses IPv6 adresses
+  ## Raises EInvalidValue on errors
+  result.family = IpAddressFamily.IPv6
+  if address_str.len < 2:
+    raise newException(ValueError, "Invalid IP Address")
+
+  var
+    groupCount = 0
+    currentGroupStart = 0
+    currentShort:uint32 = 0
+    seperatorValid = true
+    dualColonGroup = -1
+    lastWasColon = false
+    v4StartPos = -1
+    byteCount = 0
+
+  for i,c in address_str:
+    if c == ':':
+      if not seperatorValid:
+        raise newException(ValueError,
+          "Invalid IP Address. Address contains an invalid seperator")
+      if lastWasColon:
+        if dualColonGroup != -1:
+          raise newException(ValueError,
+            "Invalid IP Address. Address contains more than one \"::\" seperator")
+        dualColonGroup = groupCount
+        seperatorValid = false
+      elif i != 0 and i != high(address_str):
+        if groupCount >= 8:
+          raise newException(ValueError,
+            "Invalid IP Address. The address consists of too many groups")
+        result.address_v6[groupCount*2] = cast[uint8](currentShort shr 8)
+        result.address_v6[groupCount*2+1] = cast[uint8](currentShort and 0xFF)
+        currentShort = 0
+        groupCount.inc()
+        if dualColonGroup != -1: seperatorValid = false
+      elif i == 0: # only valid if address starts with ::
+        if address_str[1] != ':':
+          raise newException(ValueError,
+            "Invalid IP Address. Address may not start with \":\"")
+      else: # i == high(address_str) - only valid if address ends with ::
+        if address_str[high(address_str)-1] != ':':
+          raise newException(ValueError,
+            "Invalid IP Address. Address may not end with \":\"")
+      lastWasColon = true
+      currentGroupStart = i + 1
+    elif c == '.': # Switch to parse IPv4 mode
+      if i < 3 or not seperatorValid or groupCount >= 7:
+        raise newException(ValueError, "Invalid IP Address")
+      v4StartPos = currentGroupStart
+      currentShort = 0
+      seperatorValid = false
+      break
+    elif c in strutils.HexDigits:
+      if c in strutils.Digits: # Normal digit
+        currentShort = (currentShort shl 4) + cast[uint32](ord(c) - ord('0'))
+      elif c >= 'a' and c <= 'f': # Lower case hex
+        currentShort = (currentShort shl 4) + cast[uint32](ord(c) - ord('a')) + 10
+      else: # Upper case hex
+        currentShort = (currentShort shl 4) + cast[uint32](ord(c) - ord('A')) + 10
+      if currentShort > 65535'u32:
+        raise newException(ValueError,
+          "Invalid IP Address. Value is out of range")
+      lastWasColon = false
+      seperatorValid = true
+    else:
+      raise newException(ValueError,
+        "Invalid IP Address. Address contains an invalid character")
+
+
+  if v4StartPos == -1: # Don't parse v4. Copy the remaining v6 stuff
+    if seperatorValid: # Copy remaining data
+      if groupCount >= 8:
+        raise newException(ValueError,
+          "Invalid IP Address. The address consists of too many groups")
+      result.address_v6[groupCount*2] = cast[uint8](currentShort shr 8)
+      result.address_v6[groupCount*2+1] = cast[uint8](currentShort and 0xFF)
+      groupCount.inc()
+  else: # Must parse IPv4 address
+    for i,c in address_str[v4StartPos..high(address_str)]:
+      if c in strutils.Digits: # Character is a number
+        currentShort = currentShort * 10 + cast[uint32](ord(c) - ord('0'))
+        if currentShort > 255'u32:
+          raise newException(ValueError,
+            "Invalid IP Address. Value is out of range")
+        seperatorValid = true
+      elif c == '.': # IPv4 address separator
+        if not seperatorValid or byteCount >= 3:
+          raise newException(ValueError, "Invalid IP Address")
+        result.address_v6[groupCount*2 + byteCount] = cast[uint8](currentShort)
+        currentShort = 0
+        byteCount.inc()
+        seperatorValid = false
+      else: # Invalid character
+        raise newException(ValueError,
+          "Invalid IP Address. Address contains an invalid character")
+
+    if byteCount != 3 or not seperatorValid:
+      raise newException(ValueError, "Invalid IP Address")
+    result.address_v6[groupCount*2 + byteCount] = cast[uint8](currentShort)
+    groupCount += 2
+
+  # Shift and fill zeros in case of ::
+  if groupCount > 8:
+    raise newException(ValueError,
+      "Invalid IP Address. The address consists of too many groups")
+  elif groupCount < 8: # must fill
+    if dualColonGroup == -1:
+      raise newException(ValueError,
+        "Invalid IP Address. The address consists of too few groups")
+    var toFill = 8 - groupCount # The number of groups to fill
+    var toShift = groupCount - dualColonGroup # Nr of known groups after ::
+    for i in 0..2*toShift-1: # shift
+      result.address_v6[15-i] = result.address_v6[groupCount*2-i-1]
+    for i in 0..2*toFill-1: # fill with 0s
+      result.address_v6[dualColonGroup*2+i] = 0
+  elif dualColonGroup != -1:
+    raise newException(ValueError,
+      "Invalid IP Address. The address consists of too many groups")
+
+proc parseIpAddress*(address_str: string): IpAddress =
+  ## Parses an IP address
+  ## Raises EInvalidValue on error
+  if address_str == nil:
+    raise newException(ValueError, "IP Address string is nil")
+  if address_str.contains(':'):
+    return parseIPv6Address(address_str)
+  else:
+    return parseIPv4Address(address_str)
+
+proc isIpAddress*(address_str: string): bool {.tags: [].} =
+  ## Checks if a string is an IP address
+  ## Returns true if it is, false otherwise
+  try:
+    discard parseIpAddress(address_str)
+  except ValueError:
+    return false
+  return true
+
 when defineSsl:
   CRYPTO_malloc_init()
   SslLibraryInit()
@@ -438,9 +612,12 @@ when defineSsl:
       raiseSSLError()
 
   proc wrapConnectedSocket*(ctx: SSLContext, socket: Socket,
-                            handshake: SslHandshakeType) =
+                            handshake: SslHandshakeType,
+                            hostname: string = nil) =
     ## Wraps a connected socket in an SSL context. This function effectively
     ## turns ``socket`` into an SSL socket.
+    ## ``hostname`` should be specified so that the client knows which hostname
+    ## the server certificate should be validated against.
     ##
     ## This should be called on a connected socket, and will perform
     ## an SSL handshake immediately.
@@ -450,6 +627,10 @@ when defineSsl:
     wrapSocket(ctx, socket)
     case handshake
     of handshakeAsClient:
+      if not hostname.isNil and not isIpAddress(hostname):
+        # Discard result in case OpenSSL version doesn't support SNI, or we're
+        # not using TLSv1+
+        discard SSL_set_tlsext_host_name(socket.sslHandle, hostname)
       let ret = SSLConnect(socket.sslHandle)
       socketError(socket, ret)
     of handshakeAsServer:
@@ -1302,181 +1483,6 @@ proc `$`*(address: IpAddress): string =
             mask = mask shr 4
           printedLastGroup = true
 
-proc parseIPv4Address(address_str: string): IpAddress =
-  ## Parses IPv4 adresses
-  ## Raises EInvalidValue on errors
-  var
-    byteCount = 0
-    currentByte:uint16 = 0
-    seperatorValid = false
-
-  result.family = IpAddressFamily.IPv4
-
-  for i in 0 .. high(address_str):
-    if address_str[i] in strutils.Digits: # Character is a number
-      currentByte = currentByte * 10 +
-        cast[uint16](ord(address_str[i]) - ord('0'))
-      if currentByte > 255'u16:
-        raise newException(ValueError,
-          "Invalid IP Address. Value is out of range")
-      seperatorValid = true
-    elif address_str[i] == '.': # IPv4 address separator
-      if not seperatorValid or byteCount >= 3:
-        raise newException(ValueError,
-          "Invalid IP Address. The address consists of too many groups")
-      result.address_v4[byteCount] = cast[uint8](currentByte)
-      currentByte = 0
-      byteCount.inc
-      seperatorValid = false
-    else:
-      raise newException(ValueError,
-        "Invalid IP Address. Address contains an invalid character")
-
-  if byteCount != 3 or not seperatorValid:
-    raise newException(ValueError, "Invalid IP Address")
-  result.address_v4[byteCount] = cast[uint8](currentByte)
-
-proc parseIPv6Address(address_str: string): IpAddress =
-  ## Parses IPv6 adresses
-  ## Raises EInvalidValue on errors
-  result.family = IpAddressFamily.IPv6
-  if address_str.len < 2:
-    raise newException(ValueError, "Invalid IP Address")
-
-  var
-    groupCount = 0
-    currentGroupStart = 0
-    currentShort:uint32 = 0
-    seperatorValid = true
-    dualColonGroup = -1
-    lastWasColon = false
-    v4StartPos = -1
-    byteCount = 0
-
-  for i,c in address_str:
-    if c == ':':
-      if not seperatorValid:
-        raise newException(ValueError,
-          "Invalid IP Address. Address contains an invalid seperator")
-      if lastWasColon:
-        if dualColonGroup != -1:
-          raise newException(ValueError,
-            "Invalid IP Address. Address contains more than one \"::\" seperator")
-        dualColonGroup = groupCount
-        seperatorValid = false
-      elif i != 0 and i != high(address_str):
-        if groupCount >= 8:
-          raise newException(ValueError,
-            "Invalid IP Address. The address consists of too many groups")
-        result.address_v6[groupCount*2] = cast[uint8](currentShort shr 8)
-        result.address_v6[groupCount*2+1] = cast[uint8](currentShort and 0xFF)
-        currentShort = 0
-        groupCount.inc()
-        if dualColonGroup != -1: seperatorValid = false
-      elif i == 0: # only valid if address starts with ::
-        if address_str[1] != ':':
-          raise newException(ValueError,
-            "Invalid IP Address. Address may not start with \":\"")
-      else: # i == high(address_str) - only valid if address ends with ::
-        if address_str[high(address_str)-1] != ':':
-          raise newException(ValueError,
-            "Invalid IP Address. Address may not end with \":\"")
-      lastWasColon = true
-      currentGroupStart = i + 1
-    elif c == '.': # Switch to parse IPv4 mode
-      if i < 3 or not seperatorValid or groupCount >= 7:
-        raise newException(ValueError, "Invalid IP Address")
-      v4StartPos = currentGroupStart
-      currentShort = 0
-      seperatorValid = false
-      break
-    elif c in strutils.HexDigits:
-      if c in strutils.Digits: # Normal digit
-        currentShort = (currentShort shl 4) + cast[uint32](ord(c) - ord('0'))
-      elif c >= 'a' and c <= 'f': # Lower case hex
-        currentShort = (currentShort shl 4) + cast[uint32](ord(c) - ord('a')) + 10
-      else: # Upper case hex
-        currentShort = (currentShort shl 4) + cast[uint32](ord(c) - ord('A')) + 10
-      if currentShort > 65535'u32:
-        raise newException(ValueError,
-          "Invalid IP Address. Value is out of range")
-      lastWasColon = false
-      seperatorValid = true
-    else:
-      raise newException(ValueError,
-        "Invalid IP Address. Address contains an invalid character")
-
-
-  if v4StartPos == -1: # Don't parse v4. Copy the remaining v6 stuff
-    if seperatorValid: # Copy remaining data
-      if groupCount >= 8:
-        raise newException(ValueError,
-          "Invalid IP Address. The address consists of too many groups")
-      result.address_v6[groupCount*2] = cast[uint8](currentShort shr 8)
-      result.address_v6[groupCount*2+1] = cast[uint8](currentShort and 0xFF)
-      groupCount.inc()
-  else: # Must parse IPv4 address
-    for i,c in address_str[v4StartPos..high(address_str)]:
-      if c in strutils.Digits: # Character is a number
-        currentShort = currentShort * 10 + cast[uint32](ord(c) - ord('0'))
-        if currentShort > 255'u32:
-          raise newException(ValueError,
-            "Invalid IP Address. Value is out of range")
-        seperatorValid = true
-      elif c == '.': # IPv4 address separator
-        if not seperatorValid or byteCount >= 3:
-          raise newException(ValueError, "Invalid IP Address")
-        result.address_v6[groupCount*2 + byteCount] = cast[uint8](currentShort)
-        currentShort = 0
-        byteCount.inc()
-        seperatorValid = false
-      else: # Invalid character
-        raise newException(ValueError,
-          "Invalid IP Address. Address contains an invalid character")
-
-    if byteCount != 3 or not seperatorValid:
-      raise newException(ValueError, "Invalid IP Address")
-    result.address_v6[groupCount*2 + byteCount] = cast[uint8](currentShort)
-    groupCount += 2
-
-  # Shift and fill zeros in case of ::
-  if groupCount > 8:
-    raise newException(ValueError,
-      "Invalid IP Address. The address consists of too many groups")
-  elif groupCount < 8: # must fill
-    if dualColonGroup == -1:
-      raise newException(ValueError,
-        "Invalid IP Address. The address consists of too few groups")
-    var toFill = 8 - groupCount # The number of groups to fill
-    var toShift = groupCount - dualColonGroup # Nr of known groups after ::
-    for i in 0..2*toShift-1: # shift
-      result.address_v6[15-i] = result.address_v6[groupCount*2-i-1]
-    for i in 0..2*toFill-1: # fill with 0s
-      result.address_v6[dualColonGroup*2+i] = 0
-  elif dualColonGroup != -1:
-    raise newException(ValueError,
-      "Invalid IP Address. The address consists of too many groups")
-
-
-proc parseIpAddress*(address_str: string): IpAddress =
-  ## Parses an IP address
-  ## Raises EInvalidValue on error
-  if address_str == nil:
-    raise newException(ValueError, "IP Address string is nil")
-  if address_str.contains(':'):
-    return parseIPv6Address(address_str)
-  else:
-    return parseIPv4Address(address_str)
-
-proc isIpAddress*(address_str: string): bool {.tags: [].} =
-  ## Checks if a string is an IP address
-  ## Returns true if it is, false otherwise
-  try:
-    discard parseIpAddress(address_str)
-  except ValueError:
-    return false
-  return true
-
 proc dial*(address: string, port: Port,
            protocol = IPPROTO_TCP, buffered = true): Socket
            {.tags: [ReadIOEffect, WriteIOEffect].} =