diff options
-rw-r--r-- | compiler/cgen.nim | 1 | ||||
-rw-r--r-- | compiler/commands.nim | 2 | ||||
-rw-r--r-- | compiler/options.nim | 11 | ||||
-rw-r--r-- | lib/pure/dynlib.nim | 32 | ||||
-rw-r--r-- | lib/pure/net.nim | 33 | ||||
-rw-r--r-- | lib/wrappers/openssl.nim | 81 |
6 files changed, 109 insertions, 51 deletions
diff --git a/compiler/cgen.nim b/compiler/cgen.nim index 6e18c8389..3217b86e4 100644 --- a/compiler/cgen.nim +++ b/compiler/cgen.nim @@ -17,6 +17,7 @@ import lowerings, semparallel from modulegraphs import ModuleGraph +from dynlib import libCandidates import strutils except `%` # collides with ropes.`%` diff --git a/compiler/commands.nim b/compiler/commands.nim index 590c4871d..f85e53511 100644 --- a/compiler/commands.nim +++ b/compiler/commands.nim @@ -226,6 +226,8 @@ proc testCompileOptionArg*(switch, arg: string, info: TLineInfo): bool = of "staticlib": result = contains(gGlobalOptions, optGenStaticLib) and not contains(gGlobalOptions, optGenGuiApp) else: localError(info, errGuiConsoleOrLibExpectedButXFound, arg) + of "dynliboverride": + result = isDynlibOverride(arg) else: invalidCmdLineOption(passCmd1, switch, info) proc testCompileOption*(switch: string, info: TLineInfo): bool = diff --git a/compiler/options.nim b/compiler/options.nim index b04f6a963..04ed2412e 100644 --- a/compiler/options.nim +++ b/compiler/options.nim @@ -372,17 +372,6 @@ proc findModule*(modulename, currentModule: string): string = result = findFile(m) patchModule() -proc libCandidates*(s: string, dest: var seq[string]) = - var le = strutils.find(s, '(') - var ri = strutils.find(s, ')', le+1) - if le >= 0 and ri > le: - var prefix = substr(s, 0, le - 1) - var suffix = substr(s, ri + 1) - for middle in split(substr(s, le + 1, ri - 1), '|'): - libCandidates(prefix & middle & suffix, dest) - else: - add(dest, s) - proc canonDynlibName(s: string): string = let start = if s.startsWith("lib"): 3 else: 0 let ende = strutils.find(s, {'(', ')', '.'}) diff --git a/lib/pure/dynlib.nim b/lib/pure/dynlib.nim index 906a9d23e..fda41dadb 100644 --- a/lib/pure/dynlib.nim +++ b/lib/pure/dynlib.nim @@ -11,20 +11,22 @@ ## libraries. On POSIX this uses the ``dlsym`` mechanism, on ## Windows ``LoadLibrary``. +import strutils + type LibHandle* = pointer ## a handle to a dynamically loaded library {.deprecated: [TLibHandle: LibHandle].} -proc loadLib*(path: string, global_symbols=false): LibHandle +proc loadLib*(path: string, global_symbols=false): LibHandle {.gcsafe.} ## loads a library from `path`. Returns nil if the library could not ## be loaded. -proc loadLib*(): LibHandle +proc loadLib*(): LibHandle {.gcsafe.} ## gets the handle from the current executable. Returns nil if the ## library could not be loaded. -proc unloadLib*(lib: LibHandle) +proc unloadLib*(lib: LibHandle) {.gcsafe.} ## unloads the library `lib` proc raiseInvalidLibrary*(name: cstring) {.noinline, noreturn.} = @@ -34,7 +36,7 @@ proc raiseInvalidLibrary*(name: cstring) {.noinline, noreturn.} = e.msg = "could not find symbol: " & $name raise e -proc symAddr*(lib: LibHandle, name: cstring): pointer +proc symAddr*(lib: LibHandle, name: cstring): pointer {.gcsafe.} ## retrieves the address of a procedure/variable from `lib`. Returns nil ## if the symbol could not be found. @@ -44,6 +46,28 @@ proc checkedSymAddr*(lib: LibHandle, name: cstring): pointer = result = symAddr(lib, name) if result == nil: raiseInvalidLibrary(name) +proc libCandidates*(s: string, dest: var seq[string]) = + ## given a library name pattern `s` write possible library names to `dest`. + var le = strutils.find(s, '(') + var ri = strutils.find(s, ')', le+1) + if le >= 0 and ri > le: + var prefix = substr(s, 0, le - 1) + var suffix = substr(s, ri + 1) + for middle in split(substr(s, le + 1, ri - 1), '|'): + libCandidates(prefix & middle & suffix, dest) + else: + add(dest, s) + +proc loadLibPattern*(pattern: string, global_symbols=false): LibHandle = + ## loads a library with name matching `pattern`, similar to what `dlimport` + ## pragma does. Returns nil if the library could not be loaded. + ## Warning: this proc uses the GC and so cannot be used to load the GC. + var candidates = newSeq[string]() + libCandidates(pattern, candidates) + for c in candidates: + result = loadLib(c, global_symbols) + if not result.isNil: break + when defined(posix): # # ========================================================================= diff --git a/lib/pure/net.nim b/lib/pure/net.nim index 863a8a6f4..5e10f2291 100644 --- a/lib/pure/net.nim +++ b/lib/pure/net.nim @@ -90,8 +90,8 @@ when defineSsl: SslContext* = ref object context*: SslCtx - extraInternalIndex: int referencedData: HashSet[int] + extraInternal: SslContextExtraInternal SslAcceptResult* = enum AcceptNoClient = 0, AcceptNoHandshake, AcceptSuccess @@ -103,6 +103,10 @@ when defineSsl: SslServerGetPskFunc* = proc(identity: string): string + SslContextExtraInternal = ref object of RootRef + serverGetPskFunc: SslServerGetPskFunc + clientGetPskFunc: SslClientGetPskFunc + {.deprecated: [ESSL: SSLError, TSSLCVerifyMode: SSLCVerifyMode, TSSLProtVersion: SSLProtVersion, PSSLContext: SSLContext, TSSLAcceptResult: SSLAcceptResult].} @@ -240,11 +244,6 @@ when defineSsl: ErrLoadBioStrings() OpenSSL_add_all_algorithms() - type - SslContextExtraInternal = ref object of RootRef - serverGetPskFunc: SslServerGetPskFunc - clientGetPskFunc: SslClientGetPskFunc - proc raiseSSLError*(s = "") = ## Raises a new SSL error. if s != "": @@ -257,12 +256,6 @@ when defineSsl: var errStr = ErrErrorString(err, nil) raise newException(SSLError, $errStr) - proc getExtraDataIndex*(ctx: SSLContext): int = - ## Retrieves unique index for storing extra data in SSLContext. - result = SSL_CTX_get_ex_new_index(0, nil, nil, nil, nil).int - if result < 0: - raiseSSLError() - proc getExtraData*(ctx: SSLContext, index: int): RootRef = ## Retrieves arbitrary data stored inside SSLContext. if index notin ctx.referencedData: @@ -347,15 +340,11 @@ when defineSsl: discard newCTX.SSLCTXSetMode(SSL_MODE_AUTO_RETRY) newCTX.loadCertificates(certFile, keyFile) - result = SSLContext(context: newCTX, extraInternalIndex: 0, - referencedData: initSet[int]()) - result.extraInternalIndex = getExtraDataIndex(result) - - let extraInternal = new(SslContextExtraInternal) - result.setExtraData(result.extraInternalIndex, extraInternal) + result = SSLContext(context: newCTX, referencedData: initSet[int](), + extraInternal: new(SslContextExtraInternal)) proc getExtraInternal(ctx: SSLContext): SslContextExtraInternal = - return SslContextExtraInternal(ctx.getExtraData(ctx.extraInternalIndex)) + return ctx.extraInternal proc destroyContext*(ctx: SSLContext) = ## Free memory referenced by SSLContext. @@ -379,7 +368,7 @@ when defineSsl: proc pskClientCallback(ssl: SslPtr; hint: cstring; identity: cstring; max_identity_len: cuint; psk: ptr cuchar; max_psk_len: cuint): cuint {.cdecl.} = - let ctx = SSLContext(context: ssl.SSL_get_SSL_CTX, extraInternalIndex: 0) + let ctx = SSLContext(context: ssl.SSL_get_SSL_CTX) let hintString = if hint == nil: nil else: $hint let (identityString, pskString) = (ctx.clientGetPskFunc)(hintString) if psk.len.cuint > max_psk_len: @@ -398,8 +387,6 @@ when defineSsl: ## ## Only used in PSK ciphersuites. ctx.getExtraInternal().clientGetPskFunc = fun - assert ctx.extraInternalIndex == 0, - "The pskClientCallback assumes the extraInternalIndex is 0" ctx.context.SSL_CTX_set_psk_client_callback( if fun == nil: nil else: pskClientCallback) @@ -407,7 +394,7 @@ when defineSsl: return ctx.getExtraInternal().serverGetPskFunc proc pskServerCallback(ssl: SslCtx; identity: cstring; psk: ptr cuchar; max_psk_len: cint): cuint {.cdecl.} = - let ctx = SSLContext(context: ssl.SSL_get_SSL_CTX, extraInternalIndex: 0) + let ctx = SSLContext(context: ssl.SSL_get_SSL_CTX) let pskString = (ctx.serverGetPskFunc)($identity) if psk.len.cint > max_psk_len: return 0 diff --git a/lib/wrappers/openssl.nim b/lib/wrappers/openssl.nim index 241ad17ae..1bd02eaf0 100644 --- a/lib/wrappers/openssl.nim +++ b/lib/wrappers/openssl.nim @@ -37,6 +37,8 @@ else: DLLUtilName = "libcrypto.so" & versions from posix import SocketHandle +import dynlib + type SslStruct {.final, pure.} = object SslPtr* = ptr SslStruct @@ -185,16 +187,74 @@ const BIO_C_DO_STATE_MACHINE = 101 BIO_C_GET_SSL = 110 -proc SSL_library_init*(): cInt{.cdecl, dynlib: DLLSSLName, importc, discardable.} -proc SSL_load_error_strings*(){.cdecl, dynlib: DLLSSLName, importc.} -proc ERR_load_BIO_strings*(){.cdecl, dynlib: DLLUtilName, importc.} - -proc SSLv23_client_method*(): PSSL_METHOD{.cdecl, dynlib: DLLSSLName, importc.} -proc SSLv23_method*(): PSSL_METHOD{.cdecl, dynlib: DLLSSLName, importc.} -proc SSLv2_method*(): PSSL_METHOD{.cdecl, dynlib: DLLSSLName, importc.} -proc SSLv3_method*(): PSSL_METHOD{.cdecl, dynlib: DLLSSLName, importc.} proc TLSv1_method*(): PSSL_METHOD{.cdecl, dynlib: DLLSSLName, importc.} +when compileOption("dynlibOverride", "ssl"): + proc SSL_library_init*(): cint {.cdecl, dynlib: DLLSSLName, importc, discardable.} + proc SSL_load_error_strings*() {.cdecl, dynlib: DLLSSLName, importc.} + proc SSLv23_client_method*(): PSSL_METHOD {.cdecl, dynlib: DLLSSLName, importc.} + + proc SSLv23_method*(): PSSL_METHOD {.cdecl, dynlib: DLLSSLName, importc.} + proc SSLv2_method*(): PSSL_METHOD {.cdecl, dynlib: DLLSSLName, importc.} + proc SSLv3_method*(): PSSL_METHOD {.cdecl, dynlib: DLLSSLName, importc.} + + template OpenSSL_add_all_algorithms*() = discard +else: + # Here we're trying to stay compatible with openssl 1.0.* and 1.1.*. Some + # symbols are loaded dynamically and we don't use them if not found. + proc thisModule(): LibHandle {.inline.} = + var thisMod {.global.}: LibHandle + if thisMod.isNil: thisMod = loadLib() + result = thisMod + + proc sslModule(): LibHandle {.inline.} = + var sslMod {.global.}: LibHandle + if sslMod.isNil: sslMod = loadLibPattern(DLLSSLName) + result = sslMod + + proc sslSym(name: string): pointer = + var dl = thisModule() + if not dl.isNil: + result = symAddr(dl, name) + if result.isNil: + dl = sslModule() + if not dl.isNil: + result = symAddr(dl, name) + + proc SSL_library_init*(): cint {.discardable.} = + let theProc = cast[proc(): cint {.cdecl.}](sslSym("SSL_library_init")) + if not theProc.isNil: result = theProc() + + proc SSL_load_error_strings*() = + let theProc = cast[proc() {.cdecl.}](sslSym("SSL_load_error_strings")) + if not theProc.isNil: theProc() + + proc SSLv23_client_method*(): PSSL_METHOD = + let theProc = cast[proc(): PSSL_METHOD {.cdecl, gcsafe.}](sslSym("SSLv23_client_method")) + if not theProc.isNil: result = theProc() + else: result = TLSv1_method() + + proc SSLv23_method*(): PSSL_METHOD = + let theProc = cast[proc(): PSSL_METHOD {.cdecl, gcsafe.}](sslSym("SSLv23_method")) + if not theProc.isNil: result = theProc() + else: result = TLSv1_method() + + proc SSLv2_method*(): PSSL_METHOD = + let theProc = cast[proc(): PSSL_METHOD {.cdecl, gcsafe.}](sslSym("SSLv2_method")) + if not theProc.isNil: result = theProc() + else: result = TLSv1_method() + + proc SSLv3_method*(): PSSL_METHOD = + let theProc = cast[proc(): PSSL_METHOD {.cdecl, gcsafe.}](sslSym("SSLv3_method")) + if not theProc.isNil: result = theProc() + else: result = TLSv1_method() + + proc OpenSSL_add_all_algorithms*() = + let theProc = cast[proc() {.cdecl.}](sslSym("OPENSSL_add_all_algorithms_conf")) + if not theProc.isNil: theProc() + +proc ERR_load_BIO_strings*(){.cdecl, dynlib: DLLUtilName, importc.} + proc SSL_new*(context: SslCtx): SslPtr{.cdecl, dynlib: DLLSSLName, importc.} proc SSL_free*(ssl: SslPtr){.cdecl, dynlib: DLLSSLName, importc.} proc SSL_get_SSL_CTX*(ssl: SslPtr): SslCtx {.cdecl, dynlib: DLLSSLName, importc.} @@ -261,11 +321,6 @@ proc ERR_error_string*(e: cInt, buf: cstring): cstring{.cdecl, proc ERR_get_error*(): cInt{.cdecl, dynlib: DLLUtilName, importc.} proc ERR_peek_last_error*(): cInt{.cdecl, dynlib: DLLUtilName, importc.} -when defined(android): - template OpenSSL_add_all_algorithms*() = discard -else: - proc OpenSSL_add_all_algorithms*(){.cdecl, dynlib: DLLUtilName, importc: "OPENSSL_add_all_algorithms_conf".} - proc OPENSSL_config*(configName: cstring){.cdecl, dynlib: DLLSSLName, importc.} when not useWinVersion and not defined(macosx) and not defined(android): |