summary refs log tree commit diff stats
diff options
context:
space:
mode:
authorNathan Hoad <nathan@getoffmalawn.com>2015-03-07 00:42:14 +1100
committerNathan Hoad <nathan@getoffmalawn.com>2015-03-07 00:48:32 +1100
commitd27f40d9b1cb2436f435e40cf952dbd19ed6d463 (patch)
treef6c9d28a2cef6fa2aff5527ea2266bfa6ca25eaa
parentb870744d5de2c5682e07add0bbce9d5584ea5892 (diff)
downloadNim-d27f40d9b1cb2436f435e40cf952dbd19ed6d463.tar.gz
Add SNI support to client and server sockets.
-rw-r--r--lib/pure/net.nim39
-rw-r--r--lib/wrappers/openssl.nim28
2 files changed, 50 insertions, 17 deletions
diff --git a/lib/pure/net.nim b/lib/pure/net.nim
index bed751542..f7fcea06d 100644
--- a/lib/pure/net.nim
+++ b/lib/pure/net.nim
@@ -81,6 +81,22 @@ type
     TReadLineResult: ReadLineResult, TSOBool: SOBool, PSocket: Socket,
     TSocketImpl: SocketImpl].}
 
+type
+  IpAddressFamily* {.pure.} = enum ## Describes the type of an IP address
+    IPv6, ## IPv6 address
+    IPv4  ## IPv4 address
+
+  TIpAddress* = object ## stores an arbitrary IP address    
+    case family*: IpAddressFamily ## the type of the IP address (IPv4 or IPv6)
+    of IpAddressFamily.IPv6:
+      address_v6*: array[0..15, uint8] ## Contains the IP address in bytes in
+                                       ## case of IPv6
+    of IpAddressFamily.IPv4:
+      address_v4*: array[0..3, uint8] ## Contains the IP address in bytes in
+                                      ## case of IPv4
+
+proc parseIpAddress*(address_str: string): TIpAddress
+
 proc isDisconnectionError*(flags: set[SocketFlag],
     lastError: OSErrorCode): bool =
   ## Determines whether ``lastError`` is a disconnection error. Only does this
@@ -487,7 +503,7 @@ proc setSockOpt*(socket: Socket, opt: SOBool, value: bool, level = SOL_SOCKET) {
   setSockOptInt(socket.fd, cint(level), toCInt(opt), valuei)
 
 proc connect*(socket: Socket, address: string, port = Port(0), 
-              af: Domain = AF_INET) {.tags: [ReadIOEffect].} =
+              af: Domain = AF_INET) {.tags: [ReadIOEffect, RootEffect].} =
   ## Connects socket to ``address``:``port``. ``Address`` can be an IP address or a
   ## host name. If ``address`` is a host name, this function will try each IP
   ## of that host name. ``htons`` is already performed on ``port`` so you must
