summary refs log tree commit diff stats
diff options
context:
space:
mode:
authorDominik Picheta <dominikpicheta@googlemail.com>2015-03-10 11:08:21 +0000
committerDominik Picheta <dominikpicheta@googlemail.com>2015-03-10 11:08:21 +0000
commit3ea3aa633d92e9a9c3f4668727c194cfae3ce7c4 (patch)
tree0ec04a10efa9f11888030a2d7009e51282a279c4
parent796a588b1cd6bbfa5d49b521c9c8d52ff8a3e4fb (diff)
parent7cffd290bf5d20c7b9f191f222ae3f4ea7952523 (diff)
downloadNim-3ea3aa633d92e9a9c3f4668727c194cfae3ce7c4.tar.gz
Merge pull request #2279 from nathan-hoad/sni-support-for-openssl
Add SNI support to client and server sockets.
-rw-r--r--lib/pure/net.nim49
-rw-r--r--lib/wrappers/openssl.nim44
-rw-r--r--tests/stdlib/tnet.nim47
3 files changed, 122 insertions, 18 deletions
diff --git a/lib/pure/net.nim b/lib/pure/net.nim
index bed751542..ffbc6e320 100644
--- a/lib/pure/net.nim
+++ b/lib/pure/net.nim
@@ -81,6 +81,23 @@ type
     TReadLineResult: ReadLineResult, TSOBool: SOBool, PSocket: Socket,
     TSocketImpl: SocketImpl].}
 
+type
+  IpAddressFamily* {.pure.} = enum ## Describes the type of an IP address
+    IPv6, ## IPv6 address
+    IPv4  ## IPv4 address
+
+  TIpAddress* = object ## stores an arbitrary IP address    
+    case family*: IpAddressFamily ## the type of the IP address (IPv4 or IPv6)
+    of IpAddressFamily.IPv6:
+      address_v6*: array[0..15, uint8] ## Contains the IP address in bytes in
+                                       ## case of IPv6
+    of IpAddressFamily.IPv4:
+      address_v4*: array[0..3, uint8] ## Contains the IP address in bytes in
+                                      ## case of IPv4
+
+proc isIpAddress*(address_str: string): bool {.tags: [].}
+proc parseIpAddress*(address_str: string): TIpAddress
+
 proc isDisconnectionError*(flags: set[SocketFlag],
     lastError: OSErrorCode): bool =
   ## Determines whether ``lastError`` is a disconnection error. Only does this
