summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--svc/db_test.go18
-rw-r--r--svc/init_test.go50
2 files changed, 27 insertions, 41 deletions
diff --git a/svc/db_test.go b/svc/db_test.go
index 8cd0b61..abf7626 100644
--- a/svc/db_test.go
+++ b/svc/db_test.go
@@ -9,16 +9,20 @@ import (
 
 func Test_pushpullDatabase(t *testing.T) {
 	initTestConf()
-	initDatabase()
+	initTestDB()
+
 	out, _, err := registry.GetTwtxt("https://gbmor.dev/twtxt.txt")
 	if err != nil {
 		t.Errorf("Couldn't set up test: %v\n", err)
 	}
+
 	statusmap, err := registry.ParseUserTwtxt(out, "gbmor", "https://gbmor.dev/twtxt.txt")
 	if err != nil {
 		t.Errorf("Couldn't set up test: %v\n", err)
 	}
+
 	twtxtCache.AddUser("gbmor", "https://gbmor.dev/twtxt.txt", "", net.ParseIP("127.0.0.1"), statusmap)
+
 	remoteRegistries.Mu.Lock()
 	remoteRegistries.List = append(remoteRegistries.List, "https://twtxt.tilde.institute/api/plain/users")
 	remoteRegistries.Mu.Unlock()
@@ -39,6 +43,7 @@ func Test_pushpullDatabase(t *testing.T) {
 
 	t.Run("Pulling from Database", func(t *testing.T) {
 		pullDB()
+
 		twtxtCache.Mu.RLock()
 		if _, ok := twtxtCache.Users["https://gbmor.dev/twtxt.txt"]; !ok {
 			t.Errorf("Missing user previously pushed to database\n")
@@ -49,10 +54,7 @@ func Test_pushpullDatabase(t *testing.T) {
 }
 func Benchmark_pushDatabase(b *testing.B) {
 	initTestConf()
-
-	if len(dbChan) < 1 {
-		initDatabase()
-	}
+	initTestDB()
 
 	if _, ok := twtxtCache.Users["https://gbmor.dev/twtxt.txt"]; !ok {
 		out, _, err := registry.GetTwtxt("https://gbmor.dev/twtxt.txt")
@@ -79,10 +81,8 @@ func Benchmark_pushDatabase(b *testing.B) {
 }
 func Benchmark_pullDatabase(b *testing.B) {
 	initTestConf()
-
-	if len(dbChan) < 1 {
-		initDatabase()
-	}
+	initTestDB()
+	b.ResetTimer()
 
 	for i := 0; i < b.N; i++ {
 		pullDB()
diff --git a/svc/init_test.go b/svc/init_test.go
index 9c99a3b..cac2fe6 100644
--- a/svc/init_test.go
+++ b/svc/init_test.go
@@ -9,23 +9,35 @@ import (
 	"sync"
 	"time"
 
-	"github.com/fsnotify/fsnotify"
 	"github.com/spf13/viper"
 )
 
-var testport = fmt.Sprintf(":%v", confObj.Port)
-var hasInit = false
-
-var initTestOnce sync.Once
+var (
+	testport     string
+	initTestOnce sync.Once
+	initDBOnce   sync.Once
+)
 
 func initTestConf() {
 	initTestOnce.Do(func() {
+
 		testConfig()
 		tmpls = testTemplates()
+
+		confObj.Mu.RLock()
+		testport = fmt.Sprintf(":%v", confObj.Port)
+		confObj.Mu.RUnlock()
+
 		logToNull()
 	})
 }
 
+func initTestDB() {
+	initDBOnce.Do(func() {
+		initDatabase()
+	})
+}
+
 func logToNull() {
 	hush, err := os.Open("/dev/null")
 	if err != nil {
@@ -44,27 +56,12 @@ func testConfig() {
 	viper.SetConfigType("yml")
 	viper.AddConfigPath("..")
 
-	log.Printf("Loading configuration ...\n")
-	if err := viper.ReadInConfig(); err != nil {
-		log.Printf("%v\n", err.Error())
-		log.Printf("Using defaults ...\n")
-	} else {
-		viper.WatchConfig()
-		viper.OnConfigChange(func(e fsnotify.Event) {
-			log.Printf("Config file change detected. Reloading...\n")
-			rebindConfig()
-		})
-	}
-
 	viper.SetDefault("ListenPort", 9001)
-	viper.SetDefault("LogFile", "getwtxt.log")
 	viper.SetDefault("DatabasePath", "getwtxt.db")
 	viper.SetDefault("AssetsDirectory", "assets")
 	viper.SetDefault("DatabaseType", "leveldb")
-	viper.SetDefault("StdoutLogging", false)
 	viper.SetDefault("ReCacheInterval", "1h")
 	viper.SetDefault("DatabasePushInterval", "5m")
-
 	viper.SetDefault("Instance.SiteName", "getwtxt")
 	viper.SetDefault("Instance.OwnerName", "Anonymous Microblogger")
 	viper.SetDefault("Instance.Email", "nobody@knows")
@@ -74,25 +71,14 @@ func testConfig() {
 	confObj.Mu.Lock()
 
 	confObj.Port = viper.GetInt("ListenPort")
-	confObj.LogFile = viper.GetString("LogFile")
+	confObj.AssetsDir = "../" + viper.GetString("AssetsDirectory")
 
 	confObj.DBType = strings.ToLower(viper.GetString("DatabaseType"))
-
 	confObj.DBPath = viper.GetString("DatabasePath")
 	log.Printf("Using %v database: %v\n", confObj.DBType, confObj.DBPath)
 
-	confObj.AssetsDir = "../" + viper.GetString("AssetsDirectory")
-
-	confObj.StdoutLogging = viper.GetBool("StdoutLogging")
-	if confObj.StdoutLogging {
-		log.Printf("Logging to stdout\n")
-	} else {
-		log.Printf("Logging to %v\n", confObj.LogFile)
-	}
-
 	confObj.CacheInterval = viper.GetDuration("StatusFetchInterval")
 	log.Printf("User status fetch interval: %v\n", confObj.CacheInterval)
-
 	confObj.DBInterval = viper.GetDuration("DatabasePushInterval")
 	log.Printf("Database push interval: %v\n", confObj.DBInterval)