diff options
-rw-r--r-- | examples/ssl/extradata.nim | 14 | ||||
-rw-r--r-- | examples/ssl/pskclient.nim | 16 | ||||
-rw-r--r-- | examples/ssl/pskserver.nim | 20 | ||||
-rw-r--r-- | lib/pure/net.nim | 102 | ||||
-rw-r--r-- | lib/wrappers/openssl.nim | 24 |
5 files changed, 173 insertions, 3 deletions
diff --git a/examples/ssl/extradata.nim b/examples/ssl/extradata.nim new file mode 100644 index 000000000..f86dc57f2 --- /dev/null +++ b/examples/ssl/extradata.nim @@ -0,0 +1,14 @@ +# Stores extra data inside the SSL context. +import net + +# Our unique index for storing foos +let fooIndex = getSslContextExtraDataIndex() +# And another unique index for storing foos +let barIndex = getSslContextExtraDataIndex() +echo "got indexes ", fooIndex, " ", barIndex + +let ctx = newContext() +assert ctx.getExtraData(fooIndex) == nil +let foo: int = 5 +ctx.setExtraData(fooIndex, cast[pointer](foo)) +assert cast[int](ctx.getExtraData(fooIndex)) == foo diff --git a/examples/ssl/pskclient.nim b/examples/ssl/pskclient.nim new file mode 100644 index 000000000..c83f27fbc --- /dev/null +++ b/examples/ssl/pskclient.nim @@ -0,0 +1,16 @@ +# Create connection encrypted using preshared key (TLS-PSK). +import net + +static: assert defined(ssl) + +let sock = newSocket() +sock.connect("localhost", Port(8800)) + +proc clientFunc(identityHint: string): tuple[identity: string, psk: string] = + echo "identity hint ", identityHint.repr + return ("foo", "psk-of-foo") + +let context = newContext(cipherList="PSK-AES256-CBC-SHA") +context.clientGetPskFunc = clientFunc +context.wrapConnectedSocket(sock, handshakeAsClient) +context.destroyContext() diff --git a/examples/ssl/pskserver.nim b/examples/ssl/pskserver.nim new file mode 100644 index 000000000..859eaa875 --- /dev/null +++ b/examples/ssl/pskserver.nim @@ -0,0 +1,20 @@ +# Accept connection encrypted using preshared key (TLS-PSK). +import net + +static: assert defined(ssl) + +let sock = newSocket() +sock.bindAddr(Port(8800)) +sock.listen() + +let context = newContext(cipherList="PSK-AES256-CBC-SHA") +context.pskIdentityHint = "hello" +context.serverGetPskFunc = proc(identity: string): string = "psk-of-" & identity + +while true: + var client = new(Socket) + sock.accept(client) + sock.setSockOpt(OptReuseAddr, true) + echo "accepted connection" + context.wrapConnectedSocket(client, handshakeAsServer) + echo "got connection with identity ", client.getPskIdentity() diff --git a/lib/pure/net.nim b/lib/pure/net.nim index cb8cea720..85d4245b2 100644 --- a/lib/pure/net.nim +++ b/lib/pure/net.nim @@ -96,6 +96,10 @@ when defineSsl: SslHandshakeType* = enum handshakeAsClient, handshakeAsServer + SslClientGetPskFunc* = proc(hint: string): tuple[identity: string, psk: string] + + SslServerGetPskFunc* = proc(identity: string): string + {.deprecated: [ESSL: SSLError, TSSLCVerifyMode: SSLCVerifyMode, TSSLProtVersion: SSLProtVersion, PSSLContext: SSLContext, TSSLAcceptResult: SSLAcceptResult].} @@ -225,6 +229,10 @@ when defineSsl: ErrLoadBioStrings() OpenSSL_add_all_algorithms() + type SslContextExtraInternal = ref object + serverGetPskFunc: SslServerGetPskFunc + clientGetPskFunc: SslClientGetPskFunc + proc raiseSSLError*(s = "") = ## Raises a new SSL error. if s != "": @@ -237,6 +245,22 @@ when defineSsl: var errStr = ErrErrorString(err, nil) raise newException(SSLError, $errStr) + proc getSslContextExtraDataIndex*(): cint = + ## Retrieves unique index for storing extra data in SSLContext. + return SSL_CTX_get_ex_new_index(0, nil, nil, nil, nil) + + proc setExtraData*(ctx: SSLContext, index: cint, data: pointer) = + ## Stores arbitrary data inside SSLContext. The unique `index` + ## should be retrieved using getSslContextExtraDataIndex. + if SslCtx(ctx).SSL_CTX_set_ex_data(index, data) == -1: + raiseSSLError() + + proc getExtraData*(ctx: SSLContext, index: cint): pointer = + ## Retrieves arbitrary data stored inside SSLContext. + return SslCtx(ctx).SSL_CTX_get_ex_data(index) + + let extraInternalIndex = getSslContextExtraDataIndex() + # http://simplestcodings.blogspot.co.uk/2010/08/secure-server-client-using-openssl-in-c.html proc loadCertificates(ctx: SSL_CTX, certFile, keyFile: string) = if certFile != "" and not existsFile(certFile): @@ -259,7 +283,7 @@ when defineSsl: raiseSSLError("Verification of private key file failed.") proc newContext*(protVersion = protSSLv23, verifyMode = CVerifyPeer, - certFile = "", keyFile = ""): SSLContext = + certFile = "", keyFile = "", cipherList = "ALL"): SSLContext = ## Creates an SSL context. ## ## Protocol version specifies the protocol to use. SSLv2, SSLv3, TLSv1 @@ -286,7 +310,7 @@ when defineSsl: of protTLSv1: newCTX = SSL_CTX_new(TLSv1_method()) - if newCTX.SSLCTXSetCipherList("ALL") != 1: + if newCTX.SSLCTXSetCipherList(cipherList) != 1: raiseSSLError() case verifyMode of CVerifyPeer: @@ -298,7 +322,79 @@ when defineSsl: discard newCTX.SSLCTXSetMode(SSL_MODE_AUTO_RETRY) newCTX.loadCertificates(certFile, keyFile) - return SSLContext(newCTX) + + result = SSLContext(newCTX) + let extraInternal = new(SslContextExtraInternal) + GC_ref(extraInternal) + result.setExtraData(extraInternalIndex, cast[pointer](extraInternal)) + + proc getExtraInternal(ctx: SSLContext): SslContextExtraInternal = + return cast[SslContextExtraInternal](ctx.getExtraData(extraInternalIndex)) + + proc destroyContext*(ctx: SSLContext) = + ## Free memory referenced by SSLContext. + let extraInternal = ctx.getExtraInternal() + if extraInternal != nil: + GC_unref(extraInternal) + SSLCTX(ctx).SSL_CTX_free() + + proc `pskIdentityHint=`*(ctx: SSLContext, hint: string) = + ## Sets the identity hint passed to server. + ## + ## Only used in PSK ciphersuites. + if SSLCTX(ctx).SSL_CTX_use_psk_identity_hint(hint) <= 0: + raiseSSLError() + + proc clientGetPskFunc*(ctx: SSLContext): SslClientGetPskFunc = + return ctx.getExtraInternal().clientGetPskFunc + + proc pskClientCallback(ssl: SslPtr; hint: cstring; identity: cstring; max_identity_len: cuint; psk: ptr cuchar; + max_psk_len: cuint): cuint {.cdecl.} = + let ctx = SSLContext(ssl.SSL_get_SSL_CTX) + let hintString = if hint == nil: nil else: $hint + let (identityString, pskString) = (ctx.clientGetPskFunc)(hintString) + if psk.len.cuint > max_psk_len: + return 0 + if identityString.len.cuint >= max_identity_len: + return 0 + + copyMem(identity, identityString.cstring, pskString.len + 1) # with the last zero byte + copyMem(psk, pskString.cstring, pskString.len) + + return pskString.len.cuint + + proc `clientGetPskFunc=`*(ctx: SSLContext, fun: SslClientGetPskFunc) = + ## Sets function that returns the client identity and the PSK based on identity + ## hint from the server. + ## + ## Only used in PSK ciphersuites. + ctx.getExtraInternal().clientGetPskFunc = fun + SslCtx(ctx).SSL_CTX_set_psk_client_callback(if fun == nil: nil else: pskClientCallback) + + proc serverGetPskFunc*(ctx: SSLContext): SslServerGetPskFunc = + return ctx.getExtraInternal().serverGetPskFunc + + proc pskServerCallback(ssl: SslCtx; identity: cstring; psk: ptr cuchar; max_psk_len: cint): cuint {.cdecl.} = + let ctx = SSLContext(ssl.SSL_get_SSL_CTX) + let pskString = (ctx.serverGetPskFunc)($identity) + if psk.len.cint > max_psk_len: + return 0 + copyMem(psk, pskString.cstring, pskString.len) + + return pskString.len.cuint + + proc `serverGetPskFunc=`*(ctx: SSLContext, fun: SslServerGetPskFunc) = + ## Sets function that returns PSK based on the client identity. + ## + ## Only used in PSK ciphersuites. + ctx.getExtraInternal().serverGetPskFunc = fun + SslCtx(ctx).SSL_CTX_set_psk_server_callback(if fun == nil: nil + else: pskServerCallback) + + proc getPskIdentity*(socket: Socket): string = + ## Gets the PSK identity provided by the client. + assert socket.isSSL + return $(socket.sslHandle.SSL_get_psk_identity) proc wrapSocket*(ctx: SSLContext, socket: Socket) = ## Wraps a socket in an SSL context. This function effectively turns diff --git a/lib/wrappers/openssl.nim b/lib/wrappers/openssl.nim index 635d52a64..9dad7e489 100644 --- a/lib/wrappers/openssl.nim +++ b/lib/wrappers/openssl.nim @@ -197,6 +197,7 @@ proc TLSv1_method*(): PSSL_METHOD{.cdecl, dynlib: DLLSSLName, importc.} proc SSL_new*(context: SslCtx): SslPtr{.cdecl, dynlib: DLLSSLName, importc.} proc SSL_free*(ssl: SslPtr){.cdecl, dynlib: DLLSSLName, importc.} +proc SSL_get_SSL_CTX*(ssl: SslPtr): SslCtx {.cdecl, dynlib: DLLSSLName, importc.} proc SSL_CTX_new*(meth: PSSL_METHOD): SslCtx{.cdecl, dynlib: DLLSSLName, importc.} proc SSL_CTX_load_verify_locations*(ctx: SslCtx, CAfile: cstring, @@ -216,6 +217,10 @@ proc SSL_CTX_use_PrivateKey_file*(ctx: SslCtx, proc SSL_CTX_check_private_key*(ctx: SslCtx): cInt{.cdecl, dynlib: DLLSSLName, importc.} +proc SSL_CTX_get_ex_new_index*(argl: clong, argp: pointer, new_func: pointer, dup_func: pointer, free_func: pointer): cint {.cdecl, dynlib: DLLSSLName, importc.} +proc SSL_CTX_set_ex_data*(ssl: SslCtx, idx: cint, arg: pointer): cint {.cdecl, dynlib: DLLSSLName, importc.} +proc SSL_CTX_get_ex_data*(ssl: SslCtx, idx: cint): pointer {.cdecl, dynlib: DLLSSLName, importc.} + proc SSL_set_fd*(ssl: SslPtr, fd: SocketHandle): cint{.cdecl, dynlib: DLLSSLName, importc.} proc SSL_shutdown*(ssl: SslPtr): cInt{.cdecl, dynlib: DLLSSLName, importc.} @@ -314,6 +319,25 @@ proc SSL_CTX_set_tlsext_servername_arg*(ctx: SslCtx, arg: pointer): int = ## Set the pointer to be used in the callback registered to ``SSL_CTX_set_tlsext_servername_callback``. result = SSL_CTX_ctrl(ctx, SSL_CTRL_SET_TLSEXT_SERVERNAME_ARG, 0, arg) +type + PskClientCallback* = proc (ssl: SslPtr; + hint: cstring; identity: cstring; max_identity_len: cuint; psk: ptr cuchar; + max_psk_len: cuint): cuint {.cdecl.} + + PskServerCallback* = proc (ssl: SslPtr; + identity: cstring; psk: ptr cuchar; max_psk_len: cint): cuint {.cdecl.} + +proc SSL_CTX_set_psk_client_callback*(ctx: SslCtx; callback: PskClientCallback) {.cdecl, dynlib: DLLSSLName, importc.} + ## Set callback called when OpenSSL needs PSK (for client). + +proc SSL_CTX_set_psk_server_callback*(ctx: SslCtx; callback: PskServerCallback) {.cdecl, dynlib: DLLSSLName, importc.} + ## Set callback called when OpenSSL needs PSK (for server). + +proc SSL_CTX_use_psk_identity_hint*(ctx: SslCtx; hint: cstring): cint {.cdecl, dynlib: DLLSSLName, importc.} + ## Set PSK identity hint to use. + +proc SSL_get_psk_identity*(ssl: SslPtr): cstring {.cdecl, dynlib: DLLSSLName, importc.} + ## Get PSK identity. proc bioNew*(b: PBIO_METHOD): BIO{.cdecl, dynlib: DLLUtilName, importc: "BIO_new".} proc bioFreeAll*(b: BIO){.cdecl, dynlib: DLLUtilName, importc: "BIO_free_all".} |