summary refs log tree commit diff stats
diff options
context:
space:
mode:
authorMichał Zieliński <michal@zielinscy.org.pl>2015-10-24 08:53:18 +0200
committerMichał Zieliński <michal@zielinscy.org.pl>2015-10-24 22:17:31 +0200
commitba61a8d00a65948fc0b3a1c100a20cca711fdd0f (patch)
treee160d0a04edec991e2dbb02dd1724a5302ed7347
parent3ebf27ddd24c04e87e33bfb6f8617d81c9fc1946 (diff)
downloadNim-ba61a8d00a65948fc0b3a1c100a20cca711fdd0f.tar.gz
net.nim: support for TLS-PSK ciphersuites
-rw-r--r--examples/ssl/pskclient.nim15
-rw-r--r--examples/ssl/pskserver.nim20
-rw-r--r--lib/pure/net.nim102
-rw-r--r--lib/wrappers/openssl.nim20
4 files changed, 144 insertions, 13 deletions
diff --git a/examples/ssl/pskclient.nim b/examples/ssl/pskclient.nim
new file mode 100644
index 000000000..7c93bbb61
--- /dev/null
+++ b/examples/ssl/pskclient.nim
@@ -0,0 +1,15 @@
+# Create connection encrypted using preshared key (TLS-PSK).
+import net
+
+static: assert defined(ssl)
+
+let sock = newSocket()
+sock.connect("localhost", Port(8800))
+
+proc clientFunc(identityHint: string): tuple[identity: string, psk: string] =
+  echo "identity hint ", identityHint.repr
+  return ("foo", "psk-of-foo")
+
+let context = newContext(cipherList="PSK-AES256-CBC-SHA")
+context.clientGetPskFunc = clientFunc
+context.wrapConnectedSocket(sock, handshakeAsClient)
diff --git a/examples/ssl/pskserver.nim b/examples/ssl/pskserver.nim
new file mode 100644
index 000000000..859eaa875
--- /dev/null
+++ b/examples/ssl/pskserver.nim
@@ -0,0 +1,20 @@
+# Accept connection encrypted using preshared key (TLS-PSK).
+import net
+
+static: assert defined(ssl)
+
+let sock = newSocket()
+sock.bindAddr(Port(8800))
+sock.listen()
+
+let context = newContext(cipherList="PSK-AES256-CBC-SHA")
+context.pskIdentityHint = "hello"
+context.serverGetPskFunc = proc(identity: string): string = "psk-of-" & identity
+
+while true:
+  var client = new(Socket)
+  sock.accept(client)
+  sock.setSockOpt(OptReuseAddr, true)
+  echo "accepted connection"
+  context.wrapConnectedSocket(client, handshakeAsServer)
+  echo "got connection with identity ", client.getPskIdentity()
diff --git a/lib/pure/net.nim b/lib/pure/net.nim
index 5498ebb7d..4bdfede42 100644
--- a/lib/pure/net.nim
+++ b/lib/pure/net.nim
@@ -38,6 +38,10 @@ when defined(ssl):
     SslHandshakeType* = enum
       handshakeAsClient, handshakeAsServer
 
+    SslClientGetPskFunc* = proc(hint: string): tuple[identity: string, psk: string]
+
+    SslServerGetPskFunc* = proc(identity: string): string
+
   {.deprecated: [ESSL: SSLError, TSSLCVerifyMode: SSLCVerifyMode,
     TSSLProtVersion: SSLProtVersion, PSSLContext: SSLContext,
     TSSLAcceptResult: SSLAcceptResult].}
@@ -168,6 +172,10 @@ when defined(ssl):
   ErrLoadBioStrings()
   OpenSSL_add_all_algorithms()
 
+  type SslContextExtraInternal = ref object
+    serverGetPskFunc: SslServerGetPskFunc
+    clientGetPskFunc: SslClientGetPskFunc
+
   proc raiseSSLError*(s = "") =
     ## Raises a new SSL error.
     if s != "":
@@ -180,6 +188,22 @@ when defined(ssl):
     var errStr = ErrErrorString(err, nil)
     raise newException(SSLError, $errStr)
 