@@ -511,6 +528,12 @@ proc connect*(socket: Socket, address: string, port = Port(0),
   
   when defined(ssl):
     if socket.isSSL:
+      # RFC3546 for SNI specifies that IP addresses are not allowed.
+      if not isIpAddress(address):
+        # Discard result in case OpenSSL version doesn't support SNI, or we're
+        # not using TLSv1+
+        discard SSL_set_tlsext_host_name(socket.sslHandle, address)
+
       let ret = SSLConnect(socket.sslHandle)
       socketError(socket, ret)
 
@@ -969,20 +992,6 @@ proc isSsl*(socket: Socket): bool =
 proc getFd*(socket: Socket): SocketHandle = return socket.fd
   ## Returns the socket's file descriptor
 
-type
-  IpAddressFamily* {.pure.} = enum ## Describes the type of an IP address
-    IPv6, ## IPv6 address
-    IPv4  ## IPv4 address
-
-  TIpAddress* = object ## stores an arbitrary IP address    
-    case family*: IpAddressFamily ## the type of the IP address (IPv4 or IPv6)
-    of IpAddressFamily.IPv6:
-      address_v6*: array[0..15, uint8] ## Contains the IP address in bytes in
-                                       ## case of IPv6
-    of IpAddressFamily.IPv4:
-      address_v4*: array[0..3, uint8] ## Contains the IP address in bytes in
-                                      ## case of IPv4
-
 proc IPv4_any*(): TIpAddress =
   ## Returns the IPv4 any address, which can be used to listen on all available
   ## network adapters
@@ -1241,7 +1250,7 @@ proc parseIPv6Address(address_str: string): TIpAddress =
     raise newException(ValueError,
       "Invalid IP Address. The address consists of too many groups")
 
-proc parseIpAddress*(address_str: string): TIpAddress =
+proc parseIpAddress(address_str: string): TIpAddress =
   ## Parses an IP address
   ## Raises EInvalidValue on error
   if address_str == nil:
@@ -1250,3 +1259,13 @@ proc parseIpAddress*(address_str: string): TIpAddress =
     return parseIPv6Address(address_str)
   else:
     return parseIPv4Address(address_str)
+
+
+proc isIpAddress(address_str: string): bool =
+  ## Checks if a string is an IP address
+  ## Returns true if it is, false otherwise
+  try:
+    discard parseIpAddress(address_str)
+  except ValueError:
+    return false
+  return true
diff --git a/lib/wrappers/openssl.nim b/lib/wrappers/openssl.nim
index 29fe3a921..bca7b3a40 100644
--- a/lib/wrappers/openssl.nim
+++ b/lib/wrappers/openssl.nim
@@ -50,7 +50,7 @@ when useWinVersion:
   from winlean import SocketHandle
 else:
   const
-    versions = "(.10|.1.0.1|.1.0.0|.0.9.9|.0.9.8|.0.9.7|.0.9.6|.0.9.5|.0.9.4)"
+    versions = "(.10|.1.0.1|.1.0.0|.0.9.9|.0.9.8)"
   when defined(macosx):
     const
       DLLSSLName = "libssl" & versions & ".dylib"
@@ -141,6 +141,14 @@ const
   SSL_CTRL_GET_MAX_CERT_LIST* = 50
   SSL_CTRL_SET_MAX_CERT_LIST* = 51 #* Allow SSL_write(..., n) to return r with 0 < r < n (i.e. report success
                                    # * when just a single record has been written): *
+  SSL_CTRL_SET_TLSEXT_SERVERNAME_CB = 53
+  SSL_CTRL_SET_TLSEXT_SERVERNAME_ARG = 54
+  SSL_CTRL_SET_TLSEXT_HOSTNAME = 55
+  TLSEXT_NAMETYPE_host_name* = 0
+  SSL_TLSEXT_ERR_OK* = 0
+  SSL_TLSEXT_ERR_ALERT_WARNING* = 1
+  SSL_TLSEXT_ERR_ALERT_FATAL* = 2
+  SSL_TLSEXT_ERR_NOACK* = 3
   SSL_MODE_ENABLE_PARTIAL_WRITE* = 1 #* Make it possible to retry SSL_write() with changed buffer location
                                      # * (buffer contents must stay the same!); this is not the default to avoid
                                      # * the misconception that non-blocking SSL_write() behaves like
@@ -296,9 +304,41 @@ proc CRYPTO_malloc_init*() =
 proc SSL_CTX_ctrl*(ctx: SslCtx, cmd: cInt, larg: int, parg: pointer): int{.
   cdecl, dynlib: DLLSSLName, importc.}
 
+proc SSL_CTX_callback_ctrl(ctx: SslCtx, typ: cInt, fp: PFunction): int{.
+  cdecl, dynlib: DLLSSLName, importc.}
+
 proc SSLCTXSetMode*(ctx: SslCtx, mode: int): int =
   result = SSL_CTX_ctrl(ctx, SSL_CTRL_MODE, mode, nil)
 
+proc SSL_ctrl*(ssl: SslPtr, cmd: cInt, larg: int, parg: pointer): int{.
+  cdecl, dynlib: DLLSSLName, importc.}
+
+proc SSL_set_tlsext_host_name*(ssl: SslPtr, name: cstring): int =
+  result = SSL_ctrl(ssl, SSL_CTRL_SET_TLSEXT_HOSTNAME, TLSEXT_NAMETYPE_host_name, name)
+  ## Set the SNI server name extension to be used in a client hello.
+  ## Returns 1 if SNI was set, 0 if current SSL configuration doesn't support SNI.
+
+
+proc SSL_get_servername*(ssl: SslPtr, typ: cInt = TLSEXT_NAMETYPE_host_name): cstring {.cdecl, dynlib: DLLSSLName, importc.}
+  ## Retrieve the server name requested in the client hello. This can be used
+  ## in the callback set in `SSL_CTX_set_tlsext_servername_callback` to
+  ## implement virtual hosting. May return `nil`.
+
+proc SSL_CTX_set_tlsext_servername_callback*(ctx: SslCtx, cb: proc(ssl: SslPtr, cb_id: int, arg: pointer): int {.cdecl.}): int =
+  ## Set the callback to be used on listening SSL connections when the client hello is received.
+  ##
+  ## The callback should return one of:
+  ## * SSL_TLSEXT_ERR_OK
+  ## * SSL_TLSEXT_ERR_ALERT_WARNING
+  ## * SSL_TLSEXT_ERR_ALERT_FATAL
+  ## * SSL_TLSEXT_ERR_NOACK
+  result = SSL_CTX_callback_ctrl(ctx, SSL_CTRL_SET_TLSEXT_SERVERNAME_CB, cast[PFunction](cb))
+
+proc SSL_CTX_set_tlsext_servername_arg*(ctx: SslCtx, arg: pointer): int =
+  ## Set the pointer to be used in the callback registered to ``SSL_CTX_set_tlsext_servername_callback``.
+  result = SSL_CTX_ctrl(ctx, SSL_CTRL_SET_TLSEXT_SERVERNAME_ARG, 0, arg)
+
+
 proc bioNew*(b: PBIO_METHOD): BIO{.cdecl, dynlib: DLLUtilName, importc: "BIO_new".}
 proc bioFreeAll*(b: BIO){.cdecl, dynlib: DLLUtilName, importc: "BIO_free_all".}
 proc bioSMem*(): PBIO_METHOD{.cdecl, dynlib: DLLUtilName, importc: "BIO_s_mem".}
@@ -341,8 +381,6 @@ else:
       dynlib: DLLSSLName, importc.}
 
   proc SslSetFd*(s: PSSL, fd: cInt): cInt{.cdecl, dynlib: DLLSSLName, importc.}
