diff options
author | Kartik K. Agaram <vc@akkartik.com> | 2021-11-21 15:55:52 -0800 |
---|---|---|
committer | Kartik K. Agaram <vc@akkartik.com> | 2021-11-21 15:55:52 -0800 |
commit | 5a484efe8c72a929382c96555a31129f8d2a55c8 (patch) | |
tree | 60f6b76e3c06dbc1bfb9fe9e978475256e8a8f6d /src/luasec | |
parent | 3b44b9827d5e9c6554c5600c45d832d4e6eb50f8 (diff) | |
download | teliva-5a484efe8c72a929382c96555a31129f8d2a55c8.tar.gz |
https now working!
Still extremely ugly: - I've inlined all the namespaces under ssl, so you need to know that context and config are related to ssl. - luasec comes with its own copy of luasocket. I haven't deduped that yet.
Diffstat (limited to 'src/luasec')
29 files changed, 5611 insertions, 0 deletions
diff --git a/src/luasec/Makefile b/src/luasec/Makefile new file mode 100644 index 0000000..d1696d2 --- /dev/null +++ b/src/luasec/Makefile @@ -0,0 +1,67 @@ +CMOD=ssl.so +LMOD=ssl.lua + +OBJS= \ + options.o \ + x509.o \ + context.o \ + ssl.o \ + config.o \ + ec.o + +WARN=-Wall -pedantic + +BSD_CFLAGS=-O2 -fPIC $(WARN) $(INCDIR) $(DEFS) +BSD_LDFLAGS=-O -fPIC -shared $(LIBDIR) + +LNX_CFLAGS=-O2 $(WARN) $(INCDIR) $(DEFS) +LNX_LDFLAGS=-O + +MAC_ENV=env MACOSX_DEPLOYMENT_TARGET='$(MACVER)' +MAC_CFLAGS=-O2 -fno-common $(WARN) $(INCDIR) $(DEFS) +MAC_LDFLAGS=-bundle -undefined dynamic_lookup $(LIBDIR) + +INSTALL = install +CC ?= cc +CCLD ?= $(MYENV) $(CC) +CFLAGS += $(MYCFLAGS) +LDFLAGS += $(MYLDFLAGS) +AR= ar rc +RANLIB= ranlib + +.PHONY: all clean install none linux bsd macosx luasocket + +all: luasocket + +install: $(CMOD) $(LMOD) + $(INSTALL) -d $(DESTDIR)$(LUAPATH)/ssl $(DESTDIR)$(LUACPATH) + $(INSTALL) $(CMOD) $(DESTDIR)$(LUACPATH) + $(INSTALL) -m644 $(LMOD) $(DESTDIR)$(LUAPATH) + $(INSTALL) -m644 https.lua $(DESTDIR)$(LUAPATH)/ssl + +linux: $(OBJS) + $(AR) ssl.a $(OBJS) + $(RANLIB) ssl.a + +bsd: + @$(MAKE) $(CMOD) MYCFLAGS="$(BSD_CFLAGS)" MYLDFLAGS="$(BSD_LDFLAGS)" EXTRA="$(EXTRA)" + +macosx: + @$(MAKE) $(CMOD) MYCFLAGS="$(MAC_CFLAGS)" MYLDFLAGS="$(MAC_LDFLAGS)" MYENV="$(MAC_ENV)" EXTRA="$(EXTRA)" + +luasocket: + @cd luasocket && $(MAKE) + +$(CMOD): $(EXTRA) $(OBJS) + $(CCLD) $(LDFLAGS) -o $@ $(OBJS) luasocket/libluasocket.a -lssl -lcrypto + +clean: + cd luasocket && $(MAKE) clean + rm -f $(OBJS) $(CMOD) + +options.o: options.h options.c +ec.o: ec.c ec.h +x509.o: x509.c x509.h compat.h +context.o: context.c context.h ec.h compat.h options.h +ssl.o: ssl.c ssl.h context.h x509.h compat.h +config.o: config.c ec.h options.h compat.h diff --git a/src/luasec/compat.h b/src/luasec/compat.h new file mode 100644 index 0000000..1c88de9 --- /dev/null +++ b/src/luasec/compat.h @@ -0,0 +1,57 @@ +/*-------------------------------------------------------------------------- + * LuaSec 1.0.2 + * + * Copyright (C) 2006-2021 Bruno Silvestre + * + *--------------------------------------------------------------------------*/ + +#ifndef LSEC_COMPAT_H +#define LSEC_COMPAT_H + +#include <openssl/ssl.h> + +//------------------------------------------------------------------------------ + +#if defined(_WIN32) +#define LSEC_API __declspec(dllexport) +#else +#define LSEC_API extern +#endif + +//------------------------------------------------------------------------------ + +#if (LUA_VERSION_NUM == 501) + +#define luaL_testudata(L, ud, tname) lsec_testudata(L, ud, tname) +#define setfuncs(L, R) luaL_register(L, NULL, R) +#define lua_rawlen(L, i) lua_objlen(L, i) + +#ifndef luaL_newlib +#define luaL_newlib(L, R) do { lua_newtable(L); luaL_register(L, NULL, R); } while(0) +#endif + +#else +#define setfuncs(L, R) luaL_setfuncs(L, R, 0) +#endif + +//------------------------------------------------------------------------------ + +#if (!defined(LIBRESSL_VERSION_NUMBER) && (OPENSSL_VERSION_NUMBER >= 0x1010000fL)) +#define LSEC_ENABLE_DANE +#endif + +//------------------------------------------------------------------------------ + +#if !((defined(LIBRESSL_VERSION_NUMBER) && (LIBRESSL_VERSION_NUMBER < 0x2070000fL)) || (OPENSSL_VERSION_NUMBER < 0x1010000fL)) +#define LSEC_API_OPENSSL_1_1_0 +#endif + +//------------------------------------------------------------------------------ + +#if !defined(LIBRESSL_VERSION_NUMBER) && ((OPENSSL_VERSION_NUMBER & 0xFFFFF000L) == 0x10101000L) +#define LSEC_OPENSSL_1_1_1 +#endif + +//------------------------------------------------------------------------------ + +#endif diff --git a/src/luasec/config.c b/src/luasec/config.c new file mode 100644 index 0000000..eef7120 --- /dev/null +++ b/src/luasec/config.c @@ -0,0 +1,97 @@ +/*-------------------------------------------------------------------------- + * LuaSec 1.0.2 + * + * Copyright (C) 2006-2021 Bruno Silvestre. + * + *--------------------------------------------------------------------------*/ + +#include "compat.h" +#include "options.h" +#include "ec.h" + +/** + * Registre the module. + */ +LSEC_API int luaopen_ssl_config(lua_State *L) +{ + lsec_ssl_option_t *opt; + + lua_newtable(L); + lua_pushvalue(L, -1); + lua_setglobal(L, "config"); + + // Options + lua_pushstring(L, "options"); + lua_newtable(L); + for (opt = lsec_get_ssl_options(); opt->name; opt++) { + lua_pushstring(L, opt->name); + lua_pushboolean(L, 1); + lua_rawset(L, -3); + } + lua_rawset(L, -3); + + // Protocols + lua_pushstring(L, "protocols"); + lua_newtable(L); + + lua_pushstring(L, "tlsv1"); + lua_pushboolean(L, 1); + lua_rawset(L, -3); + lua_pushstring(L, "tlsv1_1"); + lua_pushboolean(L, 1); + lua_rawset(L, -3); + lua_pushstring(L, "tlsv1_2"); + lua_pushboolean(L, 1); + lua_rawset(L, -3); +#ifdef TLS1_3_VERSION + lua_pushstring(L, "tlsv1_3"); + lua_pushboolean(L, 1); + lua_rawset(L, -3); +#endif + + lua_rawset(L, -3); + + // Algorithms + lua_pushstring(L, "algorithms"); + lua_newtable(L); + +#ifndef OPENSSL_NO_EC + lua_pushstring(L, "ec"); + lua_pushboolean(L, 1); + lua_rawset(L, -3); +#endif + lua_rawset(L, -3); + + // Curves + lua_pushstring(L, "curves"); + lsec_get_curves(L); + lua_rawset(L, -3); + + // Capabilities + lua_pushstring(L, "capabilities"); + lua_newtable(L); + + // ALPN + lua_pushstring(L, "alpn"); + lua_pushboolean(L, 1); + lua_rawset(L, -3); + +#ifdef LSEC_ENABLE_DANE + // DANE + lua_pushstring(L, "dane"); + lua_pushboolean(L, 1); + lua_rawset(L, -3); +#endif + +#ifndef OPENSSL_NO_EC + lua_pushstring(L, "curves_list"); + lua_pushboolean(L, 1); + lua_rawset(L, -3); + lua_pushstring(L, "ecdh_auto"); + lua_pushboolean(L, 1); + lua_rawset(L, -3); +#endif + lua_rawset(L, -3); + + return 1; +} diff --git a/src/luasec/context.c b/src/luasec/context.c new file mode 100644 index 0000000..5496fbf --- /dev/null +++ b/src/luasec/context.c @@ -0,0 +1,934 @@ +/*-------------------------------------------------------------------------- + * LuaSec 1.0.2 + * + * Copyright (C) 2014-2021 Kim Alvefur, Paul Aurich, Tobias Markmann, + * Matthew Wild. + * Copyright (C) 2006-2021 Bruno Silvestre. + * + *--------------------------------------------------------------------------*/ + +#include <string.h> + +#if defined(WIN32) +#include <windows.h> +#endif + +#include <openssl/ssl.h> +#include <openssl/err.h> +#include <openssl/x509.h> +#include <openssl/x509v3.h> +#include <openssl/dh.h> + +#include "../lua.h" +#include "../lauxlib.h" + +#include "compat.h" +#include "context.h" +#include "options.h" + +#ifndef OPENSSL_NO_EC +#include <openssl/ec.h> +#include "ec.h" +#endif + +/*--------------------------- Auxiliary Functions ----------------------------*/ + +/** + * Return the context. + */ +static p_context checkctx(lua_State *L, int idx) +{ + return (p_context)luaL_checkudata(L, idx, "SSL:Context"); +} + +static p_context testctx(lua_State *L, int idx) +{ + return (p_context)luaL_testudata(L, idx, "SSL:Context"); +} + +/** + * Prepare the SSL options flag. + */ +static int set_option_flag(const char *opt, unsigned long *flag) +{ + lsec_ssl_option_t *p; + for (p = lsec_get_ssl_options(); p->name; p++) { + if (!strcmp(opt, p->name)) { + *flag |= p->code; + return 1; + } + } + return 0; +} + +#ifndef LSEC_API_OPENSSL_1_1_0 +/** + * Find the protocol. + */ +static const SSL_METHOD* str2method(const char *method, int *vmin, int *vmax) +{ + (void)vmin; + (void)vmax; + if (!strcmp(method, "any")) return SSLv23_method(); + if (!strcmp(method, "sslv23")) return SSLv23_method(); // deprecated + if (!strcmp(method, "tlsv1")) return TLSv1_method(); + if (!strcmp(method, "tlsv1_1")) return TLSv1_1_method(); + if (!strcmp(method, "tlsv1_2")) return TLSv1_2_method(); + return NULL; +} + +#else + +/** + * Find the protocol. + */ +static const SSL_METHOD* str2method(const char *method, int *vmin, int *vmax) +{ + if (!strcmp(method, "any") || !strcmp(method, "sslv23")) { // 'sslv23' is deprecated + *vmin = 0; + *vmax = 0; + return TLS_method(); + } + else if (!strcmp(method, "tlsv1")) { + *vmin = TLS1_VERSION; + *vmax = TLS1_VERSION; + return TLS_method(); + } + else if (!strcmp(method, "tlsv1_1")) { + *vmin = TLS1_1_VERSION; + *vmax = TLS1_1_VERSION; + return TLS_method(); + } + else if (!strcmp(method, "tlsv1_2")) { + *vmin = TLS1_2_VERSION; + *vmax = TLS1_2_VERSION; + return TLS_method(); + } +#if defined(TLS1_3_VERSION) + else if (!strcmp(method, "tlsv1_3")) { + *vmin = TLS1_3_VERSION; + *vmax = TLS1_3_VERSION; + return TLS_method(); + } +#endif + return NULL; +} +#endif + +/** + * Prepare the SSL handshake verify flag. + */ +static int set_verify_flag(const char *str, int *flag) +{ + if (!strcmp(str, "none")) { + *flag |= SSL_VERIFY_NONE; + return 1; + } + if (!strcmp(str, "peer")) { + *flag |= SSL_VERIFY_PEER; + return 1; + } + if (!strcmp(str, "client_once")) { + *flag |= SSL_VERIFY_CLIENT_ONCE; + return 1; + } + if (!strcmp(str, "fail_if_no_peer_cert")) { + *flag |= SSL_VERIFY_FAIL_IF_NO_PEER_CERT; + return 1; + } + return 0; +} + +/** + * Password callback for reading the private key. + */ +static int passwd_cb(char *buf, int size, int flag, void *udata) +{ + lua_State *L = (lua_State*)udata; + switch (lua_type(L, 3)) { + case LUA_TFUNCTION: + lua_pushvalue(L, 3); + lua_call(L, 0, 1); + if (lua_type(L, -1) != LUA_TSTRING) { + lua_pop(L, 1); /* Remove the result from the stack */ + return 0; + } + /* fallback */ + case LUA_TSTRING: + strncpy(buf, lua_tostring(L, -1), size); + lua_pop(L, 1); /* Remove the result from the stack */ + buf[size-1] = '\0'; + return (int)strlen(buf); + } + return 0; +} + +/** + * Add an error related to a depth certificate of the chain. + */ +static void add_cert_error(lua_State *L, SSL *ssl, int err, int depth) +{ + luaL_getmetatable(L, "SSL:Verify:Registry"); + lua_pushlightuserdata(L, (void*)ssl); + lua_gettable(L, -2); + if (lua_isnil(L, -1)) { + lua_pop(L, 1); + /* Create an error table for this connection */ + lua_newtable(L); + lua_pushlightuserdata(L, (void*)ssl); + lua_pushvalue(L, -2); /* keep the table on stack */ + lua_settable(L, -4); + } + lua_rawgeti(L, -1, depth+1); + /* If the table doesn't exist, create it */ + if (lua_isnil(L, -1)) { + lua_pop(L, 1); /* remove 'nil' from stack */ + lua_newtable(L); + lua_pushvalue(L, -1); /* keep the table on stack */ + lua_rawseti(L, -3, depth+1); + } + lua_pushstring(L, X509_verify_cert_error_string(err)); + lua_rawseti(L, -2, lua_rawlen(L, -2) + 1); + /* Clear the stack */ + lua_pop(L, 3); +} + +/** + * Call Lua user function to get the DH key. + */ +static DH *dhparam_cb(SSL *ssl, int is_export, int keylength) +{ + BIO *bio; + lua_State *L; + SSL_CTX *ctx = SSL_get_SSL_CTX(ssl); + p_context pctx = (p_context)SSL_CTX_get_app_data(ctx); + + L = pctx->L; + + /* Get the callback */ + luaL_getmetatable(L, "SSL:DH:Registry"); + lua_pushlightuserdata(L, (void*)ctx); + lua_gettable(L, -2); + + /* Invoke the callback */ + lua_pushboolean(L, is_export); + lua_pushnumber(L, keylength); + lua_call(L, 2, 1); + + /* Load parameters from returned value */ + if (lua_type(L, -1) != LUA_TSTRING) { + lua_pop(L, 2); /* Remove values from stack */ + return NULL; + } + + bio = BIO_new_mem_buf((void*)lua_tostring(L, -1), lua_rawlen(L, -1)); + if (bio) { + pctx->dh_param = PEM_read_bio_DHparams(bio, NULL, NULL, NULL); + BIO_free(bio); + } + + lua_pop(L, 2); /* Remove values from stack */ + return pctx->dh_param; +} + +/** + * Set the "ignore purpose" before to start verifing the certificate chain. + */ +static int cert_verify_cb(X509_STORE_CTX *x509_ctx, void *ptr) +{ + int verify; + lua_State *L; + SSL_CTX *ctx = (SSL_CTX*)ptr; + p_context pctx = (p_context)SSL_CTX_get_app_data(ctx); + + L = pctx->L; + + /* Get verify flags */ + luaL_getmetatable(L, "SSL:Verify:Registry"); + lua_pushlightuserdata(L, (void*)ctx); + lua_gettable(L, -2); + verify = (int)lua_tonumber(L, -1); + + lua_pop(L, 2); /* Remove values from stack */ + + if (verify & LSEC_VERIFY_IGNORE_PURPOSE) { + /* Set parameters to ignore the server purpose */ + X509_VERIFY_PARAM *param = X509_STORE_CTX_get0_param(x509_ctx); + if (param) { + X509_VERIFY_PARAM_set_purpose(param, X509_PURPOSE_SSL_SERVER); + X509_VERIFY_PARAM_set_trust(param, X509_TRUST_SSL_SERVER); + } + } + /* Call OpenSSL standard verification function */ + return X509_verify_cert(x509_ctx); +} + +/** + * This callback implements the "continue on error" flag and log the errors. + */ +static int verify_cb(int preverify_ok, X509_STORE_CTX *x509_ctx) +{ + int err; + int verify; + SSL *ssl; + SSL_CTX *ctx; + p_context pctx; + lua_State *L; + + /* Short-circuit optimization */ + if (preverify_ok) + return 1; + + ssl = X509_STORE_CTX_get_ex_data(x509_ctx, + SSL_get_ex_data_X509_STORE_CTX_idx()); + ctx = SSL_get_SSL_CTX(ssl); + pctx = (p_context)SSL_CTX_get_app_data(ctx); + L = pctx->L; + + /* Get verify flags */ + luaL_getmetatable(L, "SSL:Verify:Registry"); + lua_pushlightuserdata(L, (void*)ctx); + lua_gettable(L, -2); + verify = (int)lua_tonumber(L, -1); + + lua_pop(L, 2); /* Remove values from stack */ + + err = X509_STORE_CTX_get_error(x509_ctx); + if (err != X509_V_OK) + add_cert_error(L, ssl, err, X509_STORE_CTX_get_error_depth(x509_ctx)); + + return (verify & LSEC_VERIFY_CONTINUE ? 1 : preverify_ok); +} + +/*------------------------------ Lua Functions -------------------------------*/ + +/** + * Create a SSL context. + */ +static int create(lua_State *L) +{ + p_context ctx; + const char *str_method; + const SSL_METHOD *method; + int vmin, vmax; + + str_method = luaL_checkstring(L, 1); + method = str2method(str_method, &vmin, &vmax); + if (!method) { + lua_pushnil(L); + lua_pushfstring(L, "invalid protocol (%s)", str_method); + return 2; + } + ctx = (p_context) lua_newuserdata(L, sizeof(t_context)); + if (!ctx) { + lua_pushnil(L); + lua_pushstring(L, "error creating context"); + return 2; + } + memset(ctx, 0, sizeof(t_context)); + ctx->context = SSL_CTX_new(method); + if (!ctx->context) { + lua_pushnil(L); + lua_pushfstring(L, "error creating context (%s)", + ERR_reason_error_string(ERR_get_error())); + return 2; + } +#ifdef LSEC_API_OPENSSL_1_1_0 + SSL_CTX_set_min_proto_version(ctx->context, vmin); + SSL_CTX_set_max_proto_version(ctx->context, vmax); +#endif + ctx->mode = LSEC_MODE_INVALID; + ctx->L = L; + luaL_getmetatable(L, "SSL:Context"); + lua_setmetatable(L, -2); + + /* No session support */ + SSL_CTX_set_session_cache_mode(ctx->context, SSL_SESS_CACHE_OFF); + /* Link LuaSec context with the OpenSSL context */ + SSL_CTX_set_app_data(ctx->context, ctx); + + return 1; +} + +/** + * Load the trusting certificates. + */ +static int load_locations(lua_State *L) +{ + SSL_CTX *ctx = lsec_checkcontext(L, 1); + const char *cafile = luaL_optstring(L, 2, NULL); + const char *capath = luaL_optstring(L, 3, NULL); + if (SSL_CTX_load_verify_locations(ctx, cafile, capath) != 1) { + lua_pushboolean(L, 0); + lua_pushfstring(L, "error loading CA locations (%s)", + ERR_reason_error_string(ERR_get_error())); + return 2; + } + lua_pushboolean(L, 1); + return 1; +} + +/** + * Load the certificate file. + */ +static int load_cert(lua_State *L) +{ + SSL_CTX *ctx = lsec_checkcontext(L, 1); + const char *filename = luaL_checkstring(L, 2); + if (SSL_CTX_use_certificate_chain_file(ctx, filename) != 1) { + lua_pushboolean(L, 0); + lua_pushfstring(L, "error loading certificate (%s)", + ERR_reason_error_string(ERR_get_error())); + return 2; + } + lua_pushboolean(L, 1); + return 1; +} + +/** + * Load the key file -- only in PEM format. + */ +static int load_key(lua_State *L) +{ + int ret = 1; + SSL_CTX *ctx = lsec_checkcontext(L, 1); + const char *filename = luaL_checkstring(L, 2); + switch (lua_type(L, 3)) { + case LUA_TSTRING: + case LUA_TFUNCTION: + SSL_CTX_set_default_passwd_cb(ctx, passwd_cb); + SSL_CTX_set_default_passwd_cb_userdata(ctx, L); + /* fallback */ + case LUA_TNIL: + if (SSL_CTX_use_PrivateKey_file(ctx, filename, SSL_FILETYPE_PEM) == 1) + lua_pushboolean(L, 1); + else { + ret = 2; + lua_pushboolean(L, 0); + lua_pushfstring(L, "error loading private key (%s)", + ERR_reason_error_string(ERR_get_error())); + } + SSL_CTX_set_default_passwd_cb(ctx, NULL); + SSL_CTX_set_default_passwd_cb_userdata(ctx, NULL); + break; + default: + lua_pushstring(L, "invalid callback value"); + lua_error(L); + } + return ret; +} + +/** + * Check that the certificate public key matches the private key + */ + +static int check_key(lua_State *L) +{ + SSL_CTX *ctx = lsec_checkcontext(L, 1); + lua_pushboolean(L, SSL_CTX_check_private_key(ctx)); + return 1; +} + +/** + * Set the cipher list. + */ +static int set_cipher(lua_State *L) +{ + SSL_CTX *ctx = lsec_checkcontext(L, 1); + const char *list = luaL_checkstring(L, 2); + if (SSL_CTX_set_cipher_list(ctx, list) != 1) { + lua_pushboolean(L, 0); + lua_pushfstring(L, "error setting cipher list (%s)", ERR_reason_error_string(ERR_get_error())); + return 2; + } + lua_pushboolean(L, 1); + return 1; +} + +/** + * Set the cipher suites. + */ +static int set_ciphersuites(lua_State *L) +{ +#if defined(TLS1_3_VERSION) + SSL_CTX *ctx = lsec_checkcontext(L, 1); + const char *list = luaL_checkstring(L, 2); + if (SSL_CTX_set_ciphersuites(ctx, list) != 1) { + lua_pushboolean(L, 0); + lua_pushfstring(L, "error setting cipher list (%s)", ERR_reason_error_string(ERR_get_error())); + return 2; + } +#endif + lua_pushboolean(L, 1); + return 1; +} + +/** + * Set the depth for certificate checking. + */ +static int set_depth(lua_State *L) +{ + SSL_CTX *ctx = lsec_checkcontext(L, 1); + SSL_CTX_set_verify_depth(ctx, (int)luaL_checkinteger(L, 2)); + lua_pushboolean(L, 1); + return 1; +} + +/** + * Set the handshake verify options. + */ +static int set_verify(lua_State *L) +{ + int i; + const char *str; + int flag = 0; + SSL_CTX *ctx = lsec_checkcontext(L, 1); + int max = lua_gettop(L); + for (i = 2; i <= max; i++) { + str = luaL_checkstring(L, i); + if (!set_verify_flag(str, &flag)) { + lua_pushboolean(L, 0); + lua_pushfstring(L, "invalid verify option (%s)", str); + return 2; + } + } + if (flag) SSL_CTX_set_verify(ctx, flag, NULL); + lua_pushboolean(L, 1); + return 1; +} + +/** + * Set the protocol options. + */ +static int set_options(lua_State *L) +{ + int i; + const char *str; + unsigned long flag = 0L; + SSL_CTX *ctx = lsec_checkcontext(L, 1); + int max = lua_gettop(L); + /* any option? */ + if (max > 1) { + for (i = 2; i <= max; i++) { + str = luaL_checkstring(L, i); + if (!set_option_flag(str, &flag)) { + lua_pushboolean(L, 0); + lua_pushfstring(L, "invalid option (%s)", str); + return 2; + } + } + SSL_CTX_set_options(ctx, flag); + } + lua_pushboolean(L, 1); + return 1; +} + +/** + * Set the context mode. + */ +static int set_mode(lua_State *L) +{ + p_context ctx = checkctx(L, 1); + const char *str = luaL_checkstring(L, 2); + if (!strcmp("server", str)) { + ctx->mode = LSEC_MODE_SERVER; + lua_pushboolean(L, 1); + return 1; + } + if (!strcmp("client", str)) { + ctx->mode = LSEC_MODE_CLIENT; + lua_pushboolean(L, 1); + return 1; + } + lua_pushboolean(L, 0); + lua_pushfstring(L, "invalid mode (%s)", str); + return 1; +} + +/** + * Configure DH parameters. + */ +static int set_dhparam(lua_State *L) +{ + SSL_CTX *ctx = lsec_checkcontext(L, 1); + SSL_CTX_set_tmp_dh_callback(ctx, dhparam_cb); + + /* Save callback */ + luaL_getmetatable(L, "SSL:DH:Registry"); + lua_pushlightuserdata(L, (void*)ctx); + lua_pushvalue(L, 2); + lua_settable(L, -3); + + return 0; +} + +#if !defined(OPENSSL_NO_EC) +/** + * Set elliptic curve. + */ +static int set_curve(lua_State *L) +{ + long ret; + EC_KEY *key = NULL; + SSL_CTX *ctx = lsec_checkcontext(L, 1); + const char *str = luaL_checkstring(L, 2); + + SSL_CTX_set_options(ctx, SSL_OP_SINGLE_ECDH_USE); + + key = lsec_find_ec_key(L, str); + + if (!key) { + lua_pushboolean(L, 0); + lua_pushfstring(L, "elliptic curve '%s' not supported", str); + return 2; + } + + ret = SSL_CTX_set_tmp_ecdh(ctx, key); + /* SSL_CTX_set_tmp_ecdh takes its own reference */ + EC_KEY_free(key); + + if (!ret) { + lua_pushboolean(L, 0); + lua_pushfstring(L, "error setting elliptic curve (%s)", + ERR_reason_error_string(ERR_get_error())); + return 2; + } + + lua_pushboolean(L, 1); + return 1; +} + +/** + * Set elliptic curves list. + */ +static int set_curves_list(lua_State *L) +{ + SSL_CTX *ctx = lsec_checkcontext(L, 1); + const char *str = luaL_checkstring(L, 2); + + SSL_CTX_set_options(ctx, SSL_OP_SINGLE_ECDH_USE); + + if (SSL_CTX_set1_curves_list(ctx, str) != 1) { + lua_pushboolean(L, 0); + lua_pushfstring(L, "unknown elliptic curve in \"%s\"", str); + return 2; + } + +#if defined(LIBRESSL_VERSION_NUMBER) || !defined(LSEC_API_OPENSSL_1_1_0) + (void)SSL_CTX_set_ecdh_auto(ctx, 1); +#endif + + lua_pushboolean(L, 1); + return 1; +} +#endif + +/** + * Set the protocols a client should send for ALPN. + */ +static int set_alpn(lua_State *L) +{ + long ret; + size_t len; + p_context ctx = checkctx(L, 1); + const char *str = luaL_checklstring(L, 2, &len); + + ret = SSL_CTX_set_alpn_protos(ctx->context, (const unsigned char*)str, len); + if (ret) { + lua_pushboolean(L, 0); + lua_pushfstring(L, "error setting ALPN (%s)", ERR_reason_error_string(ERR_get_error())); + return 2; + } + lua_pushboolean(L, 1); + return 1; +} + +/** + * This standard callback calls the server's callback in Lua sapce. + * The server has to return a list in wire-format strings. + * This function uses a helper function to match server and client lists. + */ +static int alpn_cb(SSL *s, const unsigned char **out, unsigned char *outlen, + const unsigned char *in, unsigned int inlen, void *arg) +{ + int ret; + size_t server_len; + const char *server; + p_context ctx = (p_context)arg; + lua_State *L = ctx->L; + + luaL_getmetatable(L, "SSL:ALPN:Registry"); + lua_pushlightuserdata(L, (void*)ctx->context); + lua_gettable(L, -2); + + lua_pushlstring(L, (const char*)in, inlen); + + lua_call(L, 1, 1); + + if (!lua_isstring(L, -1)) { + lua_pop(L, 2); + return SSL_TLSEXT_ERR_NOACK; + } + + // Protocol list from server in wire-format string + server = luaL_checklstring(L, -1, &server_len); + ret = SSL_select_next_proto((unsigned char**)out, outlen, (const unsigned char*)server, + server_len, in, inlen); + if (ret != OPENSSL_NPN_NEGOTIATED) { + lua_pop(L, 2); + return SSL_TLSEXT_ERR_NOACK; + } + + // Copy the result because lua_pop() can collect the pointer + ctx->alpn = malloc(*outlen); + memcpy(ctx->alpn, (void*)*out, *outlen); + *out = (const unsigned char*)ctx->alpn; + + lua_pop(L, 2); + + return SSL_TLSEXT_ERR_OK; +} + +/** + * Set a callback a server can use to select the next protocol with ALPN. + */ +static int set_alpn_cb(lua_State *L) +{ + p_context ctx = checkctx(L, 1); + + luaL_getmetatable(L, "SSL:ALPN:Registry"); + lua_pushlightuserdata(L, (void*)ctx->context); + lua_pushvalue(L, 2); + lua_settable(L, -3); + + SSL_CTX_set_alpn_select_cb(ctx->context, alpn_cb, ctx); + + lua_pushboolean(L, 1); + return 1; +} + +#if defined(LSEC_ENABLE_DANE) +/* + * DANE + */ +static int set_dane(lua_State *L) +{ + int ret; + SSL_CTX *ctx = lsec_checkcontext(L, 1); + ret = SSL_CTX_dane_enable(ctx); + lua_pushboolean(L, (ret > 0)); + return 1; +} +#endif + +/** + * Package functions + */ +static luaL_Reg funcs[] = { + {"create", create}, + {"locations", load_locations}, + {"loadcert", load_cert}, + {"loadkey", load_key}, + {"checkkey", check_key}, + {"setalpn", set_alpn}, + {"setalpncb", set_alpn_cb}, + {"setcipher", set_cipher}, + {"setciphersuites", set_ciphersuites}, + {"setdepth", set_depth}, + {"setdhparam", set_dhparam}, + {"setverify", set_verify}, + {"setoptions", set_options}, + {"setmode", set_mode}, +#if !defined(OPENSSL_NO_EC) + {"setcurve", set_curve}, + {"setcurveslist", set_curves_list}, +#endif +#if defined(LSEC_ENABLE_DANE) + {"setdane", set_dane}, +#endif + {NULL, NULL} +}; + +/*-------------------------------- Metamethods -------------------------------*/ + +/** + * Collect SSL context -- GC metamethod. + */ +static int meth_destroy(lua_State *L) +{ + p_context ctx = checkctx(L, 1); + if (ctx->context) { + /* Clear registries */ + luaL_getmetatable(L, "SSL:DH:Registry"); + lua_pushlightuserdata(L, (void*)ctx->context); + lua_pushnil(L); + lua_settable(L, -3); + luaL_getmetatable(L, "SSL:Verify:Registry"); + lua_pushlightuserdata(L, (void*)ctx->context); + lua_pushnil(L); + lua_settable(L, -3); + luaL_getmetatable(L, "SSL:ALPN:Registry"); + lua_pushlightuserdata(L, (void*)ctx->context); + lua_pushnil(L); + lua_settable(L, -3); + + SSL_CTX_free(ctx->context); + ctx->context = NULL; + } + return 0; +} + +/** + * Object information -- tostring metamethod. + */ +static int meth_tostring(lua_State *L) +{ + p_context ctx = checkctx(L, 1); + lua_pushfstring(L, "SSL context: %p", ctx); + return 1; +} + +/** + * Set extra flags for handshake verification. + */ +static int meth_set_verify_ext(lua_State *L) +{ + int i; + const char *str; + int crl_flag = 0; + int lsec_flag = 0; + SSL_CTX *ctx = lsec_checkcontext(L, 1); + int max = lua_gettop(L); + for (i = 2; i <= max; i++) { + str = luaL_checkstring(L, i); + if (!strcmp(str, "lsec_continue")) { + lsec_flag |= LSEC_VERIFY_CONTINUE; + } else if (!strcmp(str, "lsec_ignore_purpose")) { + lsec_flag |= LSEC_VERIFY_IGNORE_PURPOSE; + } else if (!strcmp(str, "crl_check")) { + crl_flag |= X509_V_FLAG_CRL_CHECK; + } else if (!strcmp(str, "crl_check_chain")) { + crl_flag |= X509_V_FLAG_CRL_CHECK_ALL; + } else { + lua_pushboolean(L, 0); + lua_pushfstring(L, "invalid verify option (%s)", str); + return 2; + } + } + /* Set callback? */ + if (lsec_flag) { + SSL_CTX_set_verify(ctx, SSL_CTX_get_verify_mode(ctx), verify_cb); + SSL_CTX_set_cert_verify_callback(ctx, cert_verify_cb, (void*)ctx); + /* Save flag */ + luaL_getmetatable(L, "SSL:Verify:Registry"); + lua_pushlightuserdata(L, (void*)ctx); + lua_pushnumber(L, lsec_flag); + lua_settable(L, -3); + } else { + SSL_CTX_set_verify(ctx, SSL_CTX_get_verify_mode(ctx), NULL); + SSL_CTX_set_cert_verify_callback(ctx, NULL, NULL); + /* Remove flag */ + luaL_getmetatable(L, "SSL:Verify:Registry"); + lua_pushlightuserdata(L, (void*)ctx); + lua_pushnil(L); + lua_settable(L, -3); + } + + /* X509 flag */ + X509_STORE_set_flags(SSL_CTX_get_cert_store(ctx), crl_flag); + + /* Ok */ + lua_pushboolean(L, 1); + return 1; +} + +/** + * Context metamethods. + */ +static luaL_Reg meta[] = { + {"__close", meth_destroy}, + {"__gc", meth_destroy}, + {"__tostring", meth_tostring}, + {NULL, NULL} +}; + +/** + * Index metamethods. + */ +static luaL_Reg meta_index[] = { + {"setverifyext", meth_set_verify_ext}, + {NULL, NULL} +}; + + +/*----------------------------- Public Functions ---------------------------*/ + +/** + * Retrieve the SSL context from the Lua stack. + */ +SSL_CTX* lsec_checkcontext(lua_State *L, int idx) +{ + p_context ctx = checkctx(L, idx); + return ctx->context; +} + +SSL_CTX* lsec_testcontext(lua_State *L, int idx) +{ + p_context ctx = testctx(L, idx); + return (ctx) ? ctx->context : NULL; +} + +/** + * Retrieve the mode from the context in the Lua stack. + */ +int lsec_getmode(lua_State *L, int idx) +{ + p_context ctx = checkctx(L, idx); + return ctx->mode; +} + +/*-- Compat - Lua 5.1 --*/ +#if (LUA_VERSION_NUM == 501) + +void *lsec_testudata (lua_State *L, int ud, const char *tname) { + void *p = lua_touserdata(L, ud); + if (p != NULL) { /* value is a userdata? */ + if (lua_getmetatable(L, ud)) { /* does it have a metatable? */ + luaL_getmetatable(L, tname); /* get correct metatable */ + if (!lua_rawequal(L, -1, -2)) /* not the same? */ + p = NULL; /* value is a userdata with wrong metatable */ + lua_pop(L, 2); /* remove both metatables */ + return p; + } + } + return NULL; /* value is not a userdata with a metatable */ +} + +#endif + +/*------------------------------ Initialization ------------------------------*/ + +/** + * Registre the module. + */ +LSEC_API int luaopen_ssl_context(lua_State *L) +{ + luaL_newmetatable(L, "SSL:DH:Registry"); /* Keep all DH callbacks */ + luaL_newmetatable(L, "SSL:ALPN:Registry"); /* Keep all ALPN callbacks */ + luaL_newmetatable(L, "SSL:Verify:Registry"); /* Keep all verify flags */ + luaL_newmetatable(L, "SSL:Context"); + setfuncs(L, meta); + + /* Create __index metamethods for context */ + luaL_newlib(L, meta_index); + lua_setfield(L, -2, "__index"); + + lsec_load_curves(L); + + /* Return the module */ + luaL_newlib(L, funcs); + lua_pushvalue(L, -1); + lua_setglobal(L, "context"); + + return 1; +} diff --git a/src/luasec/context.h b/src/luasec/context.h new file mode 100644 index 0000000..21202cb --- /dev/null +++ b/src/luasec/context.h @@ -0,0 +1,47 @@ +#ifndef LSEC_CONTEXT_H +#define LSEC_CONTEXT_H + +/*-------------------------------------------------------------------------- + * LuaSec 1.0.2 + * + * Copyright (C) 2006-2021 Bruno Silvestre + * + *--------------------------------------------------------------------------*/ + +#include "../lua.h" +#include <openssl/ssl.h> + +#include "compat.h" + +#define LSEC_MODE_INVALID 0 +#define LSEC_MODE_SERVER 1 +#define LSEC_MODE_CLIENT 2 + +#define LSEC_VERIFY_CONTINUE 1 +#define LSEC_VERIFY_IGNORE_PURPOSE 2 + +typedef struct t_context_ { + SSL_CTX *context; + lua_State *L; + DH *dh_param; + void *alpn; + int mode; +} t_context; +typedef t_context* p_context; + +/* Retrieve the SSL context from the Lua stack */ +SSL_CTX *lsec_checkcontext(lua_State *L, int idx); +SSL_CTX *lsec_testcontext(lua_State *L, int idx); + +/* Retrieve the mode from the context in the Lua stack */ +int lsec_getmode(lua_State *L, int idx); + +/* Registre the module. */ +LSEC_API int luaopen_ssl_context(lua_State *L); + +/* Compat - Lua 5.1 */ +#if (LUA_VERSION_NUM == 501) +void *lsec_testudata (lua_State *L, int ud, const char *tname); +#endif + +#endif diff --git a/src/luasec/ec.c b/src/luasec/ec.c new file mode 100644 index 0000000..73b09d7 --- /dev/null +++ b/src/luasec/ec.c @@ -0,0 +1,109 @@ +#include <openssl/objects.h> + +#include "ec.h" + +#ifndef OPENSSL_NO_EC + +EC_KEY *lsec_find_ec_key(lua_State *L, const char *str) +{ + int nid; + lua_pushstring(L, "SSL:EC:CURVES"); + lua_rawget(L, LUA_REGISTRYINDEX); + lua_pushstring(L, str); + lua_rawget(L, -2); + + if (!lua_isnumber(L, -1)) + return NULL; + + nid = (int)lua_tonumber(L, -1); + return EC_KEY_new_by_curve_name(nid); +} + +void lsec_load_curves(lua_State *L) +{ + size_t i; + size_t size; + const char *name; + EC_builtin_curve *curves = NULL; + + lua_pushstring(L, "SSL:EC:CURVES"); + lua_newtable(L); + + size = EC_get_builtin_curves(NULL, 0); + if (size > 0) { + curves = (EC_builtin_curve*)malloc(sizeof(EC_builtin_curve) * size); + EC_get_builtin_curves(curves, size); + for (i = 0; i < size; i++) { + name = OBJ_nid2sn(curves[i].nid); + if (name != NULL) { + lua_pushstring(L, name); + lua_pushnumber(L, curves[i].nid); + lua_rawset(L, -3); + } + switch (curves[i].nid) { + case NID_X9_62_prime256v1: + lua_pushstring(L, "P-256"); + lua_pushnumber(L, curves[i].nid); + lua_rawset(L, -3); + break; + case NID_secp384r1: + lua_pushstring(L, "P-384"); + lua_pushnumber(L, curves[i].nid); + lua_rawset(L, -3); + break; + case NID_secp521r1: + lua_pushstring(L, "P-521"); + lua_pushnumber(L, curves[i].nid); + lua_rawset(L, -3); + break; + } + } + free(curves); + } + + /* These are special so are manually added here */ +#ifdef NID_X25519 + lua_pushstring(L, "X25519"); + lua_pushnumber(L, NID_X25519); + lua_rawset(L, -3); +#endif + +#ifdef NID_X448 + lua_pushstring(L, "X448"); + lua_pushnumber(L, NID_X448); + lua_rawset(L, -3); +#endif + + lua_rawset(L, LUA_REGISTRYINDEX); +} + +void lsec_get_curves(lua_State *L) +{ + lua_newtable(L); + + lua_pushstring(L, "SSL:EC:CURVES"); + lua_rawget(L, LUA_REGISTRYINDEX); + + lua_pushnil(L); + while (lua_next(L, -2) != 0) { + lua_pop(L, 1); + lua_pushvalue(L, -1); + lua_pushboolean(L, 1); + lua_rawset(L, -5); + } + lua_pop(L, 1); +} + +#else + +void lsec_load_curves(lua_State *L) +{ + // do nothing +} + +void lsec_get_curves(lua_State *L) +{ + lua_newtable(L); +} + +#endif diff --git a/src/luasec/ec.h b/src/luasec/ec.h new file mode 100644 index 0000000..b37fa7f --- /dev/null +++ b/src/luasec/ec.h @@ -0,0 +1,22 @@ +/*-------------------------------------------------------------------------- + * LuaSec 1.0.2 + * + * Copyright (C) 2006-2021 Bruno Silvestre + * + *--------------------------------------------------------------------------*/ + +#ifndef LSEC_EC_H +#define LSEC_EC_H + +#include "../lua.h" + +#ifndef OPENSSL_NO_EC +#include <openssl/ec.h> + +EC_KEY *lsec_find_ec_key(lua_State *L, const char *str); +#endif + +void lsec_get_curves(lua_State *L); +void lsec_load_curves(lua_State *L); + +#endif diff --git a/src/luasec/https.lua b/src/luasec/https.lua new file mode 100644 index 0000000..ac0e479 --- /dev/null +++ b/src/luasec/https.lua @@ -0,0 +1,140 @@ +---------------------------------------------------------------------------- +-- LuaSec 1.0.2 +-- Copyright (C) 2009-2021 PUC-Rio +-- +-- Author: Pablo Musa +-- Author: Tomas Guisasola +--------------------------------------------------------------------------- + +local try = socket.try + +-- +-- Module +-- +local _M = { + _VERSION = "1.0.2", + _COPYRIGHT = "LuaSec 1.0.2 - Copyright (C) 2009-2021 PUC-Rio", + PORT = 443, + TIMEOUT = 60 +} + +-- TLS configuration +local cfg = { + protocol = "any", + options = {"all", "no_sslv2", "no_sslv3", "no_tlsv1"}, + verify = "none", +} + +-------------------------------------------------------------------- +-- Auxiliar Functions +-------------------------------------------------------------------- + +-- Insert default HTTPS port. +local function default_https_port(u) + return url.build(url.parse(u, {port = _M.PORT})) +end + +-- Convert an URL to a table according to Luasocket needs. +local function urlstring_totable(url, body, result_table) + url = { + url = default_https_port(url), + method = body and "POST" or "GET", + sink = ltn12.sink.table(result_table) + } + if body then + url.source = ltn12.source.string(body) + url.headers = { + ["content-length"] = #body, + ["content-type"] = "application/x-www-form-urlencoded", + } + end + return url +end + +-- Forward calls to the real connection object. +local function reg(conn) + local mt = getmetatable(conn.sock).__index + for name, method in pairs(mt) do + if type(method) == "function" then + conn[name] = function (self, ...) + return method(self.sock, ...) + end + end + end +end + +-- Return a function which performs the SSL/TLS connection. +local function tcp(params) + params = params or {} + -- Default settings + for k, v in pairs(cfg) do + params[k] = params[k] or v + end + -- Force client mode + params.mode = "client" + -- 'create' function for LuaSocket + return function () + local conn = {} + conn.sock = try(socket.tcp()) + local st = getmetatable(conn.sock).__index.settimeout + function conn:settimeout(...) + return st(self.sock, _M.TIMEOUT) + end + -- Replace TCP's connection function + function conn:connect(host, port) + try(self.sock:connect(host, port)) + self.sock = try(ssl.wrap(self.sock, params)) + self.sock:sni(host) + self.sock:settimeout(_M.TIMEOUT) + try(self.sock:dohandshake()) + reg(self, getmetatable(self.sock)) + return 1 + end + return conn + end +end + +-------------------------------------------------------------------- +-- Main Function +-------------------------------------------------------------------- + +-- Make a HTTP request over secure connection. This function receives +-- the same parameters of LuaSocket's HTTP module (except 'proxy' and +-- 'redirect') plus LuaSec parameters. +-- +-- @param url mandatory (string or table) +-- @param body optional (string) +-- @return (string if url == string or 1), code, headers, status +-- +local function request(url, body) + local result_table = {} + local stringrequest = type(url) == "string" + if stringrequest then + url = urlstring_totable(url, body, result_table) + else + url.url = default_https_port(url.url) + end + if http.PROXY or url.proxy then + return nil, "proxy not supported" + elseif url.redirect then + return nil, "redirect not supported" + elseif url.create then + return nil, "create function not permitted" + end + -- New 'create' function to establish a secure connection + url.create = tcp(url) + local res, code, headers, status = http.request(url) + if res and stringrequest then + return table.concat(result_table), code, headers, status + end + return res, code, headers, status +end + +-------------------------------------------------------------------------------- +-- Export module +-- + +_M.request = request +_M.tcp = tcp + +return _M diff --git a/src/luasec/luasocket/LICENSE b/src/luasec/luasocket/LICENSE new file mode 100644 index 0000000..eadb747 --- /dev/null +++ b/src/luasec/luasocket/LICENSE @@ -0,0 +1,21 @@ +LuaSocket 3.0-RC1 license +Copyright (C) 2004-2013 Diego Nehab + +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/src/luasec/luasocket/Makefile b/src/luasec/luasocket/Makefile new file mode 100644 index 0000000..b700fb6 --- /dev/null +++ b/src/luasec/luasocket/Makefile @@ -0,0 +1,26 @@ +OBJS= \ + io.o \ + buffer.o \ + timeout.o \ + usocket.o + +CC ?= cc +CFLAGS += $(MYCFLAGS) -DLUASOCKET_DEBUG +AR ?= ar +RANLIB ?= ranlib + +.PHONY: all clean + +all: libluasocket.a + +libluasocket.a: $(OBJS) + $(AR) rcu $@ $(OBJS) + $(RANLIB) $@ + +clean: + rm -f $(OBJS) libluasocket.a + +buffer.o: buffer.c buffer.h io.h timeout.h +io.o: io.c io.h timeout.h +timeout.o: timeout.c timeout.h +usocket.o: usocket.c socket.h io.h timeout.h usocket.h diff --git a/src/luasec/luasocket/buffer.c b/src/luasec/luasocket/buffer.c new file mode 100644 index 0000000..33882fc --- /dev/null +++ b/src/luasec/luasocket/buffer.c @@ -0,0 +1,282 @@ +/*=========================================================================*\ +* Input/Output interface for Lua programs +* LuaSocket toolkit +\*=========================================================================*/ +#include "../../lua.h" +#include "../../lauxlib.h" + +#include "buffer.h" + +/*=========================================================================*\ +* Internal function prototypes +\*=========================================================================*/ +static int recvraw(p_buffer buf, size_t wanted, luaL_Buffer *b); +static int recvline(p_buffer buf, luaL_Buffer *b); +static int recvall(p_buffer buf, luaL_Buffer *b); +static int buffer_get(p_buffer buf, const char **data, size_t *count); +static void buffer_skip(p_buffer buf, size_t count); +static int sendraw(p_buffer buf, const char *data, size_t count, size_t *sent); + +/* min and max macros */ +#ifndef MIN +#define MIN(x, y) ((x) < (y) ? x : y) +#endif +#ifndef MAX +#define MAX(x, y) ((x) > (y) ? x : y) +#endif + +/*=========================================================================*\ +* Exported functions +\*=========================================================================*/ +/*-------------------------------------------------------------------------*\ +* Initializes module +\*-------------------------------------------------------------------------*/ +int buffer_open(lua_State *L) { + (void) L; + return 0; +} + +/*-------------------------------------------------------------------------*\ +* Initializes C structure +\*-------------------------------------------------------------------------*/ +void buffer_init(p_buffer buf, p_io io, p_timeout tm) { + buf->first = buf->last = 0; + buf->io = io; + buf->tm = tm; + buf->received = buf->sent = 0; + buf->birthday = timeout_gettime(); +} + +/*-------------------------------------------------------------------------*\ +* object:getstats() interface +\*-------------------------------------------------------------------------*/ +int buffer_meth_getstats(lua_State *L, p_buffer buf) { + lua_pushnumber(L, (lua_Number) buf->received); + lua_pushnumber(L, (lua_Number) buf->sent); + lua_pushnumber(L, timeout_gettime() - buf->birthday); + return 3; +} + +/*-------------------------------------------------------------------------*\ +* object:setstats() interface +\*-------------------------------------------------------------------------*/ +int buffer_meth_setstats(lua_State *L, p_buffer buf) { + buf->received = (long) luaL_optnumber(L, 2, (lua_Number) buf->received); + buf->sent = (long) luaL_optnumber(L, 3, (lua_Number) buf->sent); + if (lua_isnumber(L, 4)) buf->birthday = timeout_gettime() - lua_tonumber(L, 4); + lua_pushnumber(L, 1); + return 1; +} + +/*-------------------------------------------------------------------------*\ +* object:send() interface +\*-------------------------------------------------------------------------*/ +int buffer_meth_send(lua_State *L, p_buffer buf) { + int top = lua_gettop(L); + int err = IO_DONE; + size_t size = 0, sent = 0; + const char *data = luaL_checklstring(L, 2, &size); + long start = (long) luaL_optnumber(L, 3, 1); + long end = (long) luaL_optnumber(L, 4, -1); +#ifdef LUASOCKET_DEBUG + p_timeout tm = timeout_markstart(buf->tm); +#endif + if (start < 0) start = (long) (size+start+1); + if (end < 0) end = (long) (size+end+1); + if (start < 1) start = (long) 1; + if (end > (long) size) end = (long) size; + if (start <= end) err = sendraw(buf, data+start-1, end-start+1, &sent); + /* check if there was an error */ + if (err != IO_DONE) { + lua_pushnil(L); + lua_pushstring(L, buf->io->error(buf->io->ctx, err)); + lua_pushnumber(L, (lua_Number) (sent+start-1)); + } else { + lua_pushnumber(L, (lua_Number) (sent+start-1)); + lua_pushnil(L); + lua_pushnil(L); + } +#ifdef LUASOCKET_DEBUG + /* push time elapsed during operation as the last return value */ + lua_pushnumber(L, timeout_gettime() - timeout_getstart(tm)); +#endif + return lua_gettop(L) - top; +} + +/*-------------------------------------------------------------------------*\ +* object:receive() interface +\*-------------------------------------------------------------------------*/ +int buffer_meth_receive(lua_State *L, p_buffer buf) { + luaL_Buffer b; + size_t size; + const char *part; + int err = IO_DONE; + int top = lua_gettop(L); + if (top < 3) { + lua_settop(L, 3); + top = 3; + } + part = luaL_optlstring(L, 3, "", &size); +#ifdef LUASOCKET_DEBUG + p_timeout tm = timeout_markstart(buf->tm); +#endif + /* initialize buffer with optional extra prefix + * (useful for concatenating previous partial results) */ + luaL_buffinit(L, &b); + luaL_addlstring(&b, part, size); + /* receive new patterns */ + if (!lua_isnumber(L, 2)) { + const char *p= luaL_optstring(L, 2, "*l"); + if (p[0] == '*' && p[1] == 'l') err = recvline(buf, &b); + else if (p[0] == '*' && p[1] == 'a') err = recvall(buf, &b); + else luaL_argcheck(L, 0, 2, "invalid receive pattern"); + /* get a fixed number of bytes (minus what was already partially + * received) */ + } else { + double n = lua_tonumber(L, 2); + size_t wanted = (size_t) n; + luaL_argcheck(L, n >= 0, 2, "invalid receive pattern"); + if (size == 0 || wanted > size) + err = recvraw(buf, wanted-size, &b); + } + /* check if there was an error */ + if (err != IO_DONE) { + /* we can't push anything in the stack before pushing the + * contents of the buffer. this is the reason for the complication */ + luaL_pushresult(&b); + lua_pushstring(L, buf->io->error(buf->io->ctx, err)); + lua_pushvalue(L, -2); + lua_pushnil(L); + lua_replace(L, -4); + } else { + luaL_pushresult(&b); + lua_pushnil(L); + lua_pushnil(L); + } +#ifdef LUASOCKET_DEBUG + /* push time elapsed during operation as the last return value */ + lua_pushnumber(L, timeout_gettime() - timeout_getstart(tm)); +#endif + return lua_gettop(L) - top; +} + +/*-------------------------------------------------------------------------*\ +* Determines if there is any data in the read buffer +\*-------------------------------------------------------------------------*/ +int buffer_isempty(p_buffer buf) { + return buf->first >= buf->last; +} + +/*=========================================================================*\ +* Internal functions +\*=========================================================================*/ +/*-------------------------------------------------------------------------*\ +* Sends a block of data (unbuffered) +\*-------------------------------------------------------------------------*/ +#define STEPSIZE 8192 +static int sendraw(p_buffer buf, const char *data, size_t count, size_t *sent) { + p_io io = buf->io; + p_timeout tm = buf->tm; + size_t total = 0; + int err = IO_DONE; + while (total < count && err == IO_DONE) { + size_t done = 0; + size_t step = (count-total <= STEPSIZE)? count-total: STEPSIZE; + err = io->send(io->ctx, data+total, step, &done, tm); + total += done; + } + *sent = total; + buf->sent += total; + return err; +} + +/*-------------------------------------------------------------------------*\ +* Reads a fixed number of bytes (buffered) +\*-------------------------------------------------------------------------*/ +static int recvraw(p_buffer buf, size_t wanted, luaL_Buffer *b) { + int err = IO_DONE; + size_t total = 0; + while (err == IO_DONE) { + size_t count; const char *data; + err = buffer_get(buf, &data, &count); + count = MIN(count, wanted - total); + luaL_addlstring(b, data, count); + buffer_skip(buf, count); + total += count; + if (total >= wanted) break; + } + return err; +} + +/*-------------------------------------------------------------------------*\ +* Reads everything until the connection is closed (buffered) +\*-------------------------------------------------------------------------*/ +static int recvall(p_buffer buf, luaL_Buffer *b) { + int err = IO_DONE; + size_t total = 0; + while (err == IO_DONE) { + const char *data; size_t count; + err = buffer_get(buf, &data, &count); + total += count; + luaL_addlstring(b, data, count); + buffer_skip(buf, count); + } + if (err == IO_CLOSED) { + if (total > 0) return IO_DONE; + else return IO_CLOSED; + } else return err; +} + +/*-------------------------------------------------------------------------*\ +* Reads a line terminated by a CR LF pair or just by a LF. The CR and LF +* are not returned by the function and are discarded from the buffer +\*-------------------------------------------------------------------------*/ +static int recvline(p_buffer buf, luaL_Buffer *b) { + int err = IO_DONE; + while (err == IO_DONE) { + size_t count, pos; const char *data; + err = buffer_get(buf, &data, &count); + pos = 0; + while (pos < count && data[pos] != '\n') { + /* we ignore all \r's */ + if (data[pos] != '\r') luaL_addchar(b, data[pos]); + pos++; + } + if (pos < count) { /* found '\n' */ + buffer_skip(buf, pos+1); /* skip '\n' too */ + break; /* we are done */ + } else /* reached the end of the buffer */ + buffer_skip(buf, pos); + } + return err; +} + +/*-------------------------------------------------------------------------*\ +* Skips a given number of bytes from read buffer. No data is read from the +* transport layer +\*-------------------------------------------------------------------------*/ +static void buffer_skip(p_buffer buf, size_t count) { + buf->received += count; + buf->first += count; + if (buffer_isempty(buf)) + buf->first = buf->last = 0; +} + +/*-------------------------------------------------------------------------*\ +* Return any data available in buffer, or get more data from transport layer +* if buffer is empty +\*-------------------------------------------------------------------------*/ +static int buffer_get(p_buffer buf, const char **data, size_t *count) { + int err = IO_DONE; + p_io io = buf->io; + p_timeout tm = buf->tm; + if (buffer_isempty(buf)) { + size_t got; + err = io->recv(io->ctx, buf->data, BUF_SIZE, &got, tm); + buf->first = 0; + buf->last = got; + } + *count = buf->last - buf->first; + *data = buf->data + buf->first; + return err; +} diff --git a/src/luasec/luasocket/buffer.h b/src/luasec/luasocket/buffer.h new file mode 100644 index 0000000..a9cbd3a --- /dev/null +++ b/src/luasec/luasocket/buffer.h @@ -0,0 +1,45 @@ +#ifndef BUF_H +#define BUF_H +/*=========================================================================*\ +* Input/Output interface for Lua programs +* LuaSocket toolkit +* +* Line patterns require buffering. Reading one character at a time involves +* too many system calls and is very slow. This module implements the +* LuaSocket interface for input/output on connected objects, as seen by +* Lua programs. +* +* Input is buffered. Output is *not* buffered because there was no simple +* way of making sure the buffered output data would ever be sent. +* +* The module is built on top of the I/O abstraction defined in io.h and the +* timeout management is done with the timeout.h interface. +\*=========================================================================*/ +#include "../../lua.h" + +#include "io.h" +#include "timeout.h" + +/* buffer size in bytes */ +#define BUF_SIZE 8192 + +/* buffer control structure */ +typedef struct t_buffer_ { + double birthday; /* throttle support info: creation time, */ + size_t sent, received; /* bytes sent, and bytes received */ + p_io io; /* IO driver used for this buffer */ + p_timeout tm; /* timeout management for this buffer */ + size_t first, last; /* index of first and last bytes of stored data */ + char data[BUF_SIZE]; /* storage space for buffer data */ +} t_buffer; +typedef t_buffer *p_buffer; + +int buffer_open(lua_State *L); +void buffer_init(p_buffer buf, p_io io, p_timeout tm); +int buffer_meth_send(lua_State *L, p_buffer buf); +int buffer_meth_receive(lua_State *L, p_buffer buf); +int buffer_meth_getstats(lua_State *L, p_buffer buf); +int buffer_meth_setstats(lua_State *L, p_buffer buf); +int buffer_isempty(p_buffer buf); + +#endif /* BUF_H */ diff --git a/src/luasec/luasocket/io.c b/src/luasec/luasocket/io.c new file mode 100644 index 0000000..35f46f7 --- /dev/null +++ b/src/luasec/luasocket/io.c @@ -0,0 +1,30 @@ +/*=========================================================================*\ +* Input/Output abstraction +* LuaSocket toolkit +\*=========================================================================*/ +#include "io.h" + +/*=========================================================================*\ +* Exported functions +\*=========================================================================*/ +/*-------------------------------------------------------------------------*\ +* Initializes C structure +\*-------------------------------------------------------------------------*/ +void io_init(p_io io, p_send send, p_recv recv, p_error error, void *ctx) { + io->send = send; + io->recv = recv; + io->error = error; + io->ctx = ctx; +} + +/*-------------------------------------------------------------------------*\ +* I/O error strings +\*-------------------------------------------------------------------------*/ +const char *io_strerror(int err) { + switch (err) { + case IO_DONE: return NULL; + case IO_CLOSED: return "closed"; + case IO_TIMEOUT: return "timeout"; + default: return "unknown error"; + } +} diff --git a/src/luasec/luasocket/io.h b/src/luasec/luasocket/io.h new file mode 100644 index 0000000..b1f35ad --- /dev/null +++ b/src/luasec/luasocket/io.h @@ -0,0 +1,65 @@ +#ifndef IO_H +#define IO_H +/*=========================================================================*\ +* Input/Output abstraction +* LuaSocket toolkit +* +* This module defines the interface that LuaSocket expects from the +* transport layer for streamed input/output. The idea is that if any +* transport implements this interface, then the buffer.c functions +* automatically work on it. +* +* The module socket.h implements this interface, and thus the module tcp.h +* is very simple. +\*=========================================================================*/ +#include <stdio.h> +#include "../../lua.h" + +#include "timeout.h" + +/* IO error codes */ +enum { + IO_DONE = 0, /* operation completed successfully */ + IO_TIMEOUT = -1, /* operation timed out */ + IO_CLOSED = -2, /* the connection has been closed */ + IO_UNKNOWN = -3 +}; + +/* interface to error message function */ +typedef const char *(*p_error) ( + void *ctx, /* context needed by send */ + int err /* error code */ +); + +/* interface to send function */ +typedef int (*p_send) ( + void *ctx, /* context needed by send */ + const char *data, /* pointer to buffer with data to send */ + size_t count, /* number of bytes to send from buffer */ + size_t *sent, /* number of bytes sent uppon return */ + p_timeout tm /* timeout control */ +); + +/* interface to recv function */ +typedef int (*p_recv) ( + void *ctx, /* context needed by recv */ + char *data, /* pointer to buffer where data will be written */ + size_t count, /* number of bytes to receive into buffer */ + size_t *got, /* number of bytes received uppon return */ + p_timeout tm /* timeout control */ +); + +/* IO driver definition */ +typedef struct t_io_ { + void *ctx; /* context needed by send/recv */ + p_send send; /* send function pointer */ + p_recv recv; /* receive function pointer */ + p_error error; /* strerror function */ +} t_io; +typedef t_io *p_io; + +void io_init(p_io io, p_send send, p_recv recv, p_error error, void *ctx); +const char *io_strerror(int err); + +#endif /* IO_H */ + diff --git a/src/luasec/luasocket/socket.h b/src/luasec/luasocket/socket.h new file mode 100644 index 0000000..07c20fe --- /dev/null +++ b/src/luasec/luasocket/socket.h @@ -0,0 +1,78 @@ +#ifndef SOCKET_H +#define SOCKET_H +/*=========================================================================*\ +* Socket compatibilization module +* LuaSocket toolkit +* +* BSD Sockets and WinSock are similar, but there are a few irritating +* differences. Also, not all *nix platforms behave the same. This module +* (and the associated usocket.h and wsocket.h) factor these differences and +* creates a interface compatible with the io.h module. +\*=========================================================================*/ +#include "io.h" + +/*=========================================================================*\ +* Platform specific compatibilization +\*=========================================================================*/ +#ifdef _WIN32 +#include "wsocket.h" +#else +#include "usocket.h" +#endif + +/*=========================================================================*\ +* The connect and accept functions accept a timeout and their +* implementations are somewhat complicated. We chose to move +* the timeout control into this module for these functions in +* order to simplify the modules that use them. +\*=========================================================================*/ +#include "timeout.h" + +/* we are lazy... */ +typedef struct sockaddr SA; + +/*=========================================================================*\ +* Functions below implement a comfortable platform independent +* interface to sockets +\*=========================================================================*/ +int socket_open(void); +int socket_close(void); +void socket_destroy(p_socket ps); +void socket_shutdown(p_socket ps, int how); +int socket_sendto(p_socket ps, const char *data, size_t count, + size_t *sent, SA *addr, socklen_t addr_len, p_timeout tm); +int socket_recvfrom(p_socket ps, char *data, size_t count, + size_t *got, SA *addr, socklen_t *addr_len, p_timeout tm); + +void socket_setnonblocking(p_socket ps); +void socket_setblocking(p_socket ps); + +int socket_waitfd(p_socket ps, int sw, p_timeout tm); +int socket_select(t_socket n, fd_set *rfds, fd_set *wfds, fd_set *efds, + p_timeout tm); + +int socket_connect(p_socket ps, SA *addr, socklen_t addr_len, p_timeout tm); +int socket_create(p_socket ps, int domain, int type, int protocol); +int socket_bind(p_socket ps, SA *addr, socklen_t addr_len); +int socket_listen(p_socket ps, int backlog); +int socket_accept(p_socket ps, p_socket pa, SA *addr, + socklen_t *addr_len, p_timeout tm); + +const char *socket_hoststrerror(int err); +const char *socket_gaistrerror(int err); +const char *socket_strerror(int err); + +/* these are perfect to use with the io abstraction module + and the buffered input module */ +int socket_send(p_socket ps, const char *data, size_t count, + size_t *sent, p_timeout tm); +int socket_recv(p_socket ps, char *data, size_t count, size_t *got, p_timeout tm); +int socket_write(p_socket ps, const char *data, size_t count, + size_t *sent, p_timeout tm); +int socket_read(p_socket ps, char *data, size_t count, size_t *got, p_timeout tm); +const char *socket_ioerror(p_socket ps, int err); + +int socket_gethostbyaddr(const char *addr, socklen_t len, struct hostent **hp); +int socket_gethostbyname(const char *addr, struct hostent **hp); + +#endif /* SOCKET_H */ diff --git a/src/luasec/luasocket/timeout.c b/src/luasec/luasocket/timeout.c new file mode 100644 index 0000000..d1faded --- /dev/null +++ b/src/luasec/luasocket/timeout.c @@ -0,0 +1,220 @@ +/*=========================================================================*\ +* Timeout management functions +* LuaSocket toolkit +\*=========================================================================*/ +#include <stdio.h> +#include <limits.h> +#include <float.h> + +#include "../../lua.h" +#include "../../lauxlib.h" + +#include "timeout.h" + +#ifdef _WIN32 +#include <windows.h> +#else +#include <time.h> +#include <sys/time.h> +#endif + +/* min and max macros */ +#ifndef MIN +#define MIN(x, y) ((x) < (y) ? x : y) +#endif +#ifndef MAX +#define MAX(x, y) ((x) > (y) ? x : y) +#endif + +/*=========================================================================*\ +* Internal function prototypes +\*=========================================================================*/ +static int timeout_lua_gettime(lua_State *L); +static int timeout_lua_sleep(lua_State *L); + +static luaL_Reg func[] = { + { "gettime", timeout_lua_gettime }, + { "sleep", timeout_lua_sleep }, + { NULL, NULL } +}; + +/*=========================================================================*\ +* Exported functions. +\*=========================================================================*/ +/*-------------------------------------------------------------------------*\ +* Initialize structure +\*-------------------------------------------------------------------------*/ +void timeout_init(p_timeout tm, double block, double total) { + tm->block = block; + tm->total = total; +} + +/*-------------------------------------------------------------------------*\ +* Determines how much time we have left for the next system call, +* if the previous call was successful +* Input +* tm: timeout control structure +* Returns +* the number of ms left or -1 if there is no time limit +\*-------------------------------------------------------------------------*/ +double timeout_get(p_timeout tm) { + if (tm->block < 0.0 && tm->total < 0.0) { + return -1; + } else if (tm->block < 0.0) { + double t = tm->total - timeout_gettime() + tm->start; + return MAX(t, 0.0); + } else if (tm->total < 0.0) { + return tm->block; + } else { + double t = tm->total - timeout_gettime() + tm->start; + return MIN(tm->block, MAX(t, 0.0)); + } +} + +/*-------------------------------------------------------------------------*\ +* Returns time since start of operation +* Input +* tm: timeout control structure +* Returns +* start field of structure +\*-------------------------------------------------------------------------*/ +double timeout_getstart(p_timeout tm) { + return tm->start; +} + +/*-------------------------------------------------------------------------*\ +* Determines how much time we have left for the next system call, +* if the previous call was a failure +* Input +* tm: timeout control structure +* Returns +* the number of ms left or -1 if there is no time limit +\*-------------------------------------------------------------------------*/ +double timeout_getretry(p_timeout tm) { + if (tm->block < 0.0 && tm->total < 0.0) { + return -1; + } else if (tm->block < 0.0) { + double t = tm->total - timeout_gettime() + tm->start; + return MAX(t, 0.0); + } else if (tm->total < 0.0) { + double t = tm->block - timeout_gettime() + tm->start; + return MAX(t, 0.0); + } else { + double t = tm->total - timeout_gettime() + tm->start; + return MIN(tm->block, MAX(t, 0.0)); + } +} + +/*-------------------------------------------------------------------------*\ +* Marks the operation start time in structure +* Input +* tm: timeout control structure +\*-------------------------------------------------------------------------*/ +p_timeout timeout_markstart(p_timeout tm) { + tm->start = timeout_gettime(); + return tm; +} + +/*-------------------------------------------------------------------------*\ +* Gets time in s, relative to January 1, 1970 (UTC) +* Returns +* time in s. +\*-------------------------------------------------------------------------*/ +#ifdef _WIN32 +double timeout_gettime(void) { + FILETIME ft; + double t; + GetSystemTimeAsFileTime(&ft); + /* Windows file time (time since January 1, 1601 (UTC)) */ + t = ft.dwLowDateTime/1.0e7 + ft.dwHighDateTime*(4294967296.0/1.0e7); + /* convert to Unix Epoch time (time since January 1, 1970 (UTC)) */ + return (t - 11644473600.0); +} +#else +double timeout_gettime(void) { + struct timeval v; + gettimeofday(&v, (struct timezone *) NULL); + /* Unix Epoch time (time since January 1, 1970 (UTC)) */ + return v.tv_sec + v.tv_usec/1.0e6; +} +#endif + +/*-------------------------------------------------------------------------*\ +* Initializes module +\*-------------------------------------------------------------------------*/ +int timeout_open(lua_State *L) { +#if LUA_VERSION_NUM > 501 && !defined(LUA_COMPAT_MODULE) + luaL_setfuncs(L, func, 0); +#else + luaL_openlib(L, NULL, func, 0); +#endif + return 0; +} + +/*-------------------------------------------------------------------------*\ +* Sets timeout values for IO operations +* Lua Input: base, time [, mode] +* time: time out value in seconds +* mode: "b" for block timeout, "t" for total timeout. (default: b) +\*-------------------------------------------------------------------------*/ +int timeout_meth_settimeout(lua_State *L, p_timeout tm) { + double t = luaL_optnumber(L, 2, -1); + const char *mode = luaL_optstring(L, 3, "b"); + switch (*mode) { + case 'b': + tm->block = t; + break; + case 'r': case 't': + tm->total = t; + break; + default: + luaL_argcheck(L, 0, 3, "invalid timeout mode"); + break; + } + lua_pushnumber(L, 1); + return 1; +} + +/*=========================================================================*\ +* Test support functions +\*=========================================================================*/ +/*-------------------------------------------------------------------------*\ +* Returns the time the system has been up, in secconds. +\*-------------------------------------------------------------------------*/ +static int timeout_lua_gettime(lua_State *L) +{ + lua_pushnumber(L, timeout_gettime()); + return 1; +} + +/*-------------------------------------------------------------------------*\ +* Sleep for n seconds. +\*-------------------------------------------------------------------------*/ +#ifdef _WIN32 +int timeout_lua_sleep(lua_State *L) +{ + double n = luaL_checknumber(L, 1); + if (n < 0.0) n = 0.0; + if (n < DBL_MAX/1000.0) n *= 1000.0; + if (n > INT_MAX) n = INT_MAX; + Sleep((int)n); + return 0; +} +#else +int timeout_lua_sleep(lua_State *L) +{ + double n = luaL_checknumber(L, 1); + struct timespec t, r; + if (n < 0.0) n = 0.0; + if (n > INT_MAX) n = INT_MAX; + t.tv_sec = (int) n; + n -= t.tv_sec; + t.tv_nsec = (int) (n * 1000000000); + if (t.tv_nsec >= 1000000000) t.tv_nsec = 999999999; + while (nanosleep(&t, &r) != 0) { + t.tv_sec = r.tv_sec; + t.tv_nsec = r.tv_nsec; + } + return 0; +} +#endif diff --git a/src/luasec/luasocket/timeout.h b/src/luasec/luasocket/timeout.h new file mode 100644 index 0000000..69751da --- /dev/null +++ b/src/luasec/luasocket/timeout.h @@ -0,0 +1,28 @@ +#ifndef TIMEOUT_H +#define TIMEOUT_H +/*=========================================================================*\ +* Timeout management functions +* LuaSocket toolkit +\*=========================================================================*/ +#include "../../lua.h" + +/* timeout control structure */ +typedef struct t_timeout_ { + double block; /* maximum time for blocking calls */ + double total; /* total number of milliseconds for operation */ + double start; /* time of start of operation */ +} t_timeout; +typedef t_timeout *p_timeout; + +int timeout_open(lua_State *L); +void timeout_init(p_timeout tm, double block, double total); +double timeout_get(p_timeout tm); +double timeout_getretry(p_timeout tm); +p_timeout timeout_markstart(p_timeout tm); +double timeout_getstart(p_timeout tm); +double timeout_gettime(void); +int timeout_meth_settimeout(lua_State *L, p_timeout tm); + +#define timeout_iszero(tm) ((tm)->block == 0.0) + +#endif /* TIMEOUT_H */ diff --git a/src/luasec/luasocket/usocket.c b/src/luasec/luasocket/usocket.c new file mode 100644 index 0000000..775e6fd --- /dev/null +++ b/src/luasec/luasocket/usocket.c @@ -0,0 +1,439 @@ +/*=========================================================================*\ +* Socket compatibilization module for Unix +* LuaSocket toolkit +* +* The code is now interrupt-safe. +* The penalty of calling select to avoid busy-wait is only paid when +* the I/O call fail in the first place. +\*=========================================================================*/ +#include <string.h> +#include <signal.h> + +#include "socket.h" + +/*-------------------------------------------------------------------------*\ +* Wait for readable/writable/connected socket with timeout +\*-------------------------------------------------------------------------*/ +#ifndef SOCKET_SELECT +int socket_waitfd(p_socket ps, int sw, p_timeout tm) { + int ret; + struct pollfd pfd; + pfd.fd = *ps; + pfd.events = sw; + pfd.revents = 0; + if (timeout_iszero(tm)) return IO_TIMEOUT; /* optimize timeout == 0 case */ + do { + int t = (int)(timeout_getretry(tm)*1e3); + ret = poll(&pfd, 1, t >= 0? t: -1); + } while (ret == -1 && errno == EINTR); + if (ret == -1) return errno; + if (ret == 0) return IO_TIMEOUT; + if (sw == WAITFD_C && (pfd.revents & (POLLIN|POLLERR))) return IO_CLOSED; + return IO_DONE; +} +#else +int socket_waitfd(p_socket ps, int sw, p_timeout tm) { + int ret; + fd_set rfds, wfds, *rp, *wp; + struct timeval tv, *tp; + double t; + if (*ps >= FD_SETSIZE) return EINVAL; + if (timeout_iszero(tm)) return IO_TIMEOUT; /* optimize timeout == 0 case */ + do { + /* must set bits within loop, because select may have modified them */ + rp = wp = NULL; + if (sw & WAITFD_R) { FD_ZERO(&rfds); FD_SET(*ps, &rfds); rp = &rfds; } + if (sw & WAITFD_W) { FD_ZERO(&wfds); FD_SET(*ps, &wfds); wp = &wfds; } + t = timeout_getretry(tm); + tp = NULL; + if (t >= 0.0) { + tv.tv_sec = (int)t; + tv.tv_usec = (int)((t-tv.tv_sec)*1.0e6); + tp = &tv; + } + ret = select(*ps+1, rp, wp, NULL, tp); + } while (ret == -1 && errno == EINTR); + if (ret == -1) return errno; + if (ret == 0) return IO_TIMEOUT; + if (sw == WAITFD_C && FD_ISSET(*ps, &rfds)) return IO_CLOSED; + return IO_DONE; +} +#endif + + +/*-------------------------------------------------------------------------*\ +* Initializes module +\*-------------------------------------------------------------------------*/ +int socket_open(void) { + /* instals a handler to ignore sigpipe or it will crash us */ + signal(SIGPIPE, SIG_IGN); + return 1; +} + +/*-------------------------------------------------------------------------*\ +* Close module +\*-------------------------------------------------------------------------*/ +int socket_close(void) { + return 1; +} + +/*-------------------------------------------------------------------------*\ +* Close and inutilize socket +\*-------------------------------------------------------------------------*/ +void socket_destroy(p_socket ps) { + if (*ps != SOCKET_INVALID) { + socket_setblocking(ps); + close(*ps); + *ps = SOCKET_INVALID; + } +} + +/*-------------------------------------------------------------------------*\ +* Select with timeout control +\*-------------------------------------------------------------------------*/ +int socket_select(t_socket n, fd_set *rfds, fd_set *wfds, fd_set *efds, + p_timeout tm) { + int ret; + do { + struct timeval tv; + double t = timeout_getretry(tm); + tv.tv_sec = (int) t; + tv.tv_usec = (int) ((t - tv.tv_sec) * 1.0e6); + /* timeout = 0 means no wait */ + ret = select(n, rfds, wfds, efds, t >= 0.0 ? &tv: NULL); + } while (ret < 0 && errno == EINTR); + return ret; +} + +/*-------------------------------------------------------------------------*\ +* Creates and sets up a socket +\*-------------------------------------------------------------------------*/ +int socket_create(p_socket ps, int domain, int type, int protocol) { + *ps = socket(domain, type, protocol); + if (*ps != SOCKET_INVALID) return IO_DONE; + else return errno; +} + +/*-------------------------------------------------------------------------*\ +* Binds or returns error message +\*-------------------------------------------------------------------------*/ +int socket_bind(p_socket ps, SA *addr, socklen_t len) { + int err = IO_DONE; + socket_setblocking(ps); + if (bind(*ps, addr, len) < 0) err = errno; + socket_setnonblocking(ps); + return err; +} + +/*-------------------------------------------------------------------------*\ +* +\*-------------------------------------------------------------------------*/ +int socket_listen(p_socket ps, int backlog) { + int err = IO_DONE; + socket_setblocking(ps); + if (listen(*ps, backlog)) err = errno; + socket_setnonblocking(ps); + return err; +} + +/*-------------------------------------------------------------------------*\ +* +\*-------------------------------------------------------------------------*/ +void socket_shutdown(p_socket ps, int how) { + socket_setblocking(ps); + shutdown(*ps, how); + socket_setnonblocking(ps); +} + +/*-------------------------------------------------------------------------*\ +* Connects or returns error message +\*-------------------------------------------------------------------------*/ +int socket_connect(p_socket ps, SA *addr, socklen_t len, p_timeout tm) { + int err; + /* avoid calling on closed sockets */ + if (*ps == SOCKET_INVALID) return IO_CLOSED; + /* call connect until done or failed without being interrupted */ + do if (connect(*ps, addr, len) == 0) return IO_DONE; + while ((err = errno) == EINTR); + /* if connection failed immediately, return error code */ + if (err != EINPROGRESS && err != EAGAIN) return err; + /* zero timeout case optimization */ + if (timeout_iszero(tm)) return IO_TIMEOUT; + /* wait until we have the result of the connection attempt or timeout */ + err = socket_waitfd(ps, WAITFD_C, tm); + if (err == IO_CLOSED) { + if (recv(*ps, (char *) &err, 0, 0) == 0) return IO_DONE; + else return errno; + } else return err; +} + +/*-------------------------------------------------------------------------*\ +* Accept with timeout +\*-------------------------------------------------------------------------*/ +int socket_accept(p_socket ps, p_socket pa, SA *addr, socklen_t *len, p_timeout tm) { + if (*ps == SOCKET_INVALID) return IO_CLOSED; + for ( ;; ) { + int err; + if ((*pa = accept(*ps, addr, len)) != SOCKET_INVALID) return IO_DONE; + err = errno; + if (err == EINTR) continue; + if (err != EAGAIN && err != ECONNABORTED) return err; + if ((err = socket_waitfd(ps, WAITFD_R, tm)) != IO_DONE) return err; + } + /* can't reach here */ + return IO_UNKNOWN; +} + +/*-------------------------------------------------------------------------*\ +* Send with timeout +\*-------------------------------------------------------------------------*/ +int socket_send(p_socket ps, const char *data, size_t count, + size_t *sent, p_timeout tm) +{ + int err; + *sent = 0; + /* avoid making system calls on closed sockets */ + if (*ps == SOCKET_INVALID) return IO_CLOSED; + /* loop until we send something or we give up on error */ + for ( ;; ) { + long put = (long) send(*ps, data, count, 0); + /* if we sent anything, we are done */ + if (put >= 0) { + *sent = put; + return IO_DONE; + } + err = errno; + /* EPIPE means the connection was closed */ + if (err == EPIPE) return IO_CLOSED; + /* we call was interrupted, just try again */ + if (err == EINTR) continue; + /* if failed fatal reason, report error */ + if (err != EAGAIN) return err; + /* wait until we can send something or we timeout */ + if ((err = socket_waitfd(ps, WAITFD_W, tm)) != IO_DONE) return err; + } + /* can't reach here */ + return IO_UNKNOWN; +} + +/*-------------------------------------------------------------------------*\ +* Sendto with timeout +\*-------------------------------------------------------------------------*/ +int socket_sendto(p_socket ps, const char *data, size_t count, size_t *sent, + SA *addr, socklen_t len, p_timeout tm) +{ + int err; + *sent = 0; + if (*ps == SOCKET_INVALID) return IO_CLOSED; + for ( ;; ) { + long put = (long) sendto(*ps, data, count, 0, addr, len); + if (put >= 0) { + *sent = put; + return IO_DONE; + } + err = errno; + if (err == EPIPE) return IO_CLOSED; + if (err == EINTR) continue; + if (err != EAGAIN) return err; + if ((err = socket_waitfd(ps, WAITFD_W, tm)) != IO_DONE) return err; + } + return IO_UNKNOWN; +} + +/*-------------------------------------------------------------------------*\ +* Receive with timeout +\*-------------------------------------------------------------------------*/ +int socket_recv(p_socket ps, char *data, size_t count, size_t *got, p_timeout tm) { + int err; + *got = 0; + if (*ps == SOCKET_INVALID) return IO_CLOSED; + for ( ;; ) { + long taken = (long) recv(*ps, data, count, 0); + if (taken > 0) { + *got = taken; + return IO_DONE; + } + err = errno; + if (taken == 0) return IO_CLOSED; + if (err == EINTR) continue; + if (err != EAGAIN) return err; + if ((err = socket_waitfd(ps, WAITFD_R, tm)) != IO_DONE) return err; + } + return IO_UNKNOWN; +} + +/*-------------------------------------------------------------------------*\ +* Recvfrom with timeout +\*-------------------------------------------------------------------------*/ +int socket_recvfrom(p_socket ps, char *data, size_t count, size_t *got, + SA *addr, socklen_t *len, p_timeout tm) { + int err; + *got = 0; + if (*ps == SOCKET_INVALID) return IO_CLOSED; + for ( ;; ) { + long taken = (long) recvfrom(*ps, data, count, 0, addr, len); + if (taken > 0) { + *got = taken; + return IO_DONE; + } + err = errno; + if (taken == 0) return IO_CLOSED; + if (err == EINTR) continue; + if (err != EAGAIN) return err; + if ((err = socket_waitfd(ps, WAITFD_R, tm)) != IO_DONE) return err; + } + return IO_UNKNOWN; +} + + +/*-------------------------------------------------------------------------*\ +* Write with timeout +* +* socket_read and socket_write are cut-n-paste of socket_send and socket_recv, +* with send/recv replaced with write/read. We can't just use write/read +* in the socket version, because behaviour when size is zero is different. +\*-------------------------------------------------------------------------*/ +int socket_write(p_socket ps, const char *data, size_t count, + size_t *sent, p_timeout tm) +{ + int err; + *sent = 0; + /* avoid making system calls on closed sockets */ + if (*ps == SOCKET_INVALID) return IO_CLOSED; + /* loop until we send something or we give up on error */ + for ( ;; ) { + long put = (long) write(*ps, data, count); + /* if we sent anything, we are done */ + if (put >= 0) { + *sent = put; + return IO_DONE; + } + err = errno; + /* EPIPE means the connection was closed */ + if (err == EPIPE) return IO_CLOSED; + /* we call was interrupted, just try again */ + if (err == EINTR) continue; + /* if failed fatal reason, report error */ + if (err != EAGAIN) return err; + /* wait until we can send something or we timeout */ + if ((err = socket_waitfd(ps, WAITFD_W, tm)) != IO_DONE) return err; + } + /* can't reach here */ + return IO_UNKNOWN; +} + +/*-------------------------------------------------------------------------*\ +* Read with timeout +* See note for socket_write +\*-------------------------------------------------------------------------*/ +int socket_read(p_socket ps, char *data, size_t count, size_t *got, p_timeout tm) { + int err; + *got = 0; + if (*ps == SOCKET_INVALID) return IO_CLOSED; + for ( ;; ) { + long taken = (long) read(*ps, data, count); + if (taken > 0) { + *got = taken; + return IO_DONE; + } + err = errno; + if (taken == 0) return IO_CLOSED; + if (err == EINTR) continue; + if (err != EAGAIN) return err; + if ((err = socket_waitfd(ps, WAITFD_R, tm)) != IO_DONE) return err; + } + return IO_UNKNOWN; +} + +/*-------------------------------------------------------------------------*\ +* Put socket into blocking mode +\*-------------------------------------------------------------------------*/ +void socket_setblocking(p_socket ps) { + int flags = fcntl(*ps, F_GETFL, 0); + flags &= (~(O_NONBLOCK)); + fcntl(*ps, F_SETFL, flags); +} + +/*-------------------------------------------------------------------------*\ +* Put socket into non-blocking mode +\*-------------------------------------------------------------------------*/ +void socket_setnonblocking(p_socket ps) { + int flags = fcntl(*ps, F_GETFL, 0); + flags |= O_NONBLOCK; + fcntl(*ps, F_SETFL, flags); +} + +/*-------------------------------------------------------------------------*\ +* DNS helpers +\*-------------------------------------------------------------------------*/ +int socket_gethostbyaddr(const char *addr, socklen_t len, struct hostent **hp) { + *hp = gethostbyaddr(addr, len, AF_INET); + if (*hp) return IO_DONE; + else if (h_errno) return h_errno; + else if (errno) return errno; + else return IO_UNKNOWN; +} + +int socket_gethostbyname(const char *addr, struct hostent **hp) { + *hp = gethostbyname(addr); + if (*hp) return IO_DONE; + else if (h_errno) return h_errno; + else if (errno) return errno; + else return IO_UNKNOWN; +} + +/*-------------------------------------------------------------------------*\ +* Error translation functions +* Make sure important error messages are standard +\*-------------------------------------------------------------------------*/ +const char *socket_hoststrerror(int err) { + if (err <= 0) return io_strerror(err); + switch (err) { + case HOST_NOT_FOUND: return "host not found"; + default: return hstrerror(err); + } +} + +const char *socket_strerror(int err) { + if (err <= 0) return io_strerror(err); + switch (err) { + case EADDRINUSE: return "address already in use"; + case EISCONN: return "already connected"; + case EACCES: return "permission denied"; + case ECONNREFUSED: return "connection refused"; + case ECONNABORTED: return "closed"; + case ECONNRESET: return "closed"; + case ETIMEDOUT: return "timeout"; + default: return strerror(err); + } +} + +const char *socket_ioerror(p_socket ps, int err) { + (void) ps; + return socket_strerror(err); +} + +const char *socket_gaistrerror(int err) { + if (err == 0) return NULL; + switch (err) { + case EAI_AGAIN: return "temporary failure in name resolution"; + case EAI_BADFLAGS: return "invalid value for ai_flags"; +#ifdef EAI_BADHINTS + case EAI_BADHINTS: return "invalid value for hints"; +#endif + case EAI_FAIL: return "non-recoverable failure in name resolution"; + case EAI_FAMILY: return "ai_family not supported"; + case EAI_MEMORY: return "memory allocation failure"; + case EAI_NONAME: + return "host or service not provided, or not known"; + case EAI_OVERFLOW: return "argument buffer overflow"; +#ifdef EAI_PROTOCOL + case EAI_PROTOCOL: return "resolved protocol is unknown"; +#endif + case EAI_SERVICE: return "service not supported for socket type"; + case EAI_SOCKTYPE: return "ai_socktype not supported"; + case EAI_SYSTEM: return strerror(errno); + default: return gai_strerror(err); + } +} + diff --git a/src/luasec/luasocket/usocket.h b/src/luasec/luasocket/usocket.h new file mode 100644 index 0000000..ecbcd8e --- /dev/null +++ b/src/luasec/luasocket/usocket.h @@ -0,0 +1,70 @@ +#ifndef USOCKET_H +#define USOCKET_H +/*=========================================================================*\ +* Socket compatibilization module for Unix +* LuaSocket toolkit +\*=========================================================================*/ + +/*=========================================================================*\ +* BSD include files +\*=========================================================================*/ +/* error codes */ +#include <errno.h> +/* close function */ +#include <unistd.h> +/* fnctnl function and associated constants */ +#include <fcntl.h> +/* struct sockaddr */ +#include <sys/types.h> +/* socket function */ +#include <sys/socket.h> +/* struct timeval */ +#include <sys/time.h> +/* gethostbyname and gethostbyaddr functions */ +#include <netdb.h> +/* sigpipe handling */ +#include <signal.h> +/* IP stuff*/ +#include <netinet/in.h> +#include <arpa/inet.h> +/* TCP options (nagle algorithm disable) */ +#include <netinet/tcp.h> +#include <net/if.h> + +#ifndef SOCKET_SELECT +#include <sys/poll.h> +#define WAITFD_R POLLIN +#define WAITFD_W POLLOUT +#define WAITFD_C (POLLIN|POLLOUT) +#else +#define WAITFD_R 1 +#define WAITFD_W 2 +#define WAITFD_C (WAITFD_R|WAITFD_W) +#endif + +#ifndef SO_REUSEPORT +#define SO_REUSEPORT SO_REUSEADDR +#endif + +/* Some platforms use IPV6_JOIN_GROUP instead if + * IPV6_ADD_MEMBERSHIP. The semantics are same, though. */ +#ifndef IPV6_ADD_MEMBERSHIP +#ifdef IPV6_JOIN_GROUP +#define IPV6_ADD_MEMBERSHIP IPV6_JOIN_GROUP +#endif /* IPV6_JOIN_GROUP */ +#endif /* !IPV6_ADD_MEMBERSHIP */ + +/* Same with IPV6_DROP_MEMBERSHIP / IPV6_LEAVE_GROUP. */ +#ifndef IPV6_DROP_MEMBERSHIP +#ifdef IPV6_LEAVE_GROUP +#define IPV6_DROP_MEMBERSHIP IPV6_LEAVE_GROUP +#endif /* IPV6_LEAVE_GROUP */ +#endif /* !IPV6_DROP_MEMBERSHIP */ + +typedef int t_socket; +typedef t_socket *p_socket; +typedef struct sockaddr_storage t_sockaddr_storage; + +#define SOCKET_INVALID (-1) + +#endif /* USOCKET_H */ diff --git a/src/luasec/luasocket/wsocket.c b/src/luasec/luasocket/wsocket.c new file mode 100644 index 0000000..8c7640e --- /dev/null +++ b/src/luasec/luasocket/wsocket.c @@ -0,0 +1,429 @@ +/*=========================================================================*\ +* Socket compatibilization module for Win32 +* LuaSocket toolkit +* +* The penalty of calling select to avoid busy-wait is only paid when +* the I/O call fail in the first place. +\*=========================================================================*/ +#include <string.h> + +#include "socket.h" + +/* WinSock doesn't have a strerror... */ +static const char *wstrerror(int err); + +/*-------------------------------------------------------------------------*\ +* Initializes module +\*-------------------------------------------------------------------------*/ +int socket_open(void) { + WSADATA wsaData; + WORD wVersionRequested = MAKEWORD(2, 0); + int err = WSAStartup(wVersionRequested, &wsaData ); + if (err != 0) return 0; + if ((LOBYTE(wsaData.wVersion) != 2 || HIBYTE(wsaData.wVersion) != 0) && + (LOBYTE(wsaData.wVersion) != 1 || HIBYTE(wsaData.wVersion) != 1)) { + WSACleanup(); + return 0; + } + return 1; +} + +/*-------------------------------------------------------------------------*\ +* Close module +\*-------------------------------------------------------------------------*/ +int socket_close(void) { + WSACleanup(); + return 1; +} + +/*-------------------------------------------------------------------------*\ +* Wait for readable/writable/connected socket with timeout +\*-------------------------------------------------------------------------*/ +int socket_waitfd(p_socket ps, int sw, p_timeout tm) { + int ret; + fd_set rfds, wfds, efds, *rp = NULL, *wp = NULL, *ep = NULL; + struct timeval tv, *tp = NULL; + double t; + if (timeout_iszero(tm)) return IO_TIMEOUT; /* optimize timeout == 0 case */ + if (sw & WAITFD_R) { + FD_ZERO(&rfds); + FD_SET(*ps, &rfds); + rp = &rfds; + } + if (sw & WAITFD_W) { FD_ZERO(&wfds); FD_SET(*ps, &wfds); wp = &wfds; } + if (sw & WAITFD_C) { FD_ZERO(&efds); FD_SET(*ps, &efds); ep = &efds; } + if ((t = timeout_get(tm)) >= 0.0) { + tv.tv_sec = (int) t; + tv.tv_usec = (int) ((t-tv.tv_sec)*1.0e6); + tp = &tv; + } + ret = select(0, rp, wp, ep, tp); + if (ret == -1) return WSAGetLastError(); + if (ret == 0) return IO_TIMEOUT; + if (sw == WAITFD_C && FD_ISSET(*ps, &efds)) return IO_CLOSED; + return IO_DONE; +} + +/*-------------------------------------------------------------------------*\ +* Select with int timeout in ms +\*-------------------------------------------------------------------------*/ +int socket_select(t_socket n, fd_set *rfds, fd_set *wfds, fd_set *efds, + p_timeout tm) { + struct timeval tv; + double t = timeout_get(tm); + tv.tv_sec = (int) t; + tv.tv_usec = (int) ((t - tv.tv_sec) * 1.0e6); + if (n <= 0) { + Sleep((DWORD) (1000*t)); + return 0; + } else return select(0, rfds, wfds, efds, t >= 0.0? &tv: NULL); +} + +/*-------------------------------------------------------------------------*\ +* Close and inutilize socket +\*-------------------------------------------------------------------------*/ +void socket_destroy(p_socket ps) { + if (*ps != SOCKET_INVALID) { + socket_setblocking(ps); /* close can take a long time on WIN32 */ + closesocket(*ps); + *ps = SOCKET_INVALID; + } +} + +/*-------------------------------------------------------------------------*\ +* +\*-------------------------------------------------------------------------*/ +void socket_shutdown(p_socket ps, int how) { + socket_setblocking(ps); + shutdown(*ps, how); + socket_setnonblocking(ps); +} + +/*-------------------------------------------------------------------------*\ +* Creates and sets up a socket +\*-------------------------------------------------------------------------*/ +int socket_create(p_socket ps, int domain, int type, int protocol) { + *ps = socket(domain, type, protocol); + if (*ps != SOCKET_INVALID) return IO_DONE; + else return WSAGetLastError(); +} + +/*-------------------------------------------------------------------------*\ +* Connects or returns error message +\*-------------------------------------------------------------------------*/ +int socket_connect(p_socket ps, SA *addr, socklen_t len, p_timeout tm) { + int err; + /* don't call on closed socket */ + if (*ps == SOCKET_INVALID) return IO_CLOSED; + /* ask system to connect */ + if (connect(*ps, addr, len) == 0) return IO_DONE; + /* make sure the system is trying to connect */ + err = WSAGetLastError(); + if (err != WSAEWOULDBLOCK && err != WSAEINPROGRESS) return err; + /* zero timeout case optimization */ + if (timeout_iszero(tm)) return IO_TIMEOUT; + /* we wait until something happens */ + err = socket_waitfd(ps, WAITFD_C, tm); + if (err == IO_CLOSED) { + int len = sizeof(err); + /* give windows time to set the error (yes, disgusting) */ + Sleep(10); + /* find out why we failed */ + getsockopt(*ps, SOL_SOCKET, SO_ERROR, (char *)&err, &len); + /* we KNOW there was an error. if 'why' is 0, we will return + * "unknown error", but it's not really our fault */ + return err > 0? err: IO_UNKNOWN; + } else return err; + +} + +/*-------------------------------------------------------------------------*\ +* Binds or returns error message +\*-------------------------------------------------------------------------*/ +int socket_bind(p_socket ps, SA *addr, socklen_t len) { + int err = IO_DONE; + socket_setblocking(ps); + if (bind(*ps, addr, len) < 0) err = WSAGetLastError(); + socket_setnonblocking(ps); + return err; +} + +/*-------------------------------------------------------------------------*\ +* +\*-------------------------------------------------------------------------*/ +int socket_listen(p_socket ps, int backlog) { + int err = IO_DONE; + socket_setblocking(ps); + if (listen(*ps, backlog) < 0) err = WSAGetLastError(); + socket_setnonblocking(ps); + return err; +} + +/*-------------------------------------------------------------------------*\ +* Accept with timeout +\*-------------------------------------------------------------------------*/ +int socket_accept(p_socket ps, p_socket pa, SA *addr, socklen_t *len, + p_timeout tm) { + if (*ps == SOCKET_INVALID) return IO_CLOSED; + for ( ;; ) { + int err; + /* try to get client socket */ + if ((*pa = accept(*ps, addr, len)) != SOCKET_INVALID) return IO_DONE; + /* find out why we failed */ + err = WSAGetLastError(); + /* if we failed because there was no connectoin, keep trying */ + if (err != WSAEWOULDBLOCK && err != WSAECONNABORTED) return err; + /* call select to avoid busy wait */ + if ((err = socket_waitfd(ps, WAITFD_R, tm)) != IO_DONE) return err; + } +} + +/*-------------------------------------------------------------------------*\ +* Send with timeout +* On windows, if you try to send 10MB, the OS will buffer EVERYTHING +* this can take an awful lot of time and we will end up blocked. +* Therefore, whoever calls this function should not pass a huge buffer. +\*-------------------------------------------------------------------------*/ +int socket_send(p_socket ps, const char *data, size_t count, + size_t *sent, p_timeout tm) +{ + int err; + *sent = 0; + /* avoid making system calls on closed sockets */ + if (*ps == SOCKET_INVALID) return IO_CLOSED; + /* loop until we send something or we give up on error */ + for ( ;; ) { + /* try to send something */ + int put = send(*ps, data, (int) count, 0); + /* if we sent something, we are done */ + if (put > 0) { + *sent = put; + return IO_DONE; + } + /* deal with failure */ + err = WSAGetLastError(); + /* we can only proceed if there was no serious error */ + if (err != WSAEWOULDBLOCK) return err; + /* avoid busy wait */ + if ((err = socket_waitfd(ps, WAITFD_W, tm)) != IO_DONE) return err; + } +} + +/*-------------------------------------------------------------------------*\ +* Sendto with timeout +\*-------------------------------------------------------------------------*/ +int socket_sendto(p_socket ps, const char *data, size_t count, size_t *sent, + SA *addr, socklen_t len, p_timeout tm) +{ + int err; + *sent = 0; + if (*ps == SOCKET_INVALID) return IO_CLOSED; + for ( ;; ) { + int put = sendto(*ps, data, (int) count, 0, addr, len); + if (put > 0) { + *sent = put; + return IO_DONE; + } + err = WSAGetLastError(); + if (err != WSAEWOULDBLOCK) return err; + if ((err = socket_waitfd(ps, WAITFD_W, tm)) != IO_DONE) return err; + } +} + +/*-------------------------------------------------------------------------*\ +* Receive with timeout +\*-------------------------------------------------------------------------*/ +int socket_recv(p_socket ps, char *data, size_t count, size_t *got, + p_timeout tm) +{ + int err, prev = IO_DONE; + *got = 0; + if (*ps == SOCKET_INVALID) return IO_CLOSED; + for ( ;; ) { + int taken = recv(*ps, data, (int) count, 0); + if (taken > 0) { + *got = taken; + return IO_DONE; + } + if (taken == 0) return IO_CLOSED; + err = WSAGetLastError(); + /* On UDP, a connreset simply means the previous send failed. + * So we try again. + * On TCP, it means our socket is now useless, so the error passes. + * (We will loop again, exiting because the same error will happen) */ + if (err != WSAEWOULDBLOCK) { + if (err != WSAECONNRESET || prev == WSAECONNRESET) return err; + prev = err; + } + if ((err = socket_waitfd(ps, WAITFD_R, tm)) != IO_DONE) return err; + } +} + +/*-------------------------------------------------------------------------*\ +* Recvfrom with timeout +\*-------------------------------------------------------------------------*/ +int socket_recvfrom(p_socket ps, char *data, size_t count, size_t *got, + SA *addr, socklen_t *len, p_timeout tm) +{ + int err, prev = IO_DONE; + *got = 0; + if (*ps == SOCKET_INVALID) return IO_CLOSED; + for ( ;; ) { + int taken = recvfrom(*ps, data, (int) count, 0, addr, len); + if (taken > 0) { + *got = taken; + return IO_DONE; + } + if (taken == 0) return IO_CLOSED; + err = WSAGetLastError(); + /* On UDP, a connreset simply means the previous send failed. + * So we try again. + * On TCP, it means our socket is now useless, so the error passes. + * (We will loop again, exiting because the same error will happen) */ + if (err != WSAEWOULDBLOCK) { + if (err != WSAECONNRESET || prev == WSAECONNRESET) return err; + prev = err; + } + if ((err = socket_waitfd(ps, WAITFD_R, tm)) != IO_DONE) return err; + } +} + +/*-------------------------------------------------------------------------*\ +* Put socket into blocking mode +\*-------------------------------------------------------------------------*/ +void socket_setblocking(p_socket ps) { + u_long argp = 0; + ioctlsocket(*ps, FIONBIO, &argp); +} + +/*-------------------------------------------------------------------------*\ +* Put socket into non-blocking mode +\*-------------------------------------------------------------------------*/ +void socket_setnonblocking(p_socket ps) { + u_long argp = 1; + ioctlsocket(*ps, FIONBIO, &argp); +} + +/*-------------------------------------------------------------------------*\ +* DNS helpers +\*-------------------------------------------------------------------------*/ +int socket_gethostbyaddr(const char *addr, socklen_t len, struct hostent **hp) { + *hp = gethostbyaddr(addr, len, AF_INET); + if (*hp) return IO_DONE; + else return WSAGetLastError(); +} + +int socket_gethostbyname(const char *addr, struct hostent **hp) { + *hp = gethostbyname(addr); + if (*hp) return IO_DONE; + else return WSAGetLastError(); +} + +/*-------------------------------------------------------------------------*\ +* Error translation functions +\*-------------------------------------------------------------------------*/ +const char *socket_hoststrerror(int err) { + if (err <= 0) return io_strerror(err); + switch (err) { + case WSAHOST_NOT_FOUND: return "host not found"; + default: return wstrerror(err); + } +} + +const char *socket_strerror(int err) { + if (err <= 0) return io_strerror(err); + switch (err) { + case WSAEADDRINUSE: return "address already in use"; + case WSAECONNREFUSED: return "connection refused"; + case WSAEISCONN: return "already connected"; + case WSAEACCES: return "permission denied"; + case WSAECONNABORTED: return "closed"; + case WSAECONNRESET: return "closed"; + case WSAETIMEDOUT: return "timeout"; + default: return wstrerror(err); + } +} + +const char *socket_ioerror(p_socket ps, int err) { + (void) ps; + return socket_strerror(err); +} + +static const char *wstrerror(int err) { + switch (err) { + case WSAEINTR: return "Interrupted function call"; + case WSAEACCES: return "Permission denied"; + case WSAEFAULT: return "Bad address"; + case WSAEINVAL: return "Invalid argument"; + case WSAEMFILE: return "Too many open files"; + case WSAEWOULDBLOCK: return "Resource temporarily unavailable"; + case WSAEINPROGRESS: return "Operation now in progress"; + case WSAEALREADY: return "Operation already in progress"; + case WSAENOTSOCK: return "Socket operation on nonsocket"; + case WSAEDESTADDRREQ: return "Destination address required"; + case WSAEMSGSIZE: return "Message too long"; + case WSAEPROTOTYPE: return "Protocol wrong type for socket"; + case WSAENOPROTOOPT: return "Bad protocol option"; + case WSAEPROTONOSUPPORT: return "Protocol not supported"; + case WSAESOCKTNOSUPPORT: return "Socket type not supported"; + case WSAEOPNOTSUPP: return "Operation not supported"; + case WSAEPFNOSUPPORT: return "Protocol family not supported"; + case WSAEAFNOSUPPORT: + return "Address family not supported by protocol family"; + case WSAEADDRINUSE: return "Address already in use"; + case WSAEADDRNOTAVAIL: return "Cannot assign requested address"; + case WSAENETDOWN: return "Network is down"; + case WSAENETUNREACH: return "Network is unreachable"; + case WSAENETRESET: return "Network dropped connection on reset"; + case WSAECONNABORTED: return "Software caused connection abort"; + case WSAECONNRESET: return "Connection reset by peer"; + case WSAENOBUFS: return "No buffer space available"; + case WSAEISCONN: return "Socket is already connected"; + case WSAENOTCONN: return "Socket is not connected"; + case WSAESHUTDOWN: return "Cannot send after socket shutdown"; + case WSAETIMEDOUT: return "Connection timed out"; + case WSAECONNREFUSED: return "Connection refused"; + case WSAEHOSTDOWN: return "Host is down"; + case WSAEHOSTUNREACH: return "No route to host"; + case WSAEPROCLIM: return "Too many processes"; + case WSASYSNOTREADY: return "Network subsystem is unavailable"; + case WSAVERNOTSUPPORTED: return "Winsock.dll version out of range"; + case WSANOTINITIALISED: + return "Successful WSAStartup not yet performed"; + case WSAEDISCON: return "Graceful shutdown in progress"; + case WSAHOST_NOT_FOUND: return "Host not found"; + case WSATRY_AGAIN: return "Nonauthoritative host not found"; + case WSANO_RECOVERY: return "Nonrecoverable name lookup error"; + case WSANO_DATA: return "Valid name, no data record of requested type"; + default: return "Unknown error"; + } +} + +const char *socket_gaistrerror(int err) { + if (err == 0) return NULL; + switch (err) { + case EAI_AGAIN: return "temporary failure in name resolution"; + case EAI_BADFLAGS: return "invalid value for ai_flags"; +#ifdef EAI_BADHINTS + case EAI_BADHINTS: return "invalid value for hints"; +#endif + case EAI_FAIL: return "non-recoverable failure in name resolution"; + case EAI_FAMILY: return "ai_family not supported"; + case EAI_MEMORY: return "memory allocation failure"; + case EAI_NONAME: + return "host or service not provided, or not known"; +#ifdef EAI_OVERFLOW + case EAI_OVERFLOW: return "argument buffer overflow"; +#endif +#ifdef EAI_PROTOCOL + case EAI_PROTOCOL: return "resolved protocol is unknown"; +#endif + case EAI_SERVICE: return "service not supported for socket type"; + case EAI_SOCKTYPE: return "ai_socktype not supported"; +#ifdef EAI_SYSTEM + case EAI_SYSTEM: return strerror(errno); +#endif + default: return gai_strerror(err); + } +} + diff --git a/src/luasec/luasocket/wsocket.h b/src/luasec/luasocket/wsocket.h new file mode 100644 index 0000000..c5a4b1c --- /dev/null +++ b/src/luasec/luasocket/wsocket.h @@ -0,0 +1,38 @@ +#ifndef WSOCKET_H +#define WSOCKET_H +/*=========================================================================*\ +* Socket compatibilization module for Win32 +* LuaSocket toolkit +\*=========================================================================*/ + +/*=========================================================================*\ +* WinSock include files +\*=========================================================================*/ +#include <winsock2.h> +#include <ws2tcpip.h> + +typedef int socklen_t; +typedef SOCKADDR_STORAGE t_sockaddr_storage; +typedef SOCKET t_socket; +typedef t_socket *p_socket; + +#define WAITFD_R 1 +#define WAITFD_W 2 +#define WAITFD_E 4 +#define WAITFD_C (WAITFD_E|WAITFD_W) + +#ifndef IPV6_V6ONLY +#define IPV6_V6ONLY 27 +#endif + +#define SOCKET_INVALID (INVALID_SOCKET) + +#ifndef SO_REUSEPORT +#define SO_REUSEPORT SO_REUSEADDR +#endif + +#ifndef AI_NUMERICSERV +#define AI_NUMERICSERV (0) +#endif + +#endif /* WSOCKET_H */ diff --git a/src/luasec/options.c b/src/luasec/options.c new file mode 100644 index 0000000..d636f7d --- /dev/null +++ b/src/luasec/options.c @@ -0,0 +1,185 @@ +/*-------------------------------------------------------------------------- + * LuaSec 1.0.2 + * + * Copyright (C) 2006-2021 Bruno Silvestre + * + *--------------------------------------------------------------------------*/ + +#include <openssl/ssl.h> + +#include "options.h" + +/* If you need to generate these options again, see options.lua */ + + +/* + OpenSSL version: OpenSSL 3.0.0-beta2 +*/ + +static lsec_ssl_option_t ssl_options[] = { +#if defined(SSL_OP_ALL) + {"all", SSL_OP_ALL}, +#endif +#if defined(SSL_OP_ALLOW_CLIENT_RENEGOTIATION) + {"allow_client_renegotiation", SSL_OP_ALLOW_CLIENT_RENEGOTIATION}, +#endif +#if defined(SSL_OP_ALLOW_NO_DHE_KEX) + {"allow_no_dhe_kex", SSL_OP_ALLOW_NO_DHE_KEX}, +#endif +#if defined(SSL_OP_ALLOW_UNSAFE_LEGACY_RENEGOTIATION) + {"allow_unsafe_legacy_renegotiation", SSL_OP_ALLOW_UNSAFE_LEGACY_RENEGOTIATION}, +#endif +#if defined(SSL_OP_CIPHER_SERVER_PREFERENCE) + {"cipher_server_preference", SSL_OP_CIPHER_SERVER_PREFERENCE}, +#endif +#if defined(SSL_OP_CISCO_ANYCONNECT) + {"cisco_anyconnect", SSL_OP_CISCO_ANYCONNECT}, +#endif +#if defined(SSL_OP_CLEANSE_PLAINTEXT) + {"cleanse_plaintext", SSL_OP_CLEANSE_PLAINTEXT}, +#endif +#if defined(SSL_OP_COOKIE_EXCHANGE) + {"cookie_exchange", SSL_OP_COOKIE_EXCHANGE}, +#endif +#if defined(SSL_OP_CRYPTOPRO_TLSEXT_BUG) + {"cryptopro_tlsext_bug", SSL_OP_CRYPTOPRO_TLSEXT_BUG}, +#endif +#if defined(SSL_OP_DISABLE_TLSEXT_CA_NAMES) + {"disable_tlsext_ca_names", SSL_OP_DISABLE_TLSEXT_CA_NAMES}, +#endif +#if defined(SSL_OP_DONT_INSERT_EMPTY_FRAGMENTS) + {"dont_insert_empty_fragments", SSL_OP_DONT_INSERT_EMPTY_FRAGMENTS}, +#endif +#if defined(SSL_OP_ENABLE_KTLS) + {"enable_ktls", SSL_OP_ENABLE_KTLS}, +#endif +#if defined(SSL_OP_ENABLE_MIDDLEBOX_COMPAT) + {"enable_middlebox_compat", SSL_OP_ENABLE_MIDDLEBOX_COMPAT}, +#endif +#if defined(SSL_OP_EPHEMERAL_RSA) + {"ephemeral_rsa", SSL_OP_EPHEMERAL_RSA}, +#endif +#if defined(SSL_OP_IGNORE_UNEXPECTED_EOF) + {"ignore_unexpected_eof", SSL_OP_IGNORE_UNEXPECTED_EOF}, +#endif +#if defined(SSL_OP_LEGACY_SERVER_CONNECT) + {"legacy_server_connect", SSL_OP_LEGACY_SERVER_CONNECT}, +#endif +#if defined(SSL_OP_MICROSOFT_BIG_SSLV3_BUFFER) + {"microsoft_big_sslv3_buffer", SSL_OP_MICROSOFT_BIG_SSLV3_BUFFER}, +#endif +#if defined(SSL_OP_MICROSOFT_SESS_ID_BUG) + {"microsoft_sess_id_bug", SSL_OP_MICROSOFT_SESS_ID_BUG}, +#endif +#if defined(SSL_OP_MSIE_SSLV2_RSA_PADDING) + {"msie_sslv2_rsa_padding", SSL_OP_MSIE_SSLV2_RSA_PADDING}, +#endif +#if defined(SSL_OP_NETSCAPE_CA_DN_BUG) + {"netscape_ca_dn_bug", SSL_OP_NETSCAPE_CA_DN_BUG}, +#endif +#if defined(SSL_OP_NETSCAPE_CHALLENGE_BUG) + {"netscape_challenge_bug", SSL_OP_NETSCAPE_CHALLENGE_BUG}, +#endif +#if defined(SSL_OP_NETSCAPE_DEMO_CIPHER_CHANGE_BUG) + {"netscape_demo_cipher_change_bug", SSL_OP_NETSCAPE_DEMO_CIPHER_CHANGE_BUG}, +#endif +#if defined(SSL_OP_NETSCAPE_REUSE_CIPHER_CHANGE_BUG) + {"netscape_reuse_cipher_change_bug", SSL_OP_NETSCAPE_REUSE_CIPHER_CHANGE_BUG}, +#endif +#if defined(SSL_OP_NO_ANTI_REPLAY) + {"no_anti_replay", SSL_OP_NO_ANTI_REPLAY}, +#endif +#if defined(SSL_OP_NO_COMPRESSION) + {"no_compression", SSL_OP_NO_COMPRESSION}, +#endif +#if defined(SSL_OP_NO_DTLS_MASK) + {"no_dtls_mask", SSL_OP_NO_DTLS_MASK}, +#endif +#if defined(SSL_OP_NO_DTLSv1) + {"no_dtlsv1", SSL_OP_NO_DTLSv1}, +#endif +#if defined(SSL_OP_NO_DTLSv1_2) + {"no_dtlsv1_2", SSL_OP_NO_DTLSv1_2}, +#endif +#if defined(SSL_OP_NO_ENCRYPT_THEN_MAC) + {"no_encrypt_then_mac", SSL_OP_NO_ENCRYPT_THEN_MAC}, +#endif +#if defined(SSL_OP_NO_EXTENDED_MASTER_SECRET) + {"no_extended_master_secret", SSL_OP_NO_EXTENDED_MASTER_SECRET}, +#endif +#if defined(SSL_OP_NO_QUERY_MTU) + {"no_query_mtu", SSL_OP_NO_QUERY_MTU}, +#endif +#if defined(SSL_OP_NO_RENEGOTIATION) + {"no_renegotiation", SSL_OP_NO_RENEGOTIATION}, +#endif +#if defined(SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION) + {"no_session_resumption_on_renegotiation", SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION}, +#endif +#if defined(SSL_OP_NO_SSL_MASK) + {"no_ssl_mask", SSL_OP_NO_SSL_MASK}, +#endif +#if defined(SSL_OP_NO_SSLv2) + {"no_sslv2", SSL_OP_NO_SSLv2}, +#endif +#if defined(SSL_OP_NO_SSLv3) + {"no_sslv3", SSL_OP_NO_SSLv3}, +#endif +#if defined(SSL_OP_NO_TICKET) + {"no_ticket", SSL_OP_NO_TICKET}, +#endif +#if defined(SSL_OP_NO_TLSv1) + {"no_tlsv1", SSL_OP_NO_TLSv1}, +#endif +#if defined(SSL_OP_NO_TLSv1_1) + {"no_tlsv1_1", SSL_OP_NO_TLSv1_1}, +#endif +#if defined(SSL_OP_NO_TLSv1_2) + {"no_tlsv1_2", SSL_OP_NO_TLSv1_2}, +#endif +#if defined(SSL_OP_NO_TLSv1_3) + {"no_tlsv1_3", SSL_OP_NO_TLSv1_3}, +#endif +#if defined(SSL_OP_PKCS1_CHECK_1) + {"pkcs1_check_1", SSL_OP_PKCS1_CHECK_1}, +#endif +#if defined(SSL_OP_PKCS1_CHECK_2) + {"pkcs1_check_2", SSL_OP_PKCS1_CHECK_2}, +#endif +#if defined(SSL_OP_PRIORITIZE_CHACHA) + {"prioritize_chacha", SSL_OP_PRIORITIZE_CHACHA}, +#endif +#if defined(SSL_OP_SAFARI_ECDHE_ECDSA_BUG) + {"safari_ecdhe_ecdsa_bug", SSL_OP_SAFARI_ECDHE_ECDSA_BUG}, +#endif +#if defined(SSL_OP_SINGLE_DH_USE) + {"single_dh_use", SSL_OP_SINGLE_DH_USE}, +#endif +#if defined(SSL_OP_SINGLE_ECDH_USE) + {"single_ecdh_use", SSL_OP_SINGLE_ECDH_USE}, +#endif +#if defined(SSL_OP_SSLEAY_080_CLIENT_DH_BUG) + {"ssleay_080_client_dh_bug", SSL_OP_SSLEAY_080_CLIENT_DH_BUG}, +#endif +#if defined(SSL_OP_SSLREF2_REUSE_CERT_TYPE_BUG) + {"sslref2_reuse_cert_type_bug", SSL_OP_SSLREF2_REUSE_CERT_TYPE_BUG}, +#endif +#if defined(SSL_OP_TLSEXT_PADDING) + {"tlsext_padding", SSL_OP_TLSEXT_PADDING}, +#endif +#if defined(SSL_OP_TLS_BLOCK_PADDING_BUG) + {"tls_block_padding_bug", SSL_OP_TLS_BLOCK_PADDING_BUG}, +#endif +#if defined(SSL_OP_TLS_D5_BUG) + {"tls_d5_bug", SSL_OP_TLS_D5_BUG}, +#endif +#if defined(SSL_OP_TLS_ROLLBACK_BUG) + {"tls_rollback_bug", SSL_OP_TLS_ROLLBACK_BUG}, +#endif + {NULL, 0L} +}; + +LSEC_API lsec_ssl_option_t* lsec_get_ssl_options() { + return ssl_options; +} + diff --git a/src/luasec/options.h b/src/luasec/options.h new file mode 100644 index 0000000..cd8bcea --- /dev/null +++ b/src/luasec/options.h @@ -0,0 +1,22 @@ +#ifndef LSEC_OPTIONS_H +#define LSEC_OPTIONS_H + +/*-------------------------------------------------------------------------- + * LuaSec 1.0.2 + * + * Copyright (C) 2006-2021 Bruno Silvestre + * + *--------------------------------------------------------------------------*/ + +#include "compat.h" + +struct lsec_ssl_option_s { + const char *name; + unsigned long code; +}; + +typedef struct lsec_ssl_option_s lsec_ssl_option_t; + +LSEC_API lsec_ssl_option_t* lsec_get_ssl_options(); + +#endif diff --git a/src/luasec/options.lua b/src/luasec/options.lua new file mode 100644 index 0000000..d9a8801 --- /dev/null +++ b/src/luasec/options.lua @@ -0,0 +1,93 @@ +local function usage() + print("Usage:") + print("* Generate options of your system:") + print(" lua options.lua -g /path/to/ssl.h [version] > options.c") + print("* Examples:") + print(" lua options.lua -g /usr/include/openssl/ssl.h > options.c\n") + print(" lua options.lua -g /usr/include/openssl/ssl.h \"OpenSSL 1.1.1f\" > options.c\n") + + print("* List options of your system:") + print(" lua options.lua -l /path/to/ssl.h\n") +end + +-- +local function printf(str, ...) + print(string.format(str, ...)) +end + +local function generate(options, version) + print([[ +/*-------------------------------------------------------------------------- + * LuaSec 1.1.1 + * + * Copyright (C) 2006-2021 Bruno Silvestre + * + *--------------------------------------------------------------------------*/ + +#include <openssl/ssl.h> + +#include "options.h" + +/* If you need to generate these options again, see options.lua */ + +]]) + + printf([[ +/* + OpenSSL version: %s +*/ +]], version) + + print([[static lsec_ssl_option_t ssl_options[] = {]]) + + for k, option in ipairs(options) do + local name = string.lower(string.sub(option, 8)) + print(string.format([[#if defined(%s)]], option)) + print(string.format([[ {"%s", %s},]], name, option)) + print([[#endif]]) + end + print([[ {NULL, 0L}]]) + print([[ +}; + +LSEC_API lsec_ssl_option_t* lsec_get_ssl_options() { + return ssl_options; +} +]]) +end + +local function loadoptions(file) + local options = {} + local f = assert(io.open(file, "r")) + for line in f:lines() do + local op = string.match(line, "define%s+(SSL_OP_BIT%()") + if not op then + op = string.match(line, "define%s+(SSL_OP_%S+)") + if op then + table.insert(options, op) + end + end + end + table.sort(options, function(a,b) return a<b end) + return options +end +-- + +local options +local flag, file, version = ... + +version = version or "Unknown" + +if not file then + usage() +elseif flag == "-g" then + options = loadoptions(file) + generate(options, version) +elseif flag == "-l" then + options = loadoptions(file) + for k, option in ipairs(options) do + print(option) + end +else + usage() +end diff --git a/src/luasec/ssl.c b/src/luasec/ssl.c new file mode 100644 index 0000000..9f83b13 --- /dev/null +++ b/src/luasec/ssl.c @@ -0,0 +1,970 @@ +/*-------------------------------------------------------------------------- + * LuaSec 1.0.2 + * + * Copyright (C) 2014-2021 Kim Alvefur, Paul Aurich, Tobias Markmann, + * Matthew Wild. + * Copyright (C) 2006-2021 Bruno Silvestre. + * + *--------------------------------------------------------------------------*/ + +#include <errno.h> +#include <string.h> + +#if defined(WIN32) +#include <winsock2.h> +#endif + +#include <openssl/ssl.h> +#include <openssl/x509v3.h> +#include <openssl/x509_vfy.h> +#include <openssl/err.h> +#include <openssl/dh.h> + +#include "../lua.h" +#include "../lauxlib.h" + +#include "luasocket/io.h" +#include "luasocket/buffer.h" +#include "luasocket/timeout.h" +#include "luasocket/socket.h" + +#include "x509.h" +#include "context.h" +#include "ssl.h" + + +#ifndef LSEC_API_OPENSSL_1_1_0 +#define SSL_is_server(s) (s->server) +#define SSL_up_ref(ssl) CRYPTO_add(&(ssl)->references, 1, CRYPTO_LOCK_SSL) +#define X509_up_ref(c) CRYPTO_add(&c->references, 1, CRYPTO_LOCK_X509) +#endif + + +/** + * Underline socket error. + */ +static int lsec_socket_error() +{ +#if defined(WIN32) + return WSAGetLastError(); +#else +#if defined(LSEC_OPENSSL_1_1_1) + // Bug in OpenSSL 1.1.1 + if (errno == 0) + return LSEC_IO_SSL; +#endif + return errno; +#endif +} + +/** + * Map error code into string. + */ +static const char *ssl_ioerror(void *ctx, int err) +{ + if (err == LSEC_IO_SSL) { + p_ssl ssl = (p_ssl) ctx; + switch(ssl->error) { + case SSL_ERROR_NONE: return "No error"; + case SSL_ERROR_ZERO_RETURN: return "closed"; + case SSL_ERROR_WANT_READ: return "wantread"; + case SSL_ERROR_WANT_WRITE: return "wantwrite"; + case SSL_ERROR_WANT_CONNECT: return "'connect' not completed"; + case SSL_ERROR_WANT_ACCEPT: return "'accept' not completed"; + case SSL_ERROR_WANT_X509_LOOKUP: return "Waiting for callback"; + case SSL_ERROR_SYSCALL: return "System error"; + case SSL_ERROR_SSL: return ERR_reason_error_string(ERR_get_error()); + default: return "Unknown SSL error"; + } + } + return socket_strerror(err); +} + +/** + * Close the connection before the GC collect the object. + */ +static int meth_destroy(lua_State *L) +{ + p_ssl ssl = (p_ssl)luaL_checkudata(L, 1, "SSL:Connection"); + if (ssl->state == LSEC_STATE_CONNECTED) { + socket_setblocking(&ssl->sock); + SSL_shutdown(ssl->ssl); + } + if (ssl->sock != SOCKET_INVALID) { + socket_destroy(&ssl->sock); + } + ssl->state = LSEC_STATE_CLOSED; + if (ssl->ssl) { + /* Clear the registries */ + luaL_getmetatable(L, "SSL:Verify:Registry"); + lua_pushlightuserdata(L, (void*)ssl->ssl); + lua_pushnil(L); + lua_settable(L, -3); + luaL_getmetatable(L, "SSL:SNI:Registry"); + lua_pushlightuserdata(L, (void*)ssl->ssl); + lua_pushnil(L); + lua_settable(L, -3); + /* Destroy the object */ + SSL_free(ssl->ssl); + ssl->ssl = NULL; + } + return 0; +} + +/** + * Perform the TLS/SSL handshake + */ +static int handshake(p_ssl ssl) +{ + int err; + p_timeout tm = timeout_markstart(&ssl->tm); + if (ssl->state == LSEC_STATE_CLOSED) + return IO_CLOSED; + for ( ; ; ) { + ERR_clear_error(); + err = SSL_do_handshake(ssl->ssl); + ssl->error = SSL_get_error(ssl->ssl, err); + switch (ssl->error) { + case SSL_ERROR_NONE: + ssl->state = LSEC_STATE_CONNECTED; + return IO_DONE; + case SSL_ERROR_WANT_READ: + err = socket_waitfd(&ssl->sock, WAITFD_R, tm); + if (err == IO_TIMEOUT) return LSEC_IO_SSL; + if (err != IO_DONE) return err; + break; + case SSL_ERROR_WANT_WRITE: + err = socket_waitfd(&ssl->sock, WAITFD_W, tm); + if (err == IO_TIMEOUT) return LSEC_IO_SSL; + if (err != IO_DONE) return err; + break; + case SSL_ERROR_SYSCALL: + if (ERR_peek_error()) { + ssl->error = SSL_ERROR_SSL; + return LSEC_IO_SSL; + } + if (err == 0) + return IO_CLOSED; + return lsec_socket_error(); + default: + return LSEC_IO_SSL; + } + } + return IO_UNKNOWN; +} + +/** + * Send data + */ +static int ssl_send(void *ctx, const char *data, size_t count, size_t *sent, + p_timeout tm) +{ + int err; + p_ssl ssl = (p_ssl)ctx; + if (ssl->state != LSEC_STATE_CONNECTED) + return IO_CLOSED; + *sent = 0; + for ( ; ; ) { + ERR_clear_error(); + err = SSL_write(ssl->ssl, data, (int)count); + ssl->error = SSL_get_error(ssl->ssl, err); + switch (ssl->error) { + case SSL_ERROR_NONE: + *sent = err; + return IO_DONE; + case SSL_ERROR_WANT_READ: + err = socket_waitfd(&ssl->sock, WAITFD_R, tm); + if (err == IO_TIMEOUT) return LSEC_IO_SSL; + if (err != IO_DONE) return err; + break; + case SSL_ERROR_WANT_WRITE: + err = socket_waitfd(&ssl->sock, WAITFD_W, tm); + if (err == IO_TIMEOUT) return LSEC_IO_SSL; + if (err != IO_DONE) return err; + break; + case SSL_ERROR_SYSCALL: + if (ERR_peek_error()) { + ssl->error = SSL_ERROR_SSL; + return LSEC_IO_SSL; + } + if (err == 0) + return IO_CLOSED; + return lsec_socket_error(); + default: + return LSEC_IO_SSL; + } + } + return IO_UNKNOWN; +} + +/** + * Receive data + */ +static int ssl_recv(void *ctx, char *data, size_t count, size_t *got, + p_timeout tm) +{ + int err; + p_ssl ssl = (p_ssl)ctx; + *got = 0; + if (ssl->state != LSEC_STATE_CONNECTED) + return IO_CLOSED; + for ( ; ; ) { + ERR_clear_error(); + err = SSL_read(ssl->ssl, data, (int)count); + ssl->error = SSL_get_error(ssl->ssl, err); + switch (ssl->error) { + case SSL_ERROR_NONE: + *got = err; + return IO_DONE; + case SSL_ERROR_ZERO_RETURN: + return IO_CLOSED; + case SSL_ERROR_WANT_READ: + err = socket_waitfd(&ssl->sock, WAITFD_R, tm); + if (err == IO_TIMEOUT) return LSEC_IO_SSL; + if (err != IO_DONE) return err; + break; + case SSL_ERROR_WANT_WRITE: + err = socket_waitfd(&ssl->sock, WAITFD_W, tm); + if (err == IO_TIMEOUT) return LSEC_IO_SSL; + if (err != IO_DONE) return err; + break; + case SSL_ERROR_SYSCALL: + if (ERR_peek_error()) { + ssl->error = SSL_ERROR_SSL; + return LSEC_IO_SSL; + } + if (err == 0) + return IO_CLOSED; + return lsec_socket_error(); + default: + return LSEC_IO_SSL; + } + } + return IO_UNKNOWN; +} + +static SSL_CTX* luaossl_testcontext(lua_State *L, int arg) { + SSL_CTX **ctx = luaL_testudata(L, arg, "SSL_CTX*"); + if (ctx) + return *ctx; + return NULL; +} + +static SSL* luaossl_testssl(lua_State *L, int arg) { + SSL **ssl = luaL_testudata(L, arg, "SSL*"); + if (ssl) + return *ssl; + return NULL; +} + +/** + * Create a new TLS/SSL object and mark it as new. + */ +static int meth_create(lua_State *L) +{ + p_ssl ssl; + int mode; + SSL_CTX *ctx; + + lua_settop(L, 1); + + ssl = (p_ssl)lua_newuserdata(L, sizeof(t_ssl)); + if (!ssl) { + lua_pushnil(L); + lua_pushstring(L, "error creating SSL object"); + return 2; + } + + if ((ctx = lsec_testcontext(L, 1))) { + mode = lsec_getmode(L, 1); + if (mode == LSEC_MODE_INVALID) { + lua_pushnil(L); + lua_pushstring(L, "invalid mode"); + return 2; + } + ssl->ssl = SSL_new(ctx); + if (!ssl->ssl) { + lua_pushnil(L); + lua_pushfstring(L, "error creating SSL object (%s)", + ERR_reason_error_string(ERR_get_error())); + return 2; + } + } else if ((ctx = luaossl_testcontext(L, 1))) { + ssl->ssl = SSL_new(ctx); + if (!ssl->ssl) { + lua_pushnil(L); + lua_pushfstring(L, "error creating SSL object (%s)", + ERR_reason_error_string(ERR_get_error())); + return 2; + } + mode = SSL_is_server(ssl->ssl) ? LSEC_MODE_SERVER : LSEC_MODE_CLIENT; + } else if ((ssl->ssl = luaossl_testssl(L, 1))) { + SSL_up_ref(ssl->ssl); + mode = SSL_is_server(ssl->ssl) ? LSEC_MODE_SERVER : LSEC_MODE_CLIENT; + } else { + return luaL_argerror(L, 1, "invalid context"); + } + ssl->state = LSEC_STATE_NEW; + SSL_set_fd(ssl->ssl, (int)SOCKET_INVALID); + SSL_set_mode(ssl->ssl, SSL_MODE_ENABLE_PARTIAL_WRITE | + SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER); + SSL_set_mode(ssl->ssl, SSL_MODE_RELEASE_BUFFERS); + if (mode == LSEC_MODE_SERVER) + SSL_set_accept_state(ssl->ssl); + else + SSL_set_connect_state(ssl->ssl); + + io_init(&ssl->io, (p_send)ssl_send, (p_recv)ssl_recv, + (p_error) ssl_ioerror, ssl); + timeout_init(&ssl->tm, -1, -1); + buffer_init(&ssl->buf, &ssl->io, &ssl->tm); + + luaL_getmetatable(L, "SSL:Connection"); + lua_setmetatable(L, -2); + return 1; +} + +/** + * Buffer send function + */ +static int meth_send(lua_State *L) { + p_ssl ssl = (p_ssl)luaL_checkudata(L, 1, "SSL:Connection"); + return buffer_meth_send(L, &ssl->buf); +} + +/** + * Buffer receive function + */ +static int meth_receive(lua_State *L) { + p_ssl ssl = (p_ssl)luaL_checkudata(L, 1, "SSL:Connection"); + return buffer_meth_receive(L, &ssl->buf); +} + +/** + * Get the buffer's statistics. + */ +static int meth_getstats(lua_State *L) { + p_ssl ssl = (p_ssl)luaL_checkudata(L, 1, "SSL:Connection"); + return buffer_meth_getstats(L, &ssl->buf); +} + +/** + * Set the buffer's statistics. + */ +static int meth_setstats(lua_State *L) { + p_ssl ssl = (p_ssl)luaL_checkudata(L, 1, "SSL:Connection"); + return buffer_meth_setstats(L, &ssl->buf); +} + +/** + * Select support methods + */ +static int meth_getfd(lua_State *L) +{ + p_ssl ssl = (p_ssl)luaL_checkudata(L, 1, "SSL:Connection"); + lua_pushnumber(L, ssl->sock); + return 1; +} + +/** + * Set the TLS/SSL file descriptor. + * Call it *before* the handshake. + */ +static int meth_setfd(lua_State *L) +{ + p_ssl ssl = (p_ssl)luaL_checkudata(L, 1, "SSL:Connection"); + if (ssl->state != LSEC_STATE_NEW) + luaL_argerror(L, 1, "invalid SSL object state"); + ssl->sock = (t_socket)luaL_checkinteger(L, 2); + socket_setnonblocking(&ssl->sock); + SSL_set_fd(ssl->ssl, (int)ssl->sock); + return 0; +} + +/** + * Lua handshake function. + */ +static int meth_handshake(lua_State *L) +{ + int err; + p_ssl ssl = (p_ssl)luaL_checkudata(L, 1, "SSL:Connection"); + p_context ctx = (p_context)SSL_CTX_get_app_data(SSL_get_SSL_CTX(ssl->ssl)); + ctx->L = L; + err = handshake(ssl); + if (ctx->dh_param) { + DH_free(ctx->dh_param); + ctx->dh_param = NULL; + } + if (ctx->alpn) { + free(ctx->alpn); + ctx->alpn = NULL; + } + if (err == IO_DONE) { + lua_pushboolean(L, 1); + return 1; + } + lua_pushboolean(L, 0); + lua_pushstring(L, ssl_ioerror((void*)ssl, err)); + return 2; +} + +/** + * Close the connection. + */ +static int meth_close(lua_State *L) +{ + meth_destroy(L); + return 0; +} + +/** + * Set timeout. + */ +static int meth_settimeout(lua_State *L) +{ + p_ssl ssl = (p_ssl)luaL_checkudata(L, 1, "SSL:Connection"); + return timeout_meth_settimeout(L, &ssl->tm); +} + +/** + * Check if there is data in the buffer. + */ +static int meth_dirty(lua_State *L) +{ + int res = 0; + p_ssl ssl = (p_ssl)luaL_checkudata(L, 1, "SSL:Connection"); + if (ssl->state != LSEC_STATE_CLOSED) + res = !buffer_isempty(&ssl->buf) || SSL_pending(ssl->ssl); + lua_pushboolean(L, res); + return 1; +} + +/** + * Return the state information about the SSL object. + */ +static int meth_want(lua_State *L) +{ + p_ssl ssl = (p_ssl)luaL_checkudata(L, 1, "SSL:Connection"); + int code = (ssl->state == LSEC_STATE_CLOSED) + ? SSL_NOTHING + : SSL_want(ssl->ssl); + switch(code) { + case SSL_NOTHING: lua_pushstring(L, "nothing"); break; + case SSL_READING: lua_pushstring(L, "read"); break; + case SSL_WRITING: lua_pushstring(L, "write"); break; + case SSL_X509_LOOKUP: lua_pushstring(L, "x509lookup"); break; + } + return 1; +} + +/** + * Return the compression method used. + */ +static int meth_compression(lua_State *L) +{ +#ifdef OPENSSL_NO_COMP + const void *comp; +#else + const COMP_METHOD *comp; +#endif + p_ssl ssl = (p_ssl)luaL_checkudata(L, 1, "SSL:Connection"); + if (ssl->state != LSEC_STATE_CONNECTED) { + lua_pushnil(L); + lua_pushstring(L, "closed"); + return 2; + } + comp = SSL_get_current_compression(ssl->ssl); + if (comp) + lua_pushstring(L, SSL_COMP_get_name(comp)); + else + lua_pushnil(L); + return 1; +} + +/** + * Return the nth certificate of the peer's chain. + */ +static int meth_getpeercertificate(lua_State *L) +{ + int n; + X509 *cert; + STACK_OF(X509) *certs; + p_ssl ssl = (p_ssl)luaL_checkudata(L, 1, "SSL:Connection"); + if (ssl->state != LSEC_STATE_CONNECTED) { + lua_pushnil(L); + lua_pushstring(L, "closed"); + return 2; + } + /* Default to the first cert */ + n = (int)luaL_optinteger(L, 2, 1); + /* This function is 1-based, but OpenSSL is 0-based */ + --n; + if (n < 0) { + lua_pushnil(L); + lua_pushliteral(L, "invalid certificate index"); + return 2; + } + if (n == 0) { + cert = SSL_get_peer_certificate(ssl->ssl); + if (cert) + lsec_pushx509(L, cert); + else + lua_pushnil(L); + return 1; + } + /* In a server-context, the stack doesn't contain the peer cert, + * so adjust accordingly. + */ + if (SSL_is_server(ssl->ssl)) + --n; + certs = SSL_get_peer_cert_chain(ssl->ssl); + if (n >= sk_X509_num(certs)) { + lua_pushnil(L); + return 1; + } + cert = sk_X509_value(certs, n); + /* Increment the reference counting of the object. */ + /* See SSL_get_peer_certificate() source code. */ + X509_up_ref(cert); + lsec_pushx509(L, cert); + return 1; +} + +/** + * Return the chain of certificate of the peer. + */ +static int meth_getpeerchain(lua_State *L) +{ + int i; + int idx = 1; + int n_certs; + X509 *cert; + STACK_OF(X509) *certs; + p_ssl ssl = (p_ssl)luaL_checkudata(L, 1, "SSL:Connection"); + if (ssl->state != LSEC_STATE_CONNECTED) { + lua_pushnil(L); + lua_pushstring(L, "closed"); + return 2; + } + lua_newtable(L); + if (SSL_is_server(ssl->ssl)) { + lsec_pushx509(L, SSL_get_peer_certificate(ssl->ssl)); + lua_rawseti(L, -2, idx++); + } + certs = SSL_get_peer_cert_chain(ssl->ssl); + n_certs = sk_X509_num(certs); + for (i = 0; i < n_certs; i++) { + cert = sk_X509_value(certs, i); + /* Increment the reference counting of the object. */ + /* See SSL_get_peer_certificate() source code. */ + X509_up_ref(cert); + lsec_pushx509(L, cert); + lua_rawseti(L, -2, idx++); + } + return 1; +} + +/** + * Copy the table src to the table dst. + */ +static void copy_error_table(lua_State *L, int src, int dst) +{ + lua_pushnil(L); + while (lua_next(L, src) != 0) { + if (lua_istable(L, -1)) { + /* Replace the table with its copy */ + lua_newtable(L); + copy_error_table(L, dst+2, dst+3); + lua_remove(L, dst+2); + } + lua_pushvalue(L, -2); + lua_pushvalue(L, -2); + lua_rawset(L, dst); + /* Remove the value and leave the key */ + lua_pop(L, 1); + } +} + +/** + * Return the verification state of the peer chain. + */ +static int meth_getpeerverification(lua_State *L) +{ + long err; + p_ssl ssl = (p_ssl)luaL_checkudata(L, 1, "SSL:Connection"); + if (ssl->state != LSEC_STATE_CONNECTED) { + lua_pushboolean(L, 0); + lua_pushstring(L, "closed"); + return 2; + } + err = SSL_get_verify_result(ssl->ssl); + if (err == X509_V_OK) { + lua_pushboolean(L, 1); + return 1; + } + luaL_getmetatable(L, "SSL:Verify:Registry"); + lua_pushlightuserdata(L, (void*)ssl->ssl); + lua_gettable(L, -2); + if (lua_isnil(L, -1)) + lua_pushstring(L, X509_verify_cert_error_string(err)); + else { + /* Copy the table of errors to avoid modifications */ + lua_newtable(L); + copy_error_table(L, lua_gettop(L)-1, lua_gettop(L)); + } + lua_pushboolean(L, 0); + lua_pushvalue(L, -2); + return 2; +} + +/** + * Get the latest "Finished" message sent out. + */ +static int meth_getfinished(lua_State *L) +{ + size_t len = 0; + char *buffer = NULL; + p_ssl ssl = (p_ssl)luaL_checkudata(L, 1, "SSL:Connection"); + if (ssl->state != LSEC_STATE_CONNECTED) { + lua_pushnil(L); + lua_pushstring(L, "closed"); + return 2; + } + if ((len = SSL_get_finished(ssl->ssl, NULL, 0)) == 0) + return 0; + buffer = (char*)malloc(len); + if (!buffer) { + lua_pushnil(L); + lua_pushstring(L, "out of memory"); + return 2; + } + SSL_get_finished(ssl->ssl, buffer, len); + lua_pushlstring(L, buffer, len); + free(buffer); + return 1; +} + +/** + * Gets the latest "Finished" message received. + */ +static int meth_getpeerfinished(lua_State *L) +{ + size_t len = 0; + char *buffer = NULL; + p_ssl ssl = (p_ssl)luaL_checkudata(L, 1, "SSL:Connection"); + if (ssl->state != LSEC_STATE_CONNECTED) { + lua_pushnil(L); + lua_pushstring(L, "closed"); + return 0; + } + if ((len = SSL_get_peer_finished(ssl->ssl, NULL, 0)) == 0) + return 0; + buffer = (char*)malloc(len); + if (!buffer) { + lua_pushnil(L); + lua_pushstring(L, "out of memory"); + return 2; + } + SSL_get_peer_finished(ssl->ssl, buffer, len); + lua_pushlstring(L, buffer, len); + free(buffer); + return 1; +} + +/** + * Object information -- tostring metamethod + */ +static int meth_tostring(lua_State *L) +{ + p_ssl ssl = (p_ssl)luaL_checkudata(L, 1, "SSL:Connection"); + lua_pushfstring(L, "SSL connection: %p%s", ssl, + ssl->state == LSEC_STATE_CLOSED ? " (closed)" : ""); + return 1; +} + +/** + * Add a method in the SSL metatable. + */ +static int meth_setmethod(lua_State *L) +{ + luaL_getmetatable(L, "SSL:Connection"); + lua_pushstring(L, "__index"); + lua_gettable(L, -2); + lua_pushvalue(L, 1); + lua_pushvalue(L, 2); + lua_settable(L, -3); + return 0; +} + +/** + * Return information about the connection. + */ +static int meth_info(lua_State *L) +{ + int bits = 0; + int algbits = 0; + char buf[256] = {0}; + const SSL_CIPHER *cipher; + p_ssl ssl = (p_ssl)luaL_checkudata(L, 1, "SSL:Connection"); + cipher = SSL_get_current_cipher(ssl->ssl); + if (!cipher) + return 0; + SSL_CIPHER_description(cipher, buf, sizeof(buf)); + bits = SSL_CIPHER_get_bits(cipher, &algbits); + lua_pushstring(L, buf); + lua_pushnumber(L, bits); + lua_pushnumber(L, algbits); + lua_pushstring(L, SSL_get_version(ssl->ssl)); + return 4; +} + +static int sni_cb(SSL *ssl, int *ad, void *arg) +{ + int strict; + SSL_CTX *newctx = NULL; + SSL_CTX *ctx = SSL_get_SSL_CTX(ssl); + lua_State *L = ((p_context)SSL_CTX_get_app_data(ctx))->L; + const char *name = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name); + /* No name, use default context */ + if (!name) + return SSL_TLSEXT_ERR_NOACK; + /* Retrieve struct from registry */ + luaL_getmetatable(L, "SSL:SNI:Registry"); + lua_pushlightuserdata(L, (void*)ssl); + lua_gettable(L, -2); + /* Strict search? */ + lua_pushstring(L, "strict"); + lua_gettable(L, -2); + strict = lua_toboolean(L, -1); + lua_pop(L, 1); + /* Search for the name in the map */ + lua_pushstring(L, "map"); + lua_gettable(L, -2); + lua_pushstring(L, name); + lua_gettable(L, -2); + if (lua_isuserdata(L, -1)) + newctx = lsec_checkcontext(L, -1); + lua_pop(L, 4); + /* Found, use this context */ + if (newctx) { + p_context pctx = (p_context)SSL_CTX_get_app_data(newctx); + pctx->L = L; + SSL_set_SSL_CTX(ssl, newctx); + return SSL_TLSEXT_ERR_OK; + } + /* Not found, but use initial context */ + if (!strict) + return SSL_TLSEXT_ERR_OK; + return SSL_TLSEXT_ERR_ALERT_FATAL; +} + +static int meth_sni(lua_State *L) +{ + int strict; + SSL_CTX *aux; + const char *name; + p_ssl ssl = (p_ssl)luaL_checkudata(L, 1, "SSL:Connection"); + SSL_CTX *ctx = SSL_get_SSL_CTX(ssl->ssl); + p_context pctx = (p_context)SSL_CTX_get_app_data(ctx); + if (pctx->mode == LSEC_MODE_CLIENT) { + name = luaL_checkstring(L, 2); + SSL_set_tlsext_host_name(ssl->ssl, name); + return 0; + } else if (pctx->mode == LSEC_MODE_SERVER) { + luaL_checktype(L, 2, LUA_TTABLE); + strict = lua_toboolean(L, 3); + /* Check if the table contains only (string -> context) */ + lua_pushnil(L); + while (lua_next(L, 2)) { + luaL_checkstring(L, -2); + aux = lsec_checkcontext(L, -1); + /* Set callback in every context */ + SSL_CTX_set_tlsext_servername_callback(aux, sni_cb); + /* leave the next key on the stack */ + lua_pop(L, 1); + } + /* Save table in the register */ + luaL_getmetatable(L, "SSL:SNI:Registry"); + lua_pushlightuserdata(L, (void*)ssl->ssl); + lua_newtable(L); + lua_pushstring(L, "map"); + lua_pushvalue(L, 2); + lua_settable(L, -3); + lua_pushstring(L, "strict"); + lua_pushboolean(L, strict); + lua_settable(L, -3); + lua_settable(L, -3); + /* Set callback in the default context */ + SSL_CTX_set_tlsext_servername_callback(ctx, sni_cb); + } + return 0; +} + +static int meth_getsniname(lua_State *L) +{ + p_ssl ssl = (p_ssl)luaL_checkudata(L, 1, "SSL:Connection"); + const char *name = SSL_get_servername(ssl->ssl, TLSEXT_NAMETYPE_host_name); + if (name) + lua_pushstring(L, name); + else + lua_pushnil(L); + return 1; +} + +static int meth_getalpn(lua_State *L) +{ + unsigned len; + const unsigned char *data; + p_ssl ssl = (p_ssl)luaL_checkudata(L, 1, "SSL:Connection"); + SSL_get0_alpn_selected(ssl->ssl, &data, &len); + if (data == NULL && len == 0) + lua_pushnil(L); + else + lua_pushlstring(L, (const char*)data, len); + return 1; +} + +static int meth_copyright(lua_State *L) +{ + lua_pushstring(L, "LuaSec 1.0.2 - Copyright (C) 2006-2021 Bruno Silvestre, UFG" +#if defined(WITH_LUASOCKET) + "\nLuaSocket 3.0-RC1 - Copyright (C) 2004-2013 Diego Nehab" +#endif + ); + return 1; +} + +#if defined(LSEC_ENABLE_DANE) +static int meth_dane(lua_State *L) +{ + int ret; + p_ssl ssl = (p_ssl)luaL_checkudata(L, 1, "SSL:Connection"); + ret = SSL_dane_enable(ssl->ssl, luaL_checkstring(L, 2)); + lua_pushboolean(L, (ret > 0)); + return 1; +} + +static int meth_tlsa(lua_State *L) +{ + int ret; + size_t len; + p_ssl ssl = (p_ssl)luaL_checkudata(L, 1, "SSL:Connection"); + uint8_t usage = (uint8_t)luaL_checkinteger(L, 2); + uint8_t selector = (uint8_t)luaL_checkinteger(L, 3); + uint8_t mtype = (uint8_t)luaL_checkinteger(L, 4); + unsigned char *data = (unsigned char*)luaL_checklstring(L, 5, &len); + + ERR_clear_error(); + ret = SSL_dane_tlsa_add(ssl->ssl, usage, selector, mtype, data, len); + lua_pushboolean(L, (ret > 0)); + + return 1; +} +#endif + +/*---------------------------------------------------------------------------*/ + +/** + * SSL methods + */ +static luaL_Reg methods[] = { + {"close", meth_close}, + {"getalpn", meth_getalpn}, + {"getfd", meth_getfd}, + {"getfinished", meth_getfinished}, + {"getpeercertificate", meth_getpeercertificate}, + {"getpeerchain", meth_getpeerchain}, + {"getpeerverification", meth_getpeerverification}, + {"getpeerfinished", meth_getpeerfinished}, + {"getsniname", meth_getsniname}, + {"getstats", meth_getstats}, + {"setstats", meth_setstats}, + {"dirty", meth_dirty}, + {"dohandshake", meth_handshake}, + {"receive", meth_receive}, + {"send", meth_send}, + {"settimeout", meth_settimeout}, + {"sni", meth_sni}, + {"want", meth_want}, +#if defined(LSEC_ENABLE_DANE) + {"setdane", meth_dane}, + {"settlsa", meth_tlsa}, +#endif + {NULL, NULL} +}; + +/** + * SSL metamethods. + */ +static luaL_Reg meta[] = { + {"__close", meth_destroy}, + {"__gc", meth_destroy}, + {"__tostring", meth_tostring}, + {NULL, NULL} +}; + +/** + * SSL functions. + */ +static luaL_Reg funcs[] = { + {"compression", meth_compression}, + {"create", meth_create}, + {"info", meth_info}, + {"setfd", meth_setfd}, + {"setmethod", meth_setmethod}, + {"copyright", meth_copyright}, + {NULL, NULL} +}; + +/** + * Initialize modules. + */ +LSEC_API int luaopen_ssl_core(lua_State *L) +{ +#ifndef LSEC_API_OPENSSL_1_1_0 + /* Initialize SSL */ + if (!SSL_library_init()) { + lua_pushstring(L, "unable to initialize SSL library"); + lua_error(L); + } + OpenSSL_add_all_algorithms(); + SSL_load_error_strings(); +#endif + +#if defined(WITH_LUASOCKET) + /* Initialize internal library */ + socket_open(); +#endif + + luaL_newmetatable(L, "SSL:SNI:Registry"); + + /* Register the functions and tables */ + luaL_newmetatable(L, "SSL:Connection"); + setfuncs(L, meta); + + luaL_newlib(L, methods); + lua_setfield(L, -2, "__index"); + + luaL_newlib(L, funcs); + lua_pushvalue(L, -1); + lua_setglobal(L, "ssl"); + + lua_pushstring(L, "SOCKET_INVALID"); + lua_pushinteger(L, SOCKET_INVALID); + lua_rawset(L, -3); + + return 1; +} + +//------------------------------------------------------------------------------ + +#if defined(_MSC_VER) + +/* Empty implementation to allow building with LuaRocks and MS compilers */ +LSEC_API int luaopen_ssl(lua_State *L) { + lua_pushstring(L, "you should not call this function"); + lua_error(L); + return 0; +} + +#endif diff --git a/src/luasec/ssl.h b/src/luasec/ssl.h new file mode 100644 index 0000000..61bd807 --- /dev/null +++ b/src/luasec/ssl.h @@ -0,0 +1,41 @@ +#ifndef LSEC_SSL_H +#define LSEC_SSL_H + +/*-------------------------------------------------------------------------- + * LuaSec 1.0.2 + * + * Copyright (C) 2006-2021 Bruno Silvestre + * + *--------------------------------------------------------------------------*/ + +#include <openssl/ssl.h> +#include "../lua.h" + +#include "luasocket/io.h" +#include "luasocket/buffer.h" +#include "luasocket/timeout.h" +#include "luasocket/socket.h" + +#include "compat.h" +#include "context.h" + +#define LSEC_STATE_NEW 1 +#define LSEC_STATE_CONNECTED 2 +#define LSEC_STATE_CLOSED 3 + +#define LSEC_IO_SSL -100 + +typedef struct t_ssl_ { + t_socket sock; + t_io io; + t_buffer buf; + t_timeout tm; + SSL *ssl; + int state; + int error; +} t_ssl; +typedef t_ssl* p_ssl; + +LSEC_API int luaopen_ssl_core(lua_State *L); + +#endif diff --git a/src/luasec/ssl.lua b/src/luasec/ssl.lua new file mode 100644 index 0000000..137b07e --- /dev/null +++ b/src/luasec/ssl.lua @@ -0,0 +1,277 @@ +------------------------------------------------------------------------------ +-- LuaSec 1.0.2 +-- +-- Copyright (C) 2006-2021 Bruno Silvestre +-- +------------------------------------------------------------------------------ + +local core = ssl + +-- We must prevent the contexts to be collected before the connections, +-- otherwise the C registry will be cleared. +local registry = setmetatable({}, {__mode="k"}) + +-- +-- +-- +local function optexec(func, param, ctx) + if param then + if type(param) == "table" then + return func(ctx, unpack(param)) + else + return func(ctx, param) + end + end + return true +end + +-- +-- Convert an array of strings to wire-format +-- +local function array2wireformat(array) + local str = "" + for k, v in ipairs(array) do + if type(v) ~= "string" then return nil end + local len = #v + if len == 0 then + return nil, "invalid ALPN name (empty string)" + elseif len > 255 then + return nil, "invalid ALPN name (length > 255)" + end + str = str .. string.char(len) .. v + end + if str == "" then return nil, "invalid ALPN list (empty)" end + return str +end + +-- +-- Convert wire-string format to array +-- +local function wireformat2array(str) + local i = 1 + local array = {} + while i < #str do + local len = str:byte(i) + array[#array + 1] = str:sub(i + 1, i + len) + i = i + len + 1 + end + return array +end + +-- +-- +-- +local function newcontext(cfg) + local succ, msg, ctx + -- Create the context + ctx, msg = context.create(cfg.protocol) + if not ctx then return nil, msg end + -- Mode + succ, msg = context.setmode(ctx, cfg.mode) + if not succ then return nil, msg end + local certificates = cfg.certificates + if not certificates then + certificates = { + { certificate = cfg.certificate, key = cfg.key, password = cfg.password } + } + end + for _, certificate in ipairs(certificates) do + -- Load the key + if certificate.key then + if certificate.password and + type(certificate.password) ~= "function" and + type(certificate.password) ~= "string" + then + return nil, "invalid password type" + end + succ, msg = context.loadkey(ctx, certificate.key, certificate.password) + if not succ then return nil, msg end + end + -- Load the certificate(s) + if certificate.certificate then + succ, msg = context.loadcert(ctx, certificate.certificate) + if not succ then return nil, msg end + if certificate.key and context.checkkey then + succ = context.checkkey(ctx) + if not succ then return nil, "private key does not match public key" end + end + end + end + -- Load the CA certificates + if cfg.cafile or cfg.capath then + succ, msg = context.locations(ctx, cfg.cafile, cfg.capath) + if not succ then return nil, msg end + end + -- Set SSL ciphers + if cfg.ciphers then + succ, msg = context.setcipher(ctx, cfg.ciphers) + if not succ then return nil, msg end + end + -- Set SSL cipher suites + if cfg.ciphersuites then + succ, msg = context.setciphersuites(ctx, cfg.ciphersuites) + if not succ then return nil, msg end + end + -- Set the verification options + succ, msg = optexec(context.setverify, cfg.verify, ctx) + if not succ then return nil, msg end + -- Set SSL options + succ, msg = optexec(context.setoptions, cfg.options, ctx) + if not succ then return nil, msg end + -- Set the depth for certificate verification + if cfg.depth then + succ, msg = context.setdepth(ctx, cfg.depth) + if not succ then return nil, msg end + end + + -- NOTE: Setting DH parameters and elliptic curves needs to come after + -- setoptions(), in case the user has specified the single_{dh,ecdh}_use + -- options. + + -- Set DH parameters + if cfg.dhparam then + if type(cfg.dhparam) ~= "function" then + return nil, "invalid DH parameter type" + end + context.setdhparam(ctx, cfg.dhparam) + end + + -- Set elliptic curves + if (not config.algorithms.ec) and (cfg.curve or cfg.curveslist) then + return false, "elliptic curves not supported" + end + if config.capabilities.curves_list and cfg.curveslist then + succ, msg = context.setcurveslist(ctx, cfg.curveslist) + if not succ then return nil, msg end + elseif cfg.curve then + succ, msg = context.setcurve(ctx, cfg.curve) + if not succ then return nil, msg end + end + + -- Set extra verification options + if cfg.verifyext and ctx.setverifyext then + succ, msg = optexec(ctx.setverifyext, cfg.verifyext, ctx) + if not succ then return nil, msg end + end + + -- ALPN + if cfg.mode == "server" and cfg.alpn then + if type(cfg.alpn) == "function" then + local alpncb = cfg.alpn + -- This callback function has to return one value only + succ, msg = context.setalpncb(ctx, function(str) + local protocols = alpncb(wireformat2array(str)) + if type(protocols) == "string" then + protocols = { protocols } + elseif type(protocols) ~= "table" then + return nil + end + return (array2wireformat(protocols)) -- use "()" to drop error message + end) + if not succ then return nil, msg end + elseif type(cfg.alpn) == "table" then + local protocols = cfg.alpn + -- check if array is valid before use it + succ, msg = array2wireformat(protocols) + if not succ then return nil, msg end + -- This callback function has to return one value only + succ, msg = context.setalpncb(ctx, function() + return (array2wireformat(protocols)) -- use "()" to drop error message + end) + if not succ then return nil, msg end + else + return nil, "invalid ALPN parameter" + end + elseif cfg.mode == "client" and cfg.alpn then + local alpn + if type(cfg.alpn) == "string" then + alpn, msg = array2wireformat({ cfg.alpn }) + elseif type(cfg.alpn) == "table" then + alpn, msg = array2wireformat(cfg.alpn) + else + return nil, "invalid ALPN parameter" + end + if not alpn then return nil, msg end + succ, msg = context.setalpn(ctx, alpn) + if not succ then return nil, msg end + end + + if config.capabilities.dane and cfg.dane then + context.setdane(ctx) + end + + return ctx +end + +-- +-- +-- +local function wrap(sock, cfg) + local ctx, msg + if type(cfg) == "table" then + ctx, msg = newcontext(cfg) + if not ctx then return nil, msg end + else + ctx = cfg + end + local s, msg = core.create(ctx) + if s then + core.setfd(s, sock:getfd()) + sock:setfd(core.SOCKET_INVALID) + registry[s] = ctx + return s + end + return nil, msg +end + +-- +-- Extract connection information. +-- +local function info(ssl, field) + local str, comp, err, protocol + comp, err = core.compression(ssl) + if err then + return comp, err + end + -- Avoid parser + if field == "compression" then + return comp + end + local info = {compression = comp} + str, info.bits, info.algbits, protocol = core.info(ssl) + if str then + info.cipher, info.protocol, info.key, + info.authentication, info.encryption, info.mac = + string.match(str, + "^(%S+)%s+(%S+)%s+Kx=(%S+)%s+Au=(%S+)%s+Enc=(%S+)%s+Mac=(%S+)") + info.export = (string.match(str, "%sexport%s*$") ~= nil) + end + if protocol then + info.protocol = protocol + end + if field then + return info[field] + end + -- Empty? + return ( (next(info)) and info ) +end + +-- +-- Set method for SSL connections. +-- +core.setmethod("info", info) + +-------------------------------------------------------------------------------- +-- Export module +-- + +local _M = { + _VERSION = "1.0.2", + _COPYRIGHT = core.copyright(), + config = config, + loadcertificate = x509.load, + newcontext = newcontext, + wrap = wrap, +} + +return _M diff --git a/src/luasec/x509.c b/src/luasec/x509.c new file mode 100644 index 0000000..84f22ee --- /dev/null +++ b/src/luasec/x509.c @@ -0,0 +1,748 @@ +/*-------------------------------------------------------------------------- + * LuaSec 1.0.2 + * + * Copyright (C) 2014-2021 Kim Alvefur, Paul Aurich, Tobias Markmann + * Matthew Wild, Bruno Silvestre. + * + *--------------------------------------------------------------------------*/ + +#include <stdio.h> +#include <string.h> + +#if defined(WIN32) +#include <ws2tcpip.h> +#include <windows.h> +#else +#include <sys/types.h> +#include <sys/socket.h> +#include <netinet/in.h> +#include <arpa/inet.h> +#endif + +#include <openssl/ssl.h> +#include <openssl/x509v3.h> +#include <openssl/evp.h> +#include <openssl/err.h> +#include <openssl/asn1.h> +#include <openssl/bio.h> +#include <openssl/bn.h> + +#include "../lua.h" +#include "../lauxlib.h" + +#include "x509.h" + + +#ifndef LSEC_API_OPENSSL_1_1_0 +#define X509_get0_notBefore X509_get_notBefore +#define X509_get0_notAfter X509_get_notAfter +#define ASN1_STRING_get0_data ASN1_STRING_data +#endif + +static const char* hex_tab = "0123456789abcdef"; + +/** + * Push the certificate on the stack. + */ +void lsec_pushx509(lua_State* L, X509 *cert) +{ + p_x509 cert_obj = (p_x509)lua_newuserdata(L, sizeof(t_x509)); + cert_obj->cert = cert; + cert_obj->encode = LSEC_AI5_STRING; + luaL_getmetatable(L, "SSL:Certificate"); + lua_setmetatable(L, -2); +} + +/** + * Return the OpenSSL certificate X509. + */ +X509* lsec_checkx509(lua_State* L, int idx) +{ + return ((p_x509)luaL_checkudata(L, idx, "SSL:Certificate"))->cert; +} + +/** + * Return LuaSec certificate X509 representation. + */ +p_x509 lsec_checkp_x509(lua_State* L, int idx) +{ + return (p_x509)luaL_checkudata(L, idx, "SSL:Certificate"); +} + +/*---------------------------------------------------------------------------*/ + +#if defined(LUASEC_INET_NTOP) +/* + * For WinXP (SP3), set the following preprocessor macros: + * LUASEC_INET_NTOP + * WINVER=0x0501 + * _WIN32_WINNT=0x0501 + * NTDDI_VERSION=0x05010300 + * + * For IPv6 addresses, you need to add IPv6 Protocol to your interface. + * + */ +static const char *inet_ntop(int af, const char *src, char *dst, socklen_t size) +{ + int addrsize; + struct sockaddr *addr; + struct sockaddr_in addr4; + struct sockaddr_in6 addr6; + + switch (af) { + case AF_INET: + memset((void*)&addr4, 0, sizeof(addr4)); + addr4.sin_family = AF_INET; + memcpy((void*)&addr4.sin_addr, src, sizeof(struct in_addr)); + addr = (struct sockaddr*)&addr4; + addrsize = sizeof(struct sockaddr_in); + break; + case AF_INET6: + memset((void*)&addr6, 0, sizeof(addr6)); + addr6.sin6_family = AF_INET6; + memcpy((void*)&addr6.sin6_addr, src, sizeof(struct in6_addr)); + addr = (struct sockaddr*)&addr6; + addrsize = sizeof(struct sockaddr_in6); + break; + default: + return NULL; + } + + if(getnameinfo(addr, addrsize, dst, size, NULL, 0, NI_NUMERICHOST) != 0) + return NULL; + return dst; +} +#endif + +/*---------------------------------------------------------------------------*/ + +/** + * Convert the buffer 'in' to hexadecimal. + */ +static void to_hex(const char* in, int length, char* out) +{ + int i; + for (i = 0; i < length; i++) { + out[i*2] = hex_tab[(in[i] >> 4) & 0xF]; + out[i*2+1] = hex_tab[(in[i]) & 0xF]; + } +} + +/** + * Converts the ASN1_OBJECT into a textual representation and put it + * on the Lua stack. + */ +static void push_asn1_objname(lua_State* L, ASN1_OBJECT *object, int no_name) +{ + char buffer[256]; + int len = OBJ_obj2txt(buffer, sizeof(buffer), object, no_name); + len = (len < sizeof(buffer)) ? len : sizeof(buffer); + lua_pushlstring(L, buffer, len); +} + +/** + * Push the ASN1 string on the stack. + */ +static void push_asn1_string(lua_State* L, ASN1_STRING *string, int encode) +{ + int len; + unsigned char *data; + if (!string) { + lua_pushnil(L); + return; + } + switch (encode) { + case LSEC_AI5_STRING: + lua_pushlstring(L, (char*)ASN1_STRING_get0_data(string), ASN1_STRING_length(string)); + break; + case LSEC_UTF8_STRING: + len = ASN1_STRING_to_UTF8(&data, string); + if (len >= 0) { + lua_pushlstring(L, (char*)data, len); + OPENSSL_free(data); + } + else + lua_pushnil(L); + } +} + +/** + * Return a human readable time. + */ +static int push_asn1_time(lua_State *L, const ASN1_UTCTIME *tm) +{ + char *tmp; + long size; + BIO *out = BIO_new(BIO_s_mem()); + ASN1_TIME_print(out, tm); + size = BIO_get_mem_data(out, &tmp); + lua_pushlstring(L, tmp, size); + BIO_free(out); + return 1; +} + +/** + * Return a human readable IP address. + */ +static void push_asn1_ip(lua_State *L, ASN1_STRING *string) +{ + int af; + char dst[INET6_ADDRSTRLEN]; + unsigned char *ip = (unsigned char*)ASN1_STRING_get0_data(string); + switch(ASN1_STRING_length(string)) { + case 4: + af = AF_INET; + break; + case 16: + af = AF_INET6; + break; + default: + lua_pushnil(L); + return; + } + if(inet_ntop(af, ip, dst, INET6_ADDRSTRLEN)) + lua_pushstring(L, dst); + else + lua_pushnil(L); +} + +/** + * + */ +static int push_subtable(lua_State* L, int idx) +{ + lua_pushvalue(L, -1); + lua_gettable(L, idx-1); + if (lua_isnil(L, -1)) { + lua_pop(L, 1); + lua_newtable(L); + lua_pushvalue(L, -2); + lua_pushvalue(L, -2); + lua_settable(L, idx-3); + lua_replace(L, -2); /* Replace key with table */ + return 1; + } + lua_replace(L, -2); /* Replace key with table */ + return 0; +} + +/** + * Retrieve the general names from the object. + */ +static int push_x509_name(lua_State* L, X509_NAME *name, int encode) +{ + int i; + int n_entries; + ASN1_OBJECT *object; + X509_NAME_ENTRY *entry; + lua_newtable(L); + n_entries = X509_NAME_entry_count(name); + for (i = 0; i < n_entries; i++) { + entry = X509_NAME_get_entry(name, i); + object = X509_NAME_ENTRY_get_object(entry); + lua_newtable(L); + push_asn1_objname(L, object, 1); + lua_setfield(L, -2, "oid"); + push_asn1_objname(L, object, 0); + lua_setfield(L, -2, "name"); + push_asn1_string(L, X509_NAME_ENTRY_get_data(entry), encode); + lua_setfield(L, -2, "value"); + lua_rawseti(L, -2, i+1); + } + return 1; +} + +/*---------------------------------------------------------------------------*/ + +/** + * Retrieve the Subject from the certificate. + */ +static int meth_subject(lua_State* L) +{ + p_x509 px = lsec_checkp_x509(L, 1); + return push_x509_name(L, X509_get_subject_name(px->cert), px->encode); +} + +/** + * Retrieve the Issuer from the certificate. + */ +static int meth_issuer(lua_State* L) +{ + p_x509 px = lsec_checkp_x509(L, 1); + return push_x509_name(L, X509_get_issuer_name(px->cert), px->encode); +} + +/** + * Retrieve the extensions from the certificate. + */ +int meth_extensions(lua_State* L) +{ + int j; + int i = -1; + int n_general_names; + OTHERNAME *otherName; + X509_EXTENSION *extension; + GENERAL_NAME *general_name; + STACK_OF(GENERAL_NAME) *values; + p_x509 px = lsec_checkp_x509(L, 1); + X509 *peer = px->cert; + + /* Return (ret) */ + lua_newtable(L); + + while ((i = X509_get_ext_by_NID(peer, NID_subject_alt_name, i)) != -1) { + extension = X509_get_ext(peer, i); + if (extension == NULL) + break; + values = X509V3_EXT_d2i(extension); + if (values == NULL) + break; + + /* Push ret[oid] */ + push_asn1_objname(L, X509_EXTENSION_get_object(extension), 1); + push_subtable(L, -2); + + /* Set ret[oid].name = name */ + push_asn1_objname(L, X509_EXTENSION_get_object(extension), 0); + lua_setfield(L, -2, "name"); + + n_general_names = sk_GENERAL_NAME_num(values); + for (j = 0; j < n_general_names; j++) { + general_name = sk_GENERAL_NAME_value(values, j); + switch (general_name->type) { + case GEN_OTHERNAME: + otherName = general_name->d.otherName; + push_asn1_objname(L, otherName->type_id, 1); + if (push_subtable(L, -2)) { + push_asn1_objname(L, otherName->type_id, 0); + lua_setfield(L, -2, "name"); + } + push_asn1_string(L, otherName->value->value.asn1_string, px->encode); + lua_rawseti(L, -2, lua_rawlen(L, -2) + 1); + lua_pop(L, 1); + break; + case GEN_DNS: + lua_pushstring(L, "dNSName"); + push_subtable(L, -2); + push_asn1_string(L, general_name->d.dNSName, px->encode); + lua_rawseti(L, -2, lua_rawlen(L, -2) + 1); + lua_pop(L, 1); + break; + case GEN_EMAIL: + lua_pushstring(L, "rfc822Name"); + push_subtable(L, -2); + push_asn1_string(L, general_name->d.rfc822Name, px->encode); + lua_rawseti(L, -2, lua_rawlen(L, -2) + 1); + lua_pop(L, 1); + break; + case GEN_URI: + lua_pushstring(L, "uniformResourceIdentifier"); + push_subtable(L, -2); + push_asn1_string(L, general_name->d.uniformResourceIdentifier, px->encode); + lua_rawseti(L, -2, lua_rawlen(L, -2)+1); + lua_pop(L, 1); + break; + case GEN_IPADD: + lua_pushstring(L, "iPAddress"); + push_subtable(L, -2); + push_asn1_ip(L, general_name->d.iPAddress); + lua_rawseti(L, -2, lua_rawlen(L, -2)+1); + lua_pop(L, 1); + break; + case GEN_X400: + /* x400Address */ + /* not supported */ + break; + case GEN_DIRNAME: + /* directoryName */ + /* not supported */ + break; + case GEN_EDIPARTY: + /* ediPartyName */ + /* not supported */ + break; + case GEN_RID: + /* registeredID */ + /* not supported */ + break; + } + GENERAL_NAME_free(general_name); + } + sk_GENERAL_NAME_free(values); + lua_pop(L, 1); /* ret[oid] */ + i++; /* Next extension */ + } + return 1; +} + +/** + * Convert the certificate to PEM format. + */ +static int meth_pem(lua_State* L) +{ + char* data; + long bytes; + X509* cert = lsec_checkx509(L, 1); + BIO *bio = BIO_new(BIO_s_mem()); + if (!PEM_write_bio_X509(bio, cert)) { + lua_pushnil(L); + return 1; + } + bytes = BIO_get_mem_data(bio, &data); + if (bytes > 0) + lua_pushlstring(L, data, bytes); + else + lua_pushnil(L); + BIO_free(bio); + return 1; +} + +/** + * Extract public key in PEM format. + */ +static int meth_pubkey(lua_State* L) +{ + char* data; + long bytes; + int ret = 1; + X509* cert = lsec_checkx509(L, 1); + BIO *bio = BIO_new(BIO_s_mem()); + EVP_PKEY *pkey = X509_get_pubkey(cert); + if(PEM_write_bio_PUBKEY(bio, pkey)) { + bytes = BIO_get_mem_data(bio, &data); + if (bytes > 0) { + lua_pushlstring(L, data, bytes); + switch(EVP_PKEY_base_id(pkey)) { + case EVP_PKEY_RSA: + lua_pushstring(L, "RSA"); + break; + case EVP_PKEY_DSA: + lua_pushstring(L, "DSA"); + break; + case EVP_PKEY_DH: + lua_pushstring(L, "DH"); + break; + case EVP_PKEY_EC: + lua_pushstring(L, "EC"); + break; + default: + lua_pushstring(L, "Unknown"); + break; + } + lua_pushinteger(L, EVP_PKEY_bits(pkey)); + ret = 3; + } + else + lua_pushnil(L); + } + else + lua_pushnil(L); + /* Cleanup */ + BIO_free(bio); + EVP_PKEY_free(pkey); + return ret; +} + +/** + * Compute the fingerprint. + */ +static int meth_digest(lua_State* L) +{ + unsigned int bytes; + const EVP_MD *digest = NULL; + unsigned char buffer[EVP_MAX_MD_SIZE]; + char hex_buffer[EVP_MAX_MD_SIZE*2]; + X509 *cert = lsec_checkx509(L, 1); + const char *str = luaL_optstring(L, 2, NULL); + if (!str) + digest = EVP_sha1(); + else { + if (!strcmp(str, "sha1")) + digest = EVP_sha1(); + else if (!strcmp(str, "sha256")) + digest = EVP_sha256(); + else if (!strcmp(str, "sha512")) + digest = EVP_sha512(); + } + if (!digest) { + lua_pushnil(L); + lua_pushfstring(L, "digest algorithm not supported (%s)", str); + return 2; + } + if (!X509_digest(cert, digest, buffer, &bytes)) { + lua_pushnil(L); + lua_pushfstring(L, "error processing the certificate (%s)", + ERR_reason_error_string(ERR_get_error())); + return 2; + } + to_hex((char*)buffer, bytes, hex_buffer); + lua_pushlstring(L, hex_buffer, bytes*2); + return 1; +} + +/** + * Check if the certificate is valid in a given time. + */ +static int meth_valid_at(lua_State* L) +{ + int nb, na; + X509* cert = lsec_checkx509(L, 1); + time_t time = luaL_checkinteger(L, 2); + nb = X509_cmp_time(X509_get0_notBefore(cert), &time); + time -= 1; + na = X509_cmp_time(X509_get0_notAfter(cert), &time); + lua_pushboolean(L, nb == -1 && na == 1); + return 1; +} + +/** + * Return the serial number. + */ +static int meth_serial(lua_State *L) +{ + char *tmp; + BIGNUM *bn; + ASN1_INTEGER *serial; + X509* cert = lsec_checkx509(L, 1); + serial = X509_get_serialNumber(cert); + bn = ASN1_INTEGER_to_BN(serial, NULL); + tmp = BN_bn2hex(bn); + lua_pushstring(L, tmp); + BN_free(bn); + OPENSSL_free(tmp); + return 1; +} + +/** + * Return not before date. + */ +static int meth_notbefore(lua_State *L) +{ + X509* cert = lsec_checkx509(L, 1); + return push_asn1_time(L, X509_get0_notBefore(cert)); +} + +/** + * Return not after date. + */ +static int meth_notafter(lua_State *L) +{ + X509* cert = lsec_checkx509(L, 1); + return push_asn1_time(L, X509_get0_notAfter(cert)); +} + +/** + * Check if this certificate issued some other certificate + */ +static int meth_issued(lua_State* L) +{ + int ret, i, len; + + X509_STORE_CTX* ctx = NULL; + X509_STORE* root = NULL; + STACK_OF(X509)* chain = NULL; + + X509* issuer = lsec_checkx509(L, 1); + X509* subject = lsec_checkx509(L, 2); + X509* cert = NULL; + + len = lua_gettop(L); + + /* Check that all arguments are certificates */ + + for (i = 3; i <= len; i++) { + lsec_checkx509(L, i); + } + + /* Before allocating things that require freeing afterwards */ + + chain = sk_X509_new_null(); + ctx = X509_STORE_CTX_new(); + root = X509_STORE_new(); + + if (ctx == NULL || root == NULL) { + lua_pushnil(L); + lua_pushstring(L, "X509_STORE_new() or X509_STORE_CTX_new() error"); + ret = 2; + goto cleanup; + } + + ret = X509_STORE_add_cert(root, issuer); + + if(!ret) { + lua_pushnil(L); + lua_pushstring(L, "X509_STORE_add_cert() error"); + ret = 2; + goto cleanup; + } + + for (i = 3; i <= len && lua_isuserdata(L, i); i++) { + cert = lsec_checkx509(L, i); + sk_X509_push(chain, cert); + } + + ret = X509_STORE_CTX_init(ctx, root, subject, chain); + + if(!ret) { + lua_pushnil(L); + lua_pushstring(L, "X509_STORE_CTX_init() error"); + ret = 2; + goto cleanup; + } + + /* Actual verification */ + if (X509_verify_cert(ctx) <= 0) { + ret = X509_STORE_CTX_get_error(ctx); + lua_pushnil(L); + lua_pushstring(L, X509_verify_cert_error_string(ret)); + ret = 2; + } else { + lua_pushboolean(L, 1); + ret = 1; + } + +cleanup: + + if (ctx != NULL) { + X509_STORE_CTX_free(ctx); + } + + if (chain != NULL) { + X509_STORE_free(root); + } + + sk_X509_free(chain); + + return ret; +} + +/** + * Collect X509 objects. + */ +static int meth_destroy(lua_State* L) +{ + p_x509 px = lsec_checkp_x509(L, 1); + if (px->cert) { + X509_free(px->cert); + px->cert = NULL; + } + return 0; +} + +static int meth_tostring(lua_State *L) +{ + X509* cert = lsec_checkx509(L, 1); + lua_pushfstring(L, "X509 certificate: %p", cert); + return 1; +} + +/** + * Set the encode for ASN.1 string. + */ +static int meth_set_encode(lua_State* L) +{ + int succ = 0; + p_x509 px = lsec_checkp_x509(L, 1); + const char *enc = luaL_checkstring(L, 2); + if (strncmp(enc, "ai5", 3) == 0) { + succ = 1; + px->encode = LSEC_AI5_STRING; + } else if (strncmp(enc, "utf8", 4) == 0) { + succ = 1; + px->encode = LSEC_UTF8_STRING; + } + lua_pushboolean(L, succ); + return 1; +} + +/** + * Get signature name. + */ +static int meth_get_signature_name(lua_State* L) +{ + p_x509 px = lsec_checkp_x509(L, 1); + int nid = X509_get_signature_nid(px->cert); + const char *name = OBJ_nid2sn(nid); + if (!name) + lua_pushnil(L); + else + lua_pushstring(L, name); + return 1; +} + +/*---------------------------------------------------------------------------*/ + +static int load_cert(lua_State* L) +{ + X509 *cert; + size_t bytes; + const char* data; + BIO *bio = BIO_new(BIO_s_mem()); + data = luaL_checklstring(L, 1, &bytes); + BIO_write(bio, data, bytes); + cert = PEM_read_bio_X509(bio, NULL, NULL, NULL); + if (cert) + lsec_pushx509(L, cert); + else + lua_pushnil(L); + BIO_free(bio); + return 1; +} + +/*---------------------------------------------------------------------------*/ + +/** + * Certificate methods. + */ +static luaL_Reg methods[] = { + {"digest", meth_digest}, + {"setencode", meth_set_encode}, + {"extensions", meth_extensions}, + {"getsignaturename", meth_get_signature_name}, + {"issuer", meth_issuer}, + {"notbefore", meth_notbefore}, + {"notafter", meth_notafter}, + {"issued", meth_issued}, + {"pem", meth_pem}, + {"pubkey", meth_pubkey}, + {"serial", meth_serial}, + {"subject", meth_subject}, + {"validat", meth_valid_at}, + {NULL, NULL} +}; + +/** + * X509 metamethods. + */ +static luaL_Reg meta[] = { + {"__close", meth_destroy}, + {"__gc", meth_destroy}, + {"__tostring", meth_tostring}, + {NULL, NULL} +}; + +/** + * X509 functions. + */ +static luaL_Reg funcs[] = { + {"load", load_cert}, + {NULL, NULL} +}; + +/*--------------------------------------------------------------------------*/ + +LSEC_API int luaopen_ssl_x509(lua_State *L) +{ + /* Register the functions and tables */ + luaL_newmetatable(L, "SSL:Certificate"); + setfuncs(L, meta); + + luaL_newlib(L, methods); + lua_setfield(L, -2, "__index"); + + luaL_newlib(L, funcs); + lua_pushvalue(L, -1); + lua_setglobal(L, "x509"); + + return 1; +} diff --git a/src/luasec/x509.h b/src/luasec/x509.h new file mode 100644 index 0000000..1109837 --- /dev/null +++ b/src/luasec/x509.h @@ -0,0 +1,31 @@ +/*-------------------------------------------------------------------------- + * LuaSec 1.0.2 + * + * Copyright (C) 2014-2021 Kim Alvefur, Paul Aurich, Tobias Markmann + * Matthew Wild, Bruno Silvestre. + * + *--------------------------------------------------------------------------*/ + +#ifndef LSEC_X509_H +#define LSEC_X509_H + +#include <openssl/x509v3.h> +#include "../lua.h" + +#include "compat.h" + +/* We do not support UniversalString nor BMPString as ASN.1 String types */ +enum { LSEC_AI5_STRING, LSEC_UTF8_STRING }; + +typedef struct t_x509_ { + X509 *cert; + int encode; +} t_x509; +typedef t_x509* p_x509; + +void lsec_pushx509(lua_State* L, X509* cert); +X509* lsec_checkx509(lua_State* L, int idx); + +LSEC_API int luaopen_ssl_x509(lua_State *L); + +#endif |