summary refs log tree commit diff stats
path: root/init.go
diff options
context:
space:
mode:
Diffstat (limited to 'init.go')
-rw-r--r--init.go35
1 files changed, 30 insertions, 5 deletions
diff --git a/init.go b/init.go
index ba3fdfc..076a7ca 100644
--- a/init.go
+++ b/init.go
@@ -1,6 +1,7 @@
 package main
 
 import (
+	"database/sql"
 	"fmt"
 	"html/template"
 	"log"
@@ -33,7 +34,7 @@ var closeLog = make(chan bool, 1)
 
 // used to transmit database pointer after
 // initialization
-var dbChan = make(chan *leveldb.DB, 1)
+var dbChan = make(chan dbase, 1)
 
 var tmpls *template.Template
 
@@ -136,6 +137,7 @@ func initConfig() {
 	confObj.Port = viper.GetInt("ListenPort")
 	confObj.LogFile = viper.GetString("LogFile")
 
+	confObj.DBType = strings.ToLower(viper.GetString("DatabaseType"))
 	confObj.DBPath = viper.GetString("DatabasePath")
 	log.Printf("Using database: %v\n", confObj.DBPath)
 
@@ -213,6 +215,7 @@ func rebindConfig() {
 	confObj.Mu.Lock()
 
 	confObj.LogFile = viper.GetString("LogFile")
+	confObj.DBType = strings.ToLower(viper.GetString("DatabaseType"))
 	confObj.DBPath = viper.GetString("DatabasePath")
 	confObj.StdoutLogging = viper.GetBool("StdoutLogging")
 	confObj.CacheInterval = viper.GetDuration("StatusFetchInterval")
@@ -235,9 +238,24 @@ func initTemplates() *template.Template {
 
 // Pull DB data into cache, if available.
 func initDatabase() {
+	var db dbase
+	var err error
+
 	confObj.Mu.RLock()
-	db, err := leveldb.OpenFile(confObj.DBPath, nil)
+	switch confObj.DBType {
+
+	case "leveldb":
+		var lvl *leveldb.DB
+		lvl, err = leveldb.OpenFile(confObj.DBPath, nil)
+		db = &dbLevel{db: lvl}
+
+	case "sqlite":
+		var lite *sql.DB
+		db = &dbSqlite{db: lite}
+
+	}
 	confObj.Mu.RUnlock()
+
 	if err != nil {
 		log.Fatalf("%v\n", err.Error())
 	}
@@ -259,11 +277,18 @@ func watchForInterrupt() {
 
 			log.Printf("\n\nCaught %v. Cleaning up ...\n", sigint)
 			confObj.Mu.RLock()
-
 			log.Printf("Closing database connection to %v...\n", confObj.DBPath)
+
 			db := <-dbChan
-			if err := db.Close(); err != nil {
-				log.Printf("%v\n", err.Error())
+
+			switch dbType := db.(type) {
+
+			case *dbLevel:
+				lvl := dbType
+				if err := lvl.db.Close(); err != nil {
+					log.Printf("%v\n", err.Error())
+				}
+
 			}
 
 			if !confObj.StdoutLogging {