@@ -511,6 +527,13 @@ proc connect*(socket: Socket, address: string, port = Port(0),
   
   when defined(ssl):
     if socket.isSSL:
+      try:
+        discard parseIpAddress(address)
+      except ValueError:
+        # Discard result in case OpenSSL version doesn't support SNI, or we're
+        # not using TLSv1+
+        discard SSL_set_tlsext_host_name(socket.sslHandle, address)
+
       let ret = SSLConnect(socket.sslHandle)
       socketError(socket, ret)
 
@@ -969,20 +992,6 @@ proc isSsl*(socket: Socket): bool =
 proc getFd*(socket: Socket): SocketHandle = return socket.fd
   ## Returns the socket's file descriptor
 
-type
-  IpAddressFamily* {.pure.} = enum ## Describes the type of an IP address
-    IPv6, ## IPv6 address
-    IPv4  ## IPv4 address
-
-  TIpAddress* = object ## stores an arbitrary IP address    
-    case family*: IpAddressFamily ## the type of the IP address (IPv4 or IPv6)
-    of IpAddressFamily.IPv6:
-      address_v6*: array[0..15, uint8] ## Contains the IP address in bytes in
-                                       ## case of IPv6
-    of IpAddressFamily.IPv4:
-      address_v4*: array[0..3, uint8] ## Contains the IP address in bytes in
-                                      ## case of IPv4
-
 proc IPv4_any*(): TIpAddress =
   ## Returns the IPv4 any address, which can be used to listen on all available
   ## network adapters
diff --git a/lib/wrappers/openssl.nim b/lib/wrappers/openssl.nim
index 29fe3a921..dfc08a2bd 100644
--- a/lib/wrappers/openssl.nim
+++ b/lib/wrappers/openssl.nim
@@ -141,6 +141,14 @@ const
   SSL_CTRL_GET_MAX_CERT_LIST* = 50
   SSL_CTRL_SET_MAX_CERT_LIST* = 51 #* Allow SSL_write(..., n) to return r with 0 < r < n (i.e. report success
                                    # * when just a single record has been written): *
+  SSL_CTRL_SET_TLSEXT_SERVERNAME_CB = 53
+  SSL_CTRL_SET_TLSEXT_SERVERNAME_ARG = 54
+  SSL_CTRL_SET_TLSEXT_HOSTNAME = 55
+  TLSEXT_NAMETYPE_host_name* = 0
+  SSL_TLSEXT_ERR_OK* = 0
+  SSL_TLSEXT_ERR_ALERT_WARNING* = 1
+  SSL_TLSEXT_ERR_ALERT_FATAL* = 2
+  SSL_TLSEXT_ERR_NOACK* = 3
   SSL_MODE_ENABLE_PARTIAL_WRITE* = 1 #* Make it possible to retry SSL_write() with changed buffer location
                                      # * (buffer contents must stay the same!); this is not the default to avoid
                                      # * the misconception that non-blocking SSL_write() behaves like
@@ -296,9 +304,27 @@ proc CRYPTO_malloc_init*() =
 proc SSL_CTX_ctrl*(ctx: SslCtx, cmd: cInt, larg: int, parg: pointer): int{.
   cdecl, dynlib: DLLSSLName, importc.}
 
+proc SSL_CTX_callback_ctrl(ctx: SslCtx, typ: cInt, fp: PFunction): int{.
+  cdecl, dynlib: DLLSSLName, importc.}
+
 proc SSLCTXSetMode*(ctx: SslCtx, mode: int): int =
   result = SSL_CTX_ctrl(ctx, SSL_CTRL_MODE, mode, nil)
 
+proc SSL_ctrl*(ssl: SslPtr, cmd: cInt, larg: int, parg: pointer): int{.
+  cdecl, dynlib: DLLSSLName, importc.}
+
+proc SSL_set_tlsext_host_name*(ssl: SslPtr, name: cstring): int =
+  result = SSL_ctrl(ssl, SSL_CTRL_SET_TLSEXT_HOSTNAME, TLSEXT_NAMETYPE_host_name, name)
+
+proc SSL_get_servername*(ssl: SslPtr, typ: cInt = TLSEXT_NAMETYPE_host_name): cstring {.cdecl, dynlib: DLLSSLName, importc.}
+
+proc SSL_CTX_set_tlsext_servername_callback*(ctx: SslCtx, cb: PFunction): int =
+  result = SSL_CTX_callback_ctrl(ctx, SSL_CTRL_SET_TLSEXT_SERVERNAME_CB, cb)
+
+proc SSL_CTX_set_tlsext_servername_arg*(ctx: SslCtx, arg: pointer): int =
+  result = SSL_CTX_ctrl(ctx, SSL_CTRL_SET_TLSEXT_SERVERNAME_ARG, 0, arg)
+
+
 proc bioNew*(b: PBIO_METHOD): BIO{.cdecl, dynlib: DLLUtilName, importc: "BIO_new".}
 proc bioFreeAll*(b: BIO){.cdecl, dynlib: DLLUtilName, importc: "BIO_free_all".}
 proc bioSMem*(): PBIO_METHOD{.cdecl, dynlib: DLLUtilName, importc: "BIO_s_mem".}
@@ -341,8 +367,6 @@ else:
       dynlib: DLLSSLName, importc.}
 
   proc SslSetFd*(s: PSSL, fd: cInt): cInt{.cdecl, dynlib: DLLSSLName, importc.}
-  proc SslCtrl*(ssl: PSSL, cmd: cInt, larg: int, parg: Pointer): int{.cdecl, 
-      dynlib: DLLSSLName, importc.}
   proc SslCTXCtrl*(ctx: PSSL_CTX, cmd: cInt, larg: int, parg: Pointer): int{.
       cdecl, dynlib: DLLSSLName, importc.}