+  proc getSslContextExtraDataIndex*(): cint =
+    ## Retrieves unique index for storing extra data in SSLContext.
+    return SSL_CTX_get_ex_new_index(0, nil, nil, nil, nil)
+
+  proc setExtraData*(ctx: SSLContext, index: cint, data: pointer) =
+    ## Stores arbitrary data inside SSLContext. The unique `index`
+    ## should be retrieved using getSslContextExtraDataIndex.
+    if SslCtx(ctx).SSL_CTX_set_ex_data(index, data) == -1:
+      raiseSSLError()
+
+  proc getExtraData*(ctx: SSLContext, index: cint): pointer =
+    ## Retrieves arbitrary data stored inside SSLContext.
+    return SslCtx(ctx).SSL_CTX_get_ex_data(index)
+
+  let extraInternalIndex = getSslContextExtraDataIndex()
+
   # http://simplestcodings.blogspot.co.uk/2010/08/secure-server-client-using-openssl-in-c.html
   proc loadCertificates(ctx: SSL_CTX, certFile, keyFile: string) =
     if certFile != "" and not existsFile(certFile):
@@ -202,7 +226,7 @@ when defined(ssl):
         raiseSSLError("Verification of private key file failed.")
 
   proc newContext*(protVersion = protSSLv23, verifyMode = CVerifyPeer,
-                   certFile = "", keyFile = ""): SSLContext =
+                   certFile = "", keyFile = "", cipherList = "ALL"): SSLContext =
     ## Creates an SSL context.
     ##
     ## Protocol version specifies the protocol to use. SSLv2, SSLv3, TLSv1
@@ -229,7 +253,7 @@ when defined(ssl):
     of protTLSv1:
       newCTX = SSL_CTX_new(TLSv1_method())
 
-    if newCTX.SSLCTXSetCipherList("ALL") != 1:
+    if newCTX.SSLCTXSetCipherList(cipherList) != 1:
       raiseSSLError()
     case verifyMode
     of CVerifyPeer:
@@ -241,21 +265,73 @@ when defined(ssl):
 
     discard newCTX.SSLCTXSetMode(SSL_MODE_AUTO_RETRY)
     newCTX.loadCertificates(certFile, keyFile)
-    return SSLContext(newCTX)
 
-  proc getSslContextExtraDataIndex*(): cint =
-    ## Retrieves unique index for storing extra data in SSLContext.
-    return SSL_CTX_get_ex_new_index(0, nil, nil, nil, nil)
+    result = SSLContext(newCTX)
+    # this is never freed, but SSLContext can't be freed anyway yet
+    let extraInternal = new(SslContextExtraInternal)
+    GC_ref(extraInternal)
+    result.setExtraData(extraInternalIndex, cast[pointer](extraInternal))
 
-  proc setExtraData*(ctx: SSLContext, index: cint, data: pointer) =
-    ## Stores arbitrary data inside SSLContext. The unique `index`
-    ## should be retrieved using getSslContextExtraDataIndex.
-    if SslCtx(ctx).SSL_CTX_set_ex_data(index, data) == -1:
+  proc getExtraInternal(ctx: SSLContext): SslContextExtraInternal =
+    return cast[SslContextExtraInternal](ctx.getExtraData(extraInternalIndex))
+
+  proc `pskIdentityHint=`*(ctx: SSLContext, hint: string) =
+    ## Sets the identity hint passed to server.
+    ##
+    ## Only used in PSK ciphersuites.
+    if SSLCTX(ctx).SSL_CTX_use_psk_identity_hint(hint) <= 0:
       raiseSSLError()
 
