summary refs log tree commit diff stats
path: root/lib/pure/net.nim
diff options
context:
space:
mode:
Diffstat (limited to 'lib/pure/net.nim')
-rw-r--r--lib/pure/net.nim33
1 files changed, 10 insertions, 23 deletions
diff --git a/lib/pure/net.nim b/lib/pure/net.nim
index 863a8a6f4..5e10f2291 100644
--- a/lib/pure/net.nim
+++ b/lib/pure/net.nim
@@ -90,8 +90,8 @@ when defineSsl:
 
     SslContext* = ref object
       context*: SslCtx
-      extraInternalIndex: int
       referencedData: HashSet[int]
+      extraInternal: SslContextExtraInternal
 
     SslAcceptResult* = enum
       AcceptNoClient = 0, AcceptNoHandshake, AcceptSuccess
@@ -103,6 +103,10 @@ when defineSsl:
 
     SslServerGetPskFunc* = proc(identity: string): string
 
+    SslContextExtraInternal = ref object of RootRef
+      serverGetPskFunc: SslServerGetPskFunc
+      clientGetPskFunc: SslClientGetPskFunc
+
   {.deprecated: [ESSL: SSLError, TSSLCVerifyMode: SSLCVerifyMode,
     TSSLProtVersion: SSLProtVersion, PSSLContext: SSLContext,
     TSSLAcceptResult: SSLAcceptResult].}
@@ -240,11 +244,6 @@ when defineSsl:
   ErrLoadBioStrings()
   OpenSSL_add_all_algorithms()
 
-  type
-    SslContextExtraInternal = ref object of RootRef
-      serverGetPskFunc: SslServerGetPskFunc
-      clientGetPskFunc: SslClientGetPskFunc
-
   proc raiseSSLError*(s = "") =
     ## Raises a new SSL error.
     if s != "":
@@ -257,12 +256,6 @@ when defineSsl:
     var errStr = ErrErrorString(err, nil)
     raise newException(SSLError, $errStr)
 
-  proc getExtraDataIndex*(ctx: SSLContext): int =
-    ## Retrieves unique index for storing extra data in SSLContext.
-    result = SSL_CTX_get_ex_new_index(0, nil, nil, nil, nil).int
-    if result < 0:
-      raiseSSLError()
-
   proc getExtraData*(ctx: SSLContext, index: int): RootRef =
     ## Retrieves arbitrary data stored inside SSLContext.
     if index notin ctx.referencedData:
@@ -347,15 +340,11 @@ when defineSsl:
     discard newCTX.SSLCTXSetMode(SSL_MODE_AUTO_RETRY)
     newCTX.loadCertificates(certFile, keyFile)
 
-    result = SSLContext(context: newCTX, extraInternalIndex: 0,
-        referencedData: initSet[int]())
-    result.extraInternalIndex = getExtraDataIndex(result)
-
-    let extraInternal = new(SslContextExtraInternal)
-    result.setExtraData(result.extraInternalIndex, extraInternal)
+    result = SSLContext(context: newCTX, referencedData: initSet[int](),
+      extraInternal: new(SslContextExtraInternal))
 
   proc getExtraInternal(ctx: SSLContext): SslContextExtraInternal =
-    return SslContextExtraInternal(ctx.getExtraData(ctx.extraInternalIndex))
+    return ctx.extraInternal
 
   proc destroyContext*(ctx: SSLContext) =
     ## Free memory referenced by SSLContext.
@@ -379,7 +368,7 @@ when defineSsl:
 
   proc pskClientCallback(ssl: SslPtr; hint: cstring; identity: cstring; max_identity_len: cuint; psk: ptr cuchar;
     max_psk_len: cuint): cuint {.cdecl.} =
-    let ctx = SSLContext(context: ssl.SSL_get_SSL_CTX, extraInternalIndex: 0)
+    let ctx = SSLContext(context: ssl.SSL_get_SSL_CTX)
     let hintString = if hint == nil: nil else: $hint
     let (identityString, pskString) = (ctx.clientGetPskFunc)(hintString)
     if psk.len.cuint > max_psk_len:
@@ -398,8 +387,6 @@ when defineSsl:
     ##
     ## Only used in PSK ciphersuites.
     ctx.getExtraInternal().clientGetPskFunc = fun
-    assert ctx.extraInternalIndex == 0,
-          "The pskClientCallback assumes the extraInternalIndex is 0"
     ctx.context.SSL_CTX_set_psk_client_callback(
         if fun == nil: nil else: pskClientCallback)
 
@@ -407,7 +394,7 @@ when defineSsl:
     return ctx.getExtraInternal().serverGetPskFunc
 
   proc pskServerCallback(ssl: SslCtx; identity: cstring; psk: ptr cuchar; max_psk_len: cint): cuint {.cdecl.} =
-    let ctx = SSLContext(context: ssl.SSL_get_SSL_CTX, extraInternalIndex: 0)
+    let ctx = SSLContext(context: ssl.SSL_get_SSL_CTX)
     let pskString = (ctx.serverGetPskFunc)($identity)
     if psk.len.cint > max_psk_len:
       return 0