-  proc SslCtrl*(ssl: PSSL, cmd: cInt, larg: int, parg: Pointer): int{.cdecl, 
-      dynlib: DLLSSLName, importc.}
   proc SslCTXCtrl*(ctx: PSSL_CTX, cmd: cInt, larg: int, parg: Pointer): int{.
       cdecl, dynlib: DLLSSLName, importc.}
 
diff --git a/tests/stdlib/tnet.nim b/tests/stdlib/tnet.nim
new file mode 100644
index 000000000..e8ada05e7
--- /dev/null
+++ b/tests/stdlib/tnet.nim
@@ -0,0 +1,47 @@
+import net
+import unittest
+
+suite "isIpAddress tests":
+  test "127.0.0.1 is valid":
+    check isIpAddress("127.0.0.1") == true
+
+  test "ipv6 localhost is valid":
+    check isIpAddress("::1") == true
+
+  test "fqdn is not an ip address":
+    check isIpAddress("example.com") == false
+
+  test "random string is not an ipaddress":
+    check isIpAddress("foo bar") == false
+
+  test "5127.0.0.1 is invalid":
+    check isIpAddress("5127.0.0.1") == false
+
+  test "ipv6 is valid":
+    check isIpAddress("2001:cdba:0000:0000:0000:0000:3257:9652") == true
+
+  test "invalid ipv6":
+    check isIpAddress("gggg:cdba:0000:0000:0000:0000:3257:9652") == false
+
+
+suite "parseIpAddress tests":
+  test "127.0.0.1 is valid":
+    discard parseIpAddress("127.0.0.1")
+
+  test "ipv6 localhost is valid":
+    discard parseIpAddress("::1")
+
+  test "fqdn is not an ip address":
+    expect(ValueError):
+      discard parseIpAddress("example.com")
+
+  test "random string is not an ipaddress":
+    expect(ValueError):
+      discard parseIpAddress("foo bar")
+
+  test "ipv6 is valid":
+    discard parseIpAddress("2001:cdba:0000:0000:0000:0000:3257:9652")
+
+  test "invalid ipv6":
+    expect(ValueError):
+      discard parseIpAddress("gggg:cdba:0000:0000:0000:0000:3257:9652")