diff options
Diffstat (limited to 'init.go')
-rw-r--r-- | init.go | 35 |
1 files changed, 30 insertions, 5 deletions
diff --git a/init.go b/init.go index ba3fdfc..076a7ca 100644 --- a/init.go +++ b/init.go @@ -1,6 +1,7 @@ package main import ( + "database/sql" "fmt" "html/template" "log" @@ -33,7 +34,7 @@ var closeLog = make(chan bool, 1) // used to transmit database pointer after // initialization -var dbChan = make(chan *leveldb.DB, 1) +var dbChan = make(chan dbase, 1) var tmpls *template.Template @@ -136,6 +137,7 @@ func initConfig() { confObj.Port = viper.GetInt("ListenPort") confObj.LogFile = viper.GetString("LogFile") + confObj.DBType = strings.ToLower(viper.GetString("DatabaseType")) confObj.DBPath = viper.GetString("DatabasePath") log.Printf("Using database: %v\n", confObj.DBPath) @@ -213,6 +215,7 @@ func rebindConfig() { confObj.Mu.Lock() confObj.LogFile = viper.GetString("LogFile") + confObj.DBType = strings.ToLower(viper.GetString("DatabaseType")) confObj.DBPath = viper.GetString("DatabasePath") confObj.StdoutLogging = viper.GetBool("StdoutLogging") confObj.CacheInterval = viper.GetDuration("StatusFetchInterval") @@ -235,9 +238,24 @@ func initTemplates() *template.Template { // Pull DB data into cache, if available. func initDatabase() { + var db dbase + var err error + confObj.Mu.RLock() - db, err := leveldb.OpenFile(confObj.DBPath, nil) + switch confObj.DBType { + + case "leveldb": + var lvl *leveldb.DB + lvl, err = leveldb.OpenFile(confObj.DBPath, nil) + db = &dbLevel{db: lvl} + + case "sqlite": + var lite *sql.DB + db = &dbSqlite{db: lite} + + } confObj.Mu.RUnlock() + if err != nil { log.Fatalf("%v\n", err.Error()) } @@ -259,11 +277,18 @@ func watchForInterrupt() { log.Printf("\n\nCaught %v. Cleaning up ...\n", sigint) confObj.Mu.RLock() - log.Printf("Closing database connection to %v...\n", confObj.DBPath) + db := <-dbChan - if err := db.Close(); err != nil { - log.Printf("%v\n", err.Error()) + + switch dbType := db.(type) { + + case *dbLevel: + lvl := dbType + if err := lvl.db.Close(); err != nil { + log.Printf("%v\n", err.Error()) + } + } if !confObj.StdoutLogging { |