-  proc getExtraData*(ctx: SSLContext, index: cint): pointer =
-    ## Retrieves arbitrary data stored inside SSLContext.
-    return SslCtx(ctx).SSL_CTX_get_ex_data(index)
+  proc clientGetPskFunc*(ctx: SSLContext): SslClientGetPskFunc =
+    return ctx.getExtraInternal().clientGetPskFunc
+
+  proc pskClientCallback(ssl: SslPtr; hint: cstring; identity: cstring; max_identity_len: cuint; psk: ptr cuchar;
+    max_psk_len: cuint): cuint {.cdecl.} =
+    let ctx = SSLContext(ssl.SSL_get_SSL_CTX)
+    let hintString = if hint == nil: nil else: $hint
+    let (identityString, pskString) = (ctx.clientGetPskFunc)(hintString)
+    if psk.len.cuint > max_psk_len:
+      return 0
+    if identityString.len.cuint >= max_identity_len:
+      return 0
+
+    copyMem(identity, identityString.cstring, pskString.len + 1) # with the last zero byte
+    copyMem(psk, pskString.cstring, pskString.len)
+
+    return pskString.len.cuint
+
+  proc `clientGetPskFunc=`*(ctx: SSLContext, fun: SslClientGetPskFunc) =
+    ## Sets function that returns the client identity and the PSK based on identity
+    ## hint from the server.
+    ##
+    ## Only used in PSK ciphersuites.
+    ctx.getExtraInternal().clientGetPskFunc = fun
+    SslCtx(ctx).SSL_CTX_set_psk_client_callback(if fun == nil: nil else: pskClientCallback)
+
+  proc serverGetPskFunc*(ctx: SSLContext): SslServerGetPskFunc =
+    return ctx.getExtraInternal().serverGetPskFunc
+
+  proc pskServerCallback(ssl: SslCtx; identity: cstring; psk: ptr cuchar; max_psk_len: cint): cuint {.cdecl.} =
+    let ctx = SSLContext(ssl.SSL_get_SSL_CTX)
+    let pskString = (ctx.serverGetPskFunc)($identity)
+    if psk.len.cint > max_psk_len:
+      return 0
+    copyMem(psk, pskString.cstring, pskString.len)
+
+    return pskString.len.cuint
+
+  proc `serverGetPskFunc=`*(ctx: SSLContext, fun: SslServerGetPskFunc) =
+    ## Sets function that returns PSK based on the client identity.
+    ##
+    ## Only used in PSK ciphersuites.
+    ctx.getExtraInternal().serverGetPskFunc = fun
+    SslCtx(ctx).SSL_CTX_set_psk_server_callback(if fun == nil: nil
+                                                else: pskServerCallback)
+
+  proc getPskIdentity*(socket: Socket): string =
+    ## Gets the PSK identity provided by the client.
+    assert socket.isSSL
+    return $(socket.sslHandle.SSL_get_psk_identity)
 
   proc wrapSocket*(ctx: SSLContext, socket: Socket) =
     ## Wraps a socket in an SSL context. This function effectively turns
diff --git a/lib/wrappers/openssl.nim b/lib/wrappers/openssl.nim
index 9f24ca58d..7ede0f12c 100644
--- a/lib/wrappers/openssl.nim
+++ b/lib/wrappers/openssl.nim
@@ -197,6 +197,7 @@ proc TLSv1_method*(): PSSL_METHOD{.cdecl, dynlib: DLLSSLName, importc.}
 
 proc SSL_new*(context: SslCtx): SslPtr{.cdecl, dynlib: DLLSSLName, importc.}
 proc SSL_free*(ssl: SslPtr){.cdecl, dynlib: DLLSSLName, importc.}
+proc SSL_get_SSL_CTX*(ssl: SslPtr): SslCtx {.cdecl, dynlib: DLLSSLName, importc.}
 proc SSL_CTX_new*(meth: PSSL_METHOD): SslCtx{.cdecl,
     dynlib: DLLSSLName, importc.}
 proc SSL_CTX_load_verify_locations*(ctx: SslCtx, CAfile: cstring,
@@ -318,6 +319,25 @@ proc SSL_CTX_set_tlsext_servername_arg*(ctx: SslCtx, arg: pointer): int =
   ## Set the pointer to be used in the callback registered to ``SSL_CTX_set_tlsext_servername_callback``.
   result = SSL_CTX_ctrl(ctx, SSL_CTRL_SET_TLSEXT_SERVERNAME_ARG, 0, arg)
 
+type
+  PskClientCallback* = proc (ssl: SslPtr;
+    hint: cstring; identity: cstring; max_identity_len: cuint; psk: ptr cuchar;
+    max_psk_len: cuint): cuint {.cdecl.}
+
+  PskServerCallback* = proc (ssl: SslPtr;
+    identity: cstring; psk: ptr cuchar; max_psk_len: cint): cuint {.cdecl.}
+
+proc SSL_CTX_set_psk_client_callback*(ctx: SslCtx; callback: PskClientCallback) {.cdecl, dynlib: DLLSSLName, importc.}
+  ## Set callback called when OpenSSL needs PSK (for client).
+
+proc SSL_CTX_set_psk_server_callback*(ctx: SslCtx; callback: PskServerCallback) {.cdecl, dynlib: DLLSSLName, importc.}
+  ## Set callback called when OpenSSL needs PSK (for server).
+
+proc SSL_CTX_use_psk_identity_hint*(ctx: SslCtx; hint: cstring): cint {.cdecl, dynlib: DLLSSLName, importc.}
+  ## Set PSK identity hint to use.
+
+proc SSL_get_psk_identity*(ssl: SslPtr): cstring {.cdecl, dynlib: DLLSSLName, importc.}
+  ## Get PSK identity.
 
 proc bioNew*(b: PBIO_METHOD): BIO{.cdecl, dynlib: DLLUtilName, importc: "BIO_new".}
 proc bioFreeAll*(b: BIO){.cdecl, dynlib: DLLUtilName, importc: "BIO_free_all".}