From 3ebf27ddd24c04e87e33bfb6f8617d81c9fc1946 Mon Sep 17 00:00:00 2001 From: Michał Zieliński Date: Thu, 22 Oct 2015 23:51:52 +0200 Subject: net.nim: support storing arbitrary data inside SSLContext --- examples/ssl/extradata.nim | 14 ++++++++++++++ 1 file changed, 14 insertions(+) create mode 100644 examples/ssl/extradata.nim (limited to 'examples') diff --git a/examples/ssl/extradata.nim b/examples/ssl/extradata.nim new file mode 100644 index 000000000..f86dc57f2 --- /dev/null +++ b/examples/ssl/extradata.nim @@ -0,0 +1,14 @@ +# Stores extra data inside the SSL context. +import net + +# Our unique index for storing foos +let fooIndex = getSslContextExtraDataIndex() +# And another unique index for storing foos +let barIndex = getSslContextExtraDataIndex() +echo "got indexes ", fooIndex, " ", barIndex + +let ctx = newContext() +assert ctx.getExtraData(fooIndex) == nil +let foo: int = 5 +ctx.setExtraData(fooIndex, cast[pointer](foo)) +assert cast[int](ctx.getExtraData(fooIndex)) == foo -- cgit 1.4.1-2-gfad0 From ba61a8d00a65948fc0b3a1c100a20cca711fdd0f Mon Sep 17 00:00:00 2001 From: Michał Zieliński Date: Sat, 24 Oct 2015 08:53:18 +0200 Subject: net.nim: support for TLS-PSK ciphersuites --- examples/ssl/pskclient.nim | 15 +++++++ examples/ssl/pskserver.nim | 20 +++++++++ lib/pure/net.nim | 102 +++++++++++++++++++++++++++++++++++++++------ lib/wrappers/openssl.nim | 20 +++++++++ 4 files changed, 144 insertions(+), 13 deletions(-) create mode 100644 examples/ssl/pskclient.nim create mode 100644 examples/ssl/pskserver.nim (limited to 'examples') diff --git a/examples/ssl/pskclient.nim b/examples/ssl/pskclient.nim new file mode 100644 index 000000000..7c93bbb61 --- /dev/null +++ b/examples/ssl/pskclient.nim @@ -0,0 +1,15 @@ +# Create connection encrypted using preshared key (TLS-PSK). +import net + +static: assert defined(ssl) + +let sock = newSocket() +sock.connect("localhost", Port(8800)) + +proc clientFunc(identityHint: string): tuple[identity: string, psk: string] = + echo "identity hint ", identityHint.repr + return ("foo", "psk-of-foo") + +let context = newContext(cipherList="PSK-AES256-CBC-SHA") +context.clientGetPskFunc = clientFunc +context.wrapConnectedSocket(sock, handshakeAsClient) diff --git a/examples/ssl/pskserver.nim b/examples/ssl/pskserver.nim new file mode 100644 index 000000000..859eaa875 --- /dev/null +++ b/examples/ssl/pskserver.nim @@ -0,0 +1,20 @@ +# Accept connection encrypted using preshared key (TLS-PSK). +import net + +static: assert defined(ssl) + +let sock = newSocket() +sock.bindAddr(Port(8800)) +sock.listen() + +let context = newContext(cipherList="PSK-AES256-CBC-SHA") +context.pskIdentityHint = "hello" +context.serverGetPskFunc = proc(identity: string): string = "psk-of-" & identity + +while true: + var client = new(Socket) + sock.accept(client) + sock.setSockOpt(OptReuseAddr, true) + echo "accepted connection" + context.wrapConnectedSocket(client, handshakeAsServer) + echo "got connection with identity ", client.getPskIdentity() diff --git a/lib/pure/net.nim b/lib/pure/net.nim index 5498ebb7d..4bdfede42 100644 --- a/lib/pure/net.nim +++ b/lib/pure/net.nim @@ -38,6 +38,10 @@ when defined(ssl): SslHandshakeType* = enum handshakeAsClient, handshakeAsServer + SslClientGetPskFunc* = proc(hint: string): tuple[identity: string, psk: string] + + SslServerGetPskFunc* = proc(identity: string): string + {.deprecated: [ESSL: SSLError, TSSLCVerifyMode: SSLCVerifyMode, TSSLProtVersion: SSLProtVersion, PSSLContext: SSLContext, TSSLAcceptResult: SSLAcceptResult].} @@ -168,6 +172,10 @@ when defined(ssl): ErrLoadBioStrings() OpenSSL_add_all_algorithms() + type SslContextExtraInternal = ref object + serverGetPskFunc: SslServerGetPskFunc + clientGetPskFunc: SslClientGetPskFunc + proc raiseSSLError*(s = "") = ## Raises a new SSL error. if s != "": @@ -180,6 +188,22 @@ when defined(ssl): var errStr = ErrErrorString(err, nil) raise newException(SSLError, $errStr) + proc getSslContextExtraDataIndex*(): cint = + ## Retrieves unique index for storing extra data in SSLContext. + return SSL_CTX_get_ex_new_index(0, nil, nil, nil, nil) + + proc setExtraData*(ctx: SSLContext, index: cint, data: pointer) = + ## Stores arbitrary data inside SSLContext. The unique `index` + ## should be retrieved using getSslContextExtraDataIndex. + if SslCtx(ctx).SSL_CTX_set_ex_data(index, data) == -1: + raiseSSLError() + + proc getExtraData*(ctx: SSLContext, index: cint): pointer = + ## Retrieves arbitrary data stored inside SSLContext. + return SslCtx(ctx).SSL_CTX_get_ex_data(index) + + let extraInternalIndex = getSslContextExtraDataIndex() + # http://simplestcodings.blogspot.co.uk/2010/08/secure-server-client-using-openssl-in-c.html proc loadCertificates(ctx: SSL_CTX, certFile, keyFile: string) = if certFile != "" and not existsFile(certFile): @@ -202,7 +226,7 @@ when defined(ssl): raiseSSLError("Verification of private key file failed.") proc newContext*(protVersion = protSSLv23, verifyMode = CVerifyPeer, - certFile = "", keyFile = ""): SSLContext = + certFile = "", keyFile = "", cipherList = "ALL"): SSLContext = ## Creates an SSL context. ## ## Protocol version specifies the protocol to use. SSLv2, SSLv3, TLSv1 @@ -229,7 +253,7 @@ when defined(ssl): of protTLSv1: newCTX = SSL_CTX_new(TLSv1_method()) - if newCTX.SSLCTXSetCipherList("ALL") != 1: + if newCTX.SSLCTXSetCipherList(cipherList) != 1: raiseSSLError() case verifyMode of CVerifyPeer: @@ -241,21 +265,73 @@ when defined(ssl): discard newCTX.SSLCTXSetMode(SSL_MODE_AUTO_RETRY) newCTX.loadCertificates(certFile, keyFile) - return SSLContext(newCTX) - proc getSslContextExtraDataIndex*(): cint = - ## Retrieves unique index for storing extra data in SSLContext. - return SSL_CTX_get_ex_new_index(0, nil, nil, nil, nil) + result = SSLContext(newCTX) + # this is never freed, but SSLContext can't be freed anyway yet + let extraInternal = new(SslContextExtraInternal) + GC_ref(extraInternal) + result.setExtraData(extraInternalIndex, cast[pointer](extraInternal)) - proc setExtraData*(ctx: SSLContext, index: cint, data: pointer) = - ## Stores arbitrary data inside SSLContext. The unique `index` - ## should be retrieved using getSslContextExtraDataIndex. - if SslCtx(ctx).SSL_CTX_set_ex_data(index, data) == -1: + proc getExtraInternal(ctx: SSLContext): SslContextExtraInternal = + return cast[SslContextExtraInternal](ctx.getExtraData(extraInternalIndex)) + + proc `pskIdentityHint=`*(ctx: SSLContext, hint: string) = + ## Sets the identity hint passed to server. + ## + ## Only used in PSK ciphersuites. + if SSLCTX(ctx).SSL_CTX_use_psk_identity_hint(hint) <= 0: raiseSSLError() - proc getExtraData*(ctx: SSLContext, index: cint): pointer = - ## Retrieves arbitrary data stored inside SSLContext. - return SslCtx(ctx).SSL_CTX_get_ex_data(index) + proc clientGetPskFunc*(ctx: SSLContext): SslClientGetPskFunc = + return ctx.getExtraInternal().clientGetPskFunc + + proc pskClientCallback(ssl: SslPtr; hint: cstring; identity: cstring; max_identity_len: cuint; psk: ptr cuchar; + max_psk_len: cuint): cuint {.cdecl.} = + let ctx = SSLContext(ssl.SSL_get_SSL_CTX) + let hintString = if hint == nil: nil else: $hint + let (identityString, pskString) = (ctx.clientGetPskFunc)(hintString) + if psk.len.cuint > max_psk_len: + return 0 + if identityString.len.cuint >= max_identity_len: + return 0 + + copyMem(identity, identityString.cstring, pskString.len + 1) # with the last zero byte + copyMem(psk, pskString.cstring, pskString.len) + + return pskString.len.cuint + + proc `clientGetPskFunc=`*(ctx: SSLContext, fun: SslClientGetPskFunc) = + ## Sets function that returns the client identity and the PSK based on identity + ## hint from the server. + ## + ## Only used in PSK ciphersuites. + ctx.getExtraInternal().clientGetPskFunc = fun + SslCtx(ctx).SSL_CTX_set_psk_client_callback(if fun == nil: nil else: pskClientCallback) + + proc serverGetPskFunc*(ctx: SSLContext): SslServerGetPskFunc = + return ctx.getExtraInternal().serverGetPskFunc + + proc pskServerCallback(ssl: SslCtx; identity: cstring; psk: ptr cuchar; max_psk_len: cint): cuint {.cdecl.} = + let ctx = SSLContext(ssl.SSL_get_SSL_CTX) + let pskString = (ctx.serverGetPskFunc)($identity) + if psk.len.cint > max_psk_len: + return 0 + copyMem(psk, pskString.cstring, pskString.len) + + return pskString.len.cuint + + proc `serverGetPskFunc=`*(ctx: SSLContext, fun: SslServerGetPskFunc) = + ## Sets function that returns PSK based on the client identity. + ## + ## Only used in PSK ciphersuites. + ctx.getExtraInternal().serverGetPskFunc = fun + SslCtx(ctx).SSL_CTX_set_psk_server_callback(if fun == nil: nil + else: pskServerCallback) + + proc getPskIdentity*(socket: Socket): string = + ## Gets the PSK identity provided by the client. + assert socket.isSSL + return $(socket.sslHandle.SSL_get_psk_identity) proc wrapSocket*(ctx: SSLContext, socket: Socket) = ## Wraps a socket in an SSL context. This function effectively turns diff --git a/lib/wrappers/openssl.nim b/lib/wrappers/openssl.nim index 9f24ca58d..7ede0f12c 100644 --- a/lib/wrappers/openssl.nim +++ b/lib/wrappers/openssl.nim @@ -197,6 +197,7 @@ proc TLSv1_method*(): PSSL_METHOD{.cdecl, dynlib: DLLSSLName, importc.} proc SSL_new*(context: SslCtx): SslPtr{.cdecl, dynlib: DLLSSLName, importc.} proc SSL_free*(ssl: SslPtr){.cdecl, dynlib: DLLSSLName, importc.} +proc SSL_get_SSL_CTX*(ssl: SslPtr): SslCtx {.cdecl, dynlib: DLLSSLName, importc.} proc SSL_CTX_new*(meth: PSSL_METHOD): SslCtx{.cdecl, dynlib: DLLSSLName, importc.} proc SSL_CTX_load_verify_locations*(ctx: SslCtx, CAfile: cstring, @@ -318,6 +319,25 @@ proc SSL_CTX_set_tlsext_servername_arg*(ctx: SslCtx, arg: pointer): int = ## Set the pointer to be used in the callback registered to ``SSL_CTX_set_tlsext_servername_callback``. result = SSL_CTX_ctrl(ctx, SSL_CTRL_SET_TLSEXT_SERVERNAME_ARG, 0, arg) +type + PskClientCallback* = proc (ssl: SslPtr; + hint: cstring; identity: cstring; max_identity_len: cuint; psk: ptr cuchar; + max_psk_len: cuint): cuint {.cdecl.} + + PskServerCallback* = proc (ssl: SslPtr; + identity: cstring; psk: ptr cuchar; max_psk_len: cint): cuint {.cdecl.} + +proc SSL_CTX_set_psk_client_callback*(ctx: SslCtx; callback: PskClientCallback) {.cdecl, dynlib: DLLSSLName, importc.} + ## Set callback called when OpenSSL needs PSK (for client). + +proc SSL_CTX_set_psk_server_callback*(ctx: SslCtx; callback: PskServerCallback) {.cdecl, dynlib: DLLSSLName, importc.} + ## Set callback called when OpenSSL needs PSK (for server). + +proc SSL_CTX_use_psk_identity_hint*(ctx: SslCtx; hint: cstring): cint {.cdecl, dynlib: DLLSSLName, importc.} + ## Set PSK identity hint to use. + +proc SSL_get_psk_identity*(ssl: SslPtr): cstring {.cdecl, dynlib: DLLSSLName, importc.} + ## Get PSK identity. proc bioNew*(b: PBIO_METHOD): BIO{.cdecl, dynlib: DLLUtilName, importc: "BIO_new".} proc bioFreeAll*(b: BIO){.cdecl, dynlib: DLLUtilName, importc: "BIO_free_all".} -- cgit 1.4.1-2-gfad0 From 3ecf33fa6acc87b204ac0240b597d5d91d0a78f7 Mon Sep 17 00:00:00 2001 From: Michał Zieliński Date: Sat, 24 Oct 2015 22:48:33 +0200 Subject: net.nim: destroyContext for destroying SSLContext --- examples/ssl/pskclient.nim | 1 + lib/pure/net.nim | 8 +++++++- 2 files changed, 8 insertions(+), 1 deletion(-) (limited to 'examples') diff --git a/examples/ssl/pskclient.nim b/examples/ssl/pskclient.nim index 7c93bbb61..c83f27fbc 100644 --- a/examples/ssl/pskclient.nim +++ b/examples/ssl/pskclient.nim @@ -13,3 +13,4 @@ proc clientFunc(identityHint: string): tuple[identity: string, psk: string] = let context = newContext(cipherList="PSK-AES256-CBC-SHA") context.clientGetPskFunc = clientFunc context.wrapConnectedSocket(sock, handshakeAsClient) +context.destroyContext() diff --git a/lib/pure/net.nim b/lib/pure/net.nim index 4bdfede42..368ff6e87 100644 --- a/lib/pure/net.nim +++ b/lib/pure/net.nim @@ -267,7 +267,6 @@ when defined(ssl): newCTX.loadCertificates(certFile, keyFile) result = SSLContext(newCTX) - # this is never freed, but SSLContext can't be freed anyway yet let extraInternal = new(SslContextExtraInternal) GC_ref(extraInternal) result.setExtraData(extraInternalIndex, cast[pointer](extraInternal)) @@ -275,6 +274,13 @@ when defined(ssl): proc getExtraInternal(ctx: SSLContext): SslContextExtraInternal = return cast[SslContextExtraInternal](ctx.getExtraData(extraInternalIndex)) + proc destroyContext*(ctx: SSLContext) = + ## Free memory referenced by SSLContext. + let extraInternal = ctx.getExtraInternal() + if extraInternal != nil: + GC_unref(extraInternal) + SSLCTX(ctx).SSL_CTX_free() + proc `pskIdentityHint=`*(ctx: SSLContext, hint: string) = ## Sets the identity hint passed to server. ## -- cgit 1.4.1-2-gfad0 From da308be2d7b25683a3073187f699f95f81ac149e Mon Sep 17 00:00:00 2001 From: Michał Zieliński Date: Wed, 28 Oct 2015 19:55:04 +0100 Subject: net.nim: add support for Unix sockets --- examples/unix_socket/client.nim | 6 ++++++ examples/unix_socket/server.nim | 14 ++++++++++++++ lib/posix/posix.nim | 13 +++++++++++++ lib/pure/nativesockets.nim | 2 +- lib/pure/net.nim | 24 +++++++++++++++++++++++- 5 files changed, 57 insertions(+), 2 deletions(-) create mode 100644 examples/unix_socket/client.nim create mode 100644 examples/unix_socket/server.nim (limited to 'examples') diff --git a/examples/unix_socket/client.nim b/examples/unix_socket/client.nim new file mode 100644 index 000000000..f4283d64d --- /dev/null +++ b/examples/unix_socket/client.nim @@ -0,0 +1,6 @@ +import net + +let sock = newSocket(AF_UNIX, SOCK_STREAM, IPPROTO_IP) + +sock.connectUnix("sock") +sock.send("hello\n") diff --git a/examples/unix_socket/server.nim b/examples/unix_socket/server.nim new file mode 100644 index 000000000..e798bbb48 --- /dev/null +++ b/examples/unix_socket/server.nim @@ -0,0 +1,14 @@ +import net + +let sock = newSocket(AF_UNIX, SOCK_STREAM, IPPROTO_IP) +sock.bindUnix("sock") +sock.listen() + +while true: + var client = new(Socket) + sock.accept(client) + var output = "" + output.setLen 32 + client.readLine(output) + echo "got ", output + client.close() diff --git a/lib/posix/posix.nim b/lib/posix/posix.nim index 5f1dfcfcd..cae469fe8 100644 --- a/lib/posix/posix.nim +++ b/lib/posix/posix.nim @@ -439,6 +439,14 @@ when hasSpawnH: Tposix_spawn_file_actions* {.importc: "posix_spawn_file_actions_t", header: "", final, pure.} = object +when defined(linux): + # from sys/un.h + const Sockaddr_un_path_length* = 108 +else: + # according to http://pubs.opengroup.org/onlinepubs/009604499/basedefs/sys/un.h.html + # this is >=92 + const Sockaddr_un_path_length* = 92 + type Socklen* {.importc: "socklen_t", header: "".} = cuint TSa_Family* {.importc: "sa_family_t", header: "".} = cint @@ -448,6 +456,11 @@ type sa_family*: TSa_Family ## Address family. sa_data*: array [0..255, char] ## Socket address (variable-length data). + Sockaddr_un* {.importc: "struct sockaddr_un", header: "", + pure, final.} = object ## struct sockaddr_un + sun_family*: TSa_Family ## Address family. + sun_path*: array [0..Sockaddr_un_path_length-1, char] ## Socket path + Sockaddr_storage* {.importc: "struct sockaddr_storage", header: "", pure, final.} = object ## struct sockaddr_storage diff --git a/lib/pure/nativesockets.nim b/lib/pure/nativesockets.nim index c9e067a3e..e75555115 100644 --- a/lib/pure/nativesockets.nim +++ b/lib/pure/nativesockets.nim @@ -27,7 +27,7 @@ else: import posix export fcntl, F_GETFL, O_NONBLOCK, F_SETFL, EAGAIN, EWOULDBLOCK, MSG_NOSIGNAL, EINTR, EINPROGRESS, ECONNRESET, EPIPE, ENETRESET - export Sockaddr_storage + export Sockaddr_storage, Sockaddr_un, Sockaddr_un_path_length export SocketHandle, Sockaddr_in, Addrinfo, INADDR_ANY, SockAddr, SockLen, Sockaddr_in6, diff --git a/lib/pure/net.nim b/lib/pure/net.nim index d1016011e..8bc08c433 100644 --- a/lib/pure/net.nim +++ b/lib/pure/net.nim @@ -10,8 +10,9 @@ ## This module implements a high-level cross-platform sockets interface. {.deadCodeElim: on.} -import nativesockets, os, strutils, unsigned, parseutils, times +import nativesockets, os, strutils, parseutils, times export Port, `$`, `==` +export Domain, SockType, Protocol const useWinVersion = defined(Windows) or defined(nimdoc) @@ -582,6 +583,27 @@ proc connect*(socket: Socket, address: string, let ret = SSLConnect(socket.sslHandle) socketError(socket, ret) +when defined(posix): + proc makeUnixAddr(path: string): Sockaddr_un = + result.sun_family = AF_UNIX.toInt + if path.len >= Sockaddr_un_path_length: + raise newException(ValueError, "socket path too long") + copyMem(addr result.sun_path, path.cstring, path.len + 1) + + proc connectUnix*(socket: Socket, path: string) = + ## Connects to Unix socket on `path`. + var socketAddr = makeUnixAddr(path) + if socket.fd.connect(cast[ptr SockAddr](addr socketAddr), + sizeof(socketAddr).Socklen) != 0'i32: + raiseOSError(osLastError()) + + proc bindUnix*(socket: Socket, path: string) = + ## Binds Unix socket to `path`. + var socketAddr = makeUnixAddr(path) + if socket.fd.bindAddr(cast[ptr SockAddr](addr socketAddr), + sizeof(socketAddr).Socklen) != 0'i32: + raiseOSError(osLastError()) + when defined(ssl): proc handshake*(socket: Socket): bool {.tags: [ReadIOEffect, WriteIOEffect].} = ## This proc needs to be called on a socket after it connects. This is -- cgit 1.4.1-2-gfad0 From 5390c25b60e79f87aca339f7428575066b0b2d08 Mon Sep 17 00:00:00 2001 From: Dominik Picheta Date: Fri, 3 Jun 2016 13:22:18 +0100 Subject: Modified #3472 to make its API more idiomatic. --- examples/ssl/extradata.nim | 26 +++++++++++---- lib/pure/net.nim | 80 ++++++++++++++++++++++++++++++---------------- 2 files changed, 71 insertions(+), 35 deletions(-) (limited to 'examples') diff --git a/examples/ssl/extradata.nim b/examples/ssl/extradata.nim index f86dc57f2..1e3b89b02 100644 --- a/examples/ssl/extradata.nim +++ b/examples/ssl/extradata.nim @@ -1,14 +1,26 @@ # Stores extra data inside the SSL context. import net +let ctx = newContext() + # Our unique index for storing foos -let fooIndex = getSslContextExtraDataIndex() +let fooIndex = ctx.getExtraDataIndex() # And another unique index for storing foos -let barIndex = getSslContextExtraDataIndex() +let barIndex = ctx.getExtraDataIndex() echo "got indexes ", fooIndex, " ", barIndex -let ctx = newContext() -assert ctx.getExtraData(fooIndex) == nil -let foo: int = 5 -ctx.setExtraData(fooIndex, cast[pointer](foo)) -assert cast[int](ctx.getExtraData(fooIndex)) == foo +try: + discard ctx.getExtraData(fooIndex) + assert false +except IndexError: + echo("Success") + +type + FooRef = ref object of RootRef + foo: int + +let foo = FooRef(foo: 5) +ctx.setExtraData(fooIndex, foo) +doAssert ctx.getExtraData(fooIndex).FooRef == foo + +ctx.destroyContext() diff --git a/lib/pure/net.nim b/lib/pure/net.nim index 85d4245b2..d6ec31481 100644 --- a/lib/pure/net.nim +++ b/lib/pure/net.nim @@ -66,7 +66,7 @@ ## {.deadCodeElim: on.} -import nativesockets, os, strutils, parseutils, times +import nativesockets, os, strutils, parseutils, times, sets export Port, `$`, `==` export Domain, SockType, Protocol @@ -88,7 +88,10 @@ when defineSsl: SslProtVersion* = enum protSSLv2, protSSLv3, protTLSv1, protSSLv23 - SslContext* = distinct SslCtx + SslContext* = ref object + context: SslCtx + extraInternalIndex: int + referencedData: HashSet[int] SslAcceptResult* = enum AcceptNoClient = 0, AcceptNoHandshake, AcceptSuccess @@ -229,9 +232,10 @@ when defineSsl: ErrLoadBioStrings() OpenSSL_add_all_algorithms() - type SslContextExtraInternal = ref object - serverGetPskFunc: SslServerGetPskFunc - clientGetPskFunc: SslClientGetPskFunc + type + SslContextExtraInternal = ref object of RootRef + serverGetPskFunc: SslServerGetPskFunc + clientGetPskFunc: SslClientGetPskFunc proc raiseSSLError*(s = "") = ## Raises a new SSL error. @@ -245,21 +249,33 @@ when defineSsl: var errStr = ErrErrorString(err, nil) raise newException(SSLError, $errStr) - proc getSslContextExtraDataIndex*(): cint = + proc getExtraDataIndex*(ctx: SSLContext): int = ## Retrieves unique index for storing extra data in SSLContext. - return SSL_CTX_get_ex_new_index(0, nil, nil, nil, nil) + result = SSL_CTX_get_ex_new_index(0, nil, nil, nil, nil).int + if result < 0: + raiseSSLError() + + proc getExtraData*(ctx: SSLContext, index: int): RootRef = + ## Retrieves arbitrary data stored inside SSLContext. + if index notin ctx.referencedData: + raise newException(IndexError, "No data with that index.") + let res = ctx.context.SSL_CTX_get_ex_data(index.cint) + if cast[int](res) == 0: + raiseSSLError() + return cast[RootRef](res) - proc setExtraData*(ctx: SSLContext, index: cint, data: pointer) = + proc setExtraData*(ctx: SSLContext, index: int, data: RootRef) = ## Stores arbitrary data inside SSLContext. The unique `index` ## should be retrieved using getSslContextExtraDataIndex. - if SslCtx(ctx).SSL_CTX_set_ex_data(index, data) == -1: - raiseSSLError() + if index in ctx.referencedData: + GC_unref(getExtraData(ctx, index)) - proc getExtraData*(ctx: SSLContext, index: cint): pointer = - ## Retrieves arbitrary data stored inside SSLContext. - return SslCtx(ctx).SSL_CTX_get_ex_data(index) + if ctx.context.SSL_CTX_set_ex_data(index.cint, cast[pointer](data)) == -1: + raiseSSLError() - let extraInternalIndex = getSslContextExtraDataIndex() + if index notin ctx.referencedData: + ctx.referencedData.incl(index) + GC_ref(data) # http://simplestcodings.blogspot.co.uk/2010/08/secure-server-client-using-openssl-in-c.html proc loadCertificates(ctx: SSL_CTX, certFile, keyFile: string) = @@ -323,26 +339,33 @@ when defineSsl: discard newCTX.SSLCTXSetMode(SSL_MODE_AUTO_RETRY) newCTX.loadCertificates(certFile, keyFile) - result = SSLContext(newCTX) + result = SSLContext(context: newCTX, extraInternalIndex: 0, + referencedData: initSet[int]()) + result.extraInternalIndex = getExtraDataIndex(result) + # The PSK callback functions assume the internal index is 0. + assert result.extraInternalIndex == 0 + let extraInternal = new(SslContextExtraInternal) - GC_ref(extraInternal) - result.setExtraData(extraInternalIndex, cast[pointer](extraInternal)) + result.setExtraData(result.extraInternalIndex, extraInternal) proc getExtraInternal(ctx: SSLContext): SslContextExtraInternal = - return cast[SslContextExtraInternal](ctx.getExtraData(extraInternalIndex)) + return SslContextExtraInternal(ctx.getExtraData(ctx.extraInternalIndex)) proc destroyContext*(ctx: SSLContext) = ## Free memory referenced by SSLContext. - let extraInternal = ctx.getExtraInternal() - if extraInternal != nil: - GC_unref(extraInternal) - SSLCTX(ctx).SSL_CTX_free() + + # We assume here that OpenSSL's internal indexes increase by 1 each time. + # That means we can assume that the next internal index is the length of + # extra data indexes. + for i in ctx.referencedData: + GC_unref(getExtraData(ctx, i).RootRef) + ctx.context.SSL_CTX_free() proc `pskIdentityHint=`*(ctx: SSLContext, hint: string) = ## Sets the identity hint passed to server. ## ## Only used in PSK ciphersuites. - if SSLCTX(ctx).SSL_CTX_use_psk_identity_hint(hint) <= 0: + if ctx.context.SSL_CTX_use_psk_identity_hint(hint) <= 0: raiseSSLError() proc clientGetPskFunc*(ctx: SSLContext): SslClientGetPskFunc = @@ -350,7 +373,7 @@ when defineSsl: proc pskClientCallback(ssl: SslPtr; hint: cstring; identity: cstring; max_identity_len: cuint; psk: ptr cuchar; max_psk_len: cuint): cuint {.cdecl.} = - let ctx = SSLContext(ssl.SSL_get_SSL_CTX) + let ctx = SSLContext(context: ssl.SSL_get_SSL_CTX, extraInternalIndex: 0) let hintString = if hint == nil: nil else: $hint let (identityString, pskString) = (ctx.clientGetPskFunc)(hintString) if psk.len.cuint > max_psk_len: @@ -369,13 +392,14 @@ when defineSsl: ## ## Only used in PSK ciphersuites. ctx.getExtraInternal().clientGetPskFunc = fun - SslCtx(ctx).SSL_CTX_set_psk_client_callback(if fun == nil: nil else: pskClientCallback) + ctx.context.SSL_CTX_set_psk_client_callback( + if fun == nil: nil else: pskClientCallback) proc serverGetPskFunc*(ctx: SSLContext): SslServerGetPskFunc = return ctx.getExtraInternal().serverGetPskFunc proc pskServerCallback(ssl: SslCtx; identity: cstring; psk: ptr cuchar; max_psk_len: cint): cuint {.cdecl.} = - let ctx = SSLContext(ssl.SSL_get_SSL_CTX) + let ctx = SSLContext(context: ssl.SSL_get_SSL_CTX, extraInternalIndex: 0) let pskString = (ctx.serverGetPskFunc)($identity) if psk.len.cint > max_psk_len: return 0 @@ -388,7 +412,7 @@ when defineSsl: ## ## Only used in PSK ciphersuites. ctx.getExtraInternal().serverGetPskFunc = fun - SslCtx(ctx).SSL_CTX_set_psk_server_callback(if fun == nil: nil + ctx.context.SSL_CTX_set_psk_server_callback(if fun == nil: nil else: pskServerCallback) proc getPskIdentity*(socket: Socket): string = @@ -409,7 +433,7 @@ when defineSsl: assert (not socket.isSSL) socket.isSSL = true socket.sslContext = ctx - socket.sslHandle = SSLNew(SSLCTX(socket.sslContext)) + socket.sslHandle = SSLNew(socket.sslContext.context) socket.sslNoHandshake = false socket.sslHasPeekChar = false if socket.sslHandle == nil: -- cgit 1.4.1-2-gfad0