summary refs log blame commit diff stats
path: root/cache_test.go
blob: 00f0ddf7d17619c8bc59db0d536070b7bdb59e81 (plain) (tree)
1
2
3
4


            
                 



























                                                                                           
package main

import (
	"testing"
)

func Test_refreshCache(t *testing.T) {
	initTestConf()
	confObj.Mu.RLock()
	prevtime := confObj.LastCache
	confObj.Mu.RUnlock()

	t.Run("Cache Time Check", func(t *testing.T) {
		refreshCache()
		confObj.Mu.RLock()
		newtime := confObj.LastCache
		confObj.Mu.RUnlock()

		if !newtime.After(prevtime) || newtime == prevtime {
			t.Errorf("Cache time did not update, check refreshCache() logic\n")
		}
	})
}

func Benchmark_refreshCache(b *testing.B) {
	initTestConf()
	b.ResetTimer()

	for i := 0; i < b.N; i++ {
		refreshCache()
	}
}
ght: bold } /* Name.Exception */ .highlight .nf { color: #0066bb; font-weight: bold } /* Name.Function */ .highlight .nl { color: #336699; font-style: italic } /* Name.Label */ .highlight .nn { color: #bb0066; font-weight: bold } /* Name.Namespace */ .highlight .py { color: #336699; font-weight: bold } /* Name.Property */ .highlight .nt { color: #bb0066; font-weight: bold } /* Name.Tag */ .highlight .nv { color: #336699 } /* Name.Variable */ .highlight .ow { color: #008800 } /* Operator.Word */ .highlight .w { color: #bbbbbb } /* Text.Whitespace */ .highlight .mb { color: #0000DD; font-weight: bold } /* Literal.Number.Bin */ .highlight .mf { color: #0000DD; font-weight: bold } /* Literal.Number.Float */ .highlight .mh { color: #0000DD; font-weight: bold } /* Literal.Number.Hex */ .highlight .mi { color: #0000DD; font-weight: bold } /* Literal.Number.Integer */ .highlight .mo { color: #0000DD; font-weight: bold } /* Literal.Number.Oct */ .highlight .sa { color: #dd2200; background-color: #fff0f0 } /* Literal.String.Affix */ .highlight .sb { color: #dd2200; background-color: #fff0f0 } /* Literal.String.Backtick */ .highlight .sc { color: #dd2200; background-color: #fff0f0 } /* Literal.String.Char */ .highlight .dl { color: #dd2200; background-color: #fff0f0 } /* Literal.String.Delimiter */ .highlight .sd { color: #dd2200; background-color: #fff0f0 } /* Literal.String.Doc */ .highlight .s2 { color: #dd2200; background-color: #fff0f0 } /* Literal.String.Double */ .highlight .se { color: #0044dd; background-color: #fff0f0 } /* Literal.String.Escape */ .highlight .sh { color: #dd2200; background-color: #fff0f0 } /* Literal.String.Heredoc */ .highlight .si { color: #3333bb; background-color: #fff0f0 } /* Literal.String.Interpol */ .highlight .sx { color: #22bb22; background-color: #f0fff0 } /* Literal.String.Other */ .highlight .sr { color: #008800; background-color: #fff0ff } /* Literal.String.Regex */ .highlight .s1 { color: #dd2200; background-color: #fff0f0 } /* Literal.String.Single */ .highlight .ss { color: #aa6600; background-color: #fff0f0 } /* Literal.String.Symbol */ .highlight .bp { color: #003388 } /* Name.Builtin.Pseudo */ .highlight .fm { color: #0066bb; font-weight: bold } /* Name.Function.Magic */ .highlight .vc { color: #336699 } /* Name.Variable.Class */ .highlight .vg { color: #dd7700 } /* Name.Variable.Global */ .highlight .vi { color: #3333bb } /* Name.Variable.Instance */ .highlight .vm { color: #336699 } /* Name.Variable.Magic */ .highlight .il { color: #0000DD; font-weight: bold } /* Literal.Number.Integer.Long */
------------------------------------------------------------------------------
-- LuaSec 1.0.2
--
-- Copyright (C) 2006-2021 Bruno Silvestre
--
------------------------------------------------------------------------------

local core    = ssl

-- We must prevent the contexts to be collected before the connections,
-- otherwise the C registry will be cleared.
local registry = setmetatable({}, {__mode="k"})

--
--
--
local function optexec(func, param, ctx)
  if param then
    if type(param) == "table" then
      return func(ctx, unpack(param))
    else
      return func(ctx, param)
    end
  end
  return true
end

--
-- Convert an array of strings to wire-format
--
local function array2wireformat(array)
   local str = ""
   for k, v in ipairs(array) do
      if type(v) ~= "string" then return nil end
      local len = #v
      if len == 0 then
        return nil, "invalid ALPN name (empty string)"
      elseif len > 255 then
        return nil, "invalid ALPN name (length > 255)"
      end
      str = str .. string.char(len) .. v
   end
   if str == "" then return nil, "invalid ALPN list (empty)" end
   return str
end

--
-- Convert wire-string format to array
--
local function wireformat2array(str)
   local i = 1
   local array = {}
   while i < #str do
      local len = str:byte(i)
      array[#array + 1] = str:sub(i + 1, i + len)
      i = i + len + 1
   end
   return array
end

--
--
--
local function newcontext(cfg)
   local succ, msg, ctx
   -- Create the context
   ctx, msg = context.create(cfg.protocol)
   if not ctx then return nil, msg end
   -- Mode
   succ, msg = context.setmode(ctx, cfg.mode)
   if not succ then return nil, msg end
   local certificates = cfg.certificates
   if not certificates then
      certificates = {
         { certificate = cfg.certificate, key = cfg.key, password = cfg.password }
      }
   end
   for _, certificate in ipairs(certificates) do
      -- Load the key
      if certificate.key then
         if certificate.password and
            type(certificate.password) ~= "function" and
            type(certificate.password) ~= "string"
         then
            return nil, "invalid password type"
         end
         succ, msg = context.loadkey(ctx, certificate.key, certificate.password)
         if not succ then return nil, msg end
      end
      -- Load the certificate(s)
      if certificate.certificate then
        succ, msg = context.loadcert(ctx, certificate.certificate)
        if not succ then return nil, msg end
        if certificate.key and context.checkkey then
          succ = context.checkkey(ctx)
          if not succ then return nil, "private key does not match public key" end
        end
      end
   end
   -- Load the CA certificates
   if cfg.cafile or cfg.capath then
      succ, msg = context.locations(ctx, cfg.cafile, cfg.capath)
      if not succ then return nil, msg end
   end
   -- Set SSL ciphers
   if cfg.ciphers then
      succ, msg = context.setcipher(ctx, cfg.ciphers)
      if not succ then return nil, msg end
   end
   -- Set SSL cipher suites
   if cfg.ciphersuites then
      succ, msg = context.setciphersuites(ctx, cfg.ciphersuites)
      if not succ then return nil, msg end
   end
    -- Set the verification options
   succ, msg = optexec(context.setverify, cfg.verify, ctx)
   if not succ then return nil, msg end
   -- Set SSL options
   succ, msg = optexec(context.setoptions, cfg.options, ctx)
   if not succ then return nil, msg end
   -- Set the depth for certificate verification
   if cfg.depth then
      succ, msg = context.setdepth(ctx, cfg.depth)
      if not succ then return nil, msg end
   end

   -- NOTE: Setting DH parameters and elliptic curves needs to come after
   -- setoptions(), in case the user has specified the single_{dh,ecdh}_use
   -- options.

   -- Set DH parameters
   if cfg.dhparam then
      if type(cfg.dhparam) ~= "function" then
         return nil, "invalid DH parameter type"
      end
      context.setdhparam(ctx, cfg.dhparam)
   end
   
   -- Set elliptic curves
   if (not config.algorithms.ec) and (cfg.curve or cfg.curveslist) then
     return false, "elliptic curves not supported"
   end
   if config.capabilities.curves_list and cfg.curveslist then
     succ, msg = context.setcurveslist(ctx, cfg.curveslist)
     if not succ then return nil, msg end
   elseif cfg.curve then
     succ, msg = context.setcurve(ctx, cfg.curve)
     if not succ then return nil, msg end
   end

   -- Set extra verification options
   if cfg.verifyext and ctx.setverifyext then
      succ, msg = optexec(ctx.setverifyext, cfg.verifyext, ctx)
      if not succ then return nil, msg end
   end

   -- ALPN
   if cfg.mode == "server" and cfg.alpn then
      if type(cfg.alpn) == "function" then
         local alpncb = cfg.alpn
         -- This callback function has to return one value only
         succ, msg = context.setalpncb(ctx, function(str)
            local protocols = alpncb(wireformat2array(str))
            if type(protocols) == "string" then
               protocols = { protocols }
            elseif type(protocols) ~= "table" then
               return nil
            end
            return (array2wireformat(protocols))    -- use "()" to drop error message
         end)
         if not succ then return nil, msg end
      elseif type(cfg.alpn) == "table" then
         local protocols = cfg.alpn
         -- check if array is valid before use it
         succ, msg = array2wireformat(protocols)
         if not succ then return nil, msg end
         -- This callback function has to return one value only
         succ, msg = context.setalpncb(ctx, function()
            return (array2wireformat(protocols))    -- use "()" to drop error message
         end)
         if not succ then return nil, msg end
      else
         return nil, "invalid ALPN parameter"
      end
   elseif cfg.mode == "client" and cfg.alpn then
      local alpn
      if type(cfg.alpn) == "string" then
         alpn, msg = array2wireformat({ cfg.alpn })
      elseif type(cfg.alpn) == "table" then
         alpn, msg = array2wireformat(cfg.alpn)
      else
         return nil, "invalid ALPN parameter"
      end
      if not alpn then return nil, msg end
      succ, msg = context.setalpn(ctx, alpn)
      if not succ then return nil, msg end
   end

   if config.capabilities.dane and cfg.dane then
      context.setdane(ctx)
   end

   return ctx
end

--
--
--
local function wrap(sock, cfg)
   local ctx, msg
   if type(cfg) == "table" then
      ctx, msg = newcontext(cfg)
      if not ctx then return nil, msg end
   else
      ctx = cfg
   end
   local s, msg = core.create(ctx)
   if s then
      core.setfd(s, sock:getfd())
      sock:setfd(core.SOCKET_INVALID)
      registry[s] = ctx
      return s
   end
   return nil, msg 
end

--
-- Extract connection information.
--
local function info(ssl, field)
  local str, comp, err, protocol
  comp, err = core.compression(ssl)
  if err then
    return comp, err
  end
  -- Avoid parser
  if field == "compression" then
    return comp
  end
  local info = {compression = comp}
  str, info.bits, info.algbits, protocol = core.info(ssl)
  if str then
    info.cipher, info.protocol, info.key,
    info.authentication, info.encryption, info.mac =
        string.match(str, 
          "^(%S+)%s+(%S+)%s+Kx=(%S+)%s+Au=(%S+)%s+Enc=(%S+)%s+Mac=(%S+)")
    info.export = (string.match(str, "%sexport%s*$") ~= nil)
  end
  if protocol then
    info.protocol = protocol
  end
  if field then
    return info[field]
  end
  -- Empty?
  return ( (next(info)) and info )
end

--
-- Set method for SSL connections.
--
core.setmethod("info", info)

--------------------------------------------------------------------------------
-- Export module
--

local _M = {
  _VERSION        = "1.0.2",
  _COPYRIGHT      = core.copyright(),
  config          = config,
  loadcertificate = x509.load,
  newcontext      = newcontext,
  wrap            = wrap,
}

return _M