summary refs log tree commit diff stats
diff options
context:
space:
mode:
authorBen Morrison <ben@gbmor.dev>2019-06-13 22:52:36 -0400
committerBen Morrison <ben@gbmor.dev>2019-06-13 22:52:36 -0400
commitd20ad73467c0636edea6133d7bd4b38bbfdefc55 (patch)
tree9e4b3fa8b44fb18dfe3f78eed9d27da85a4e5a41
parent5d21ab1d75b6a4ef47a98f7be2e326bbd99bba51 (diff)
downloadgetwtxt-d20ad73467c0636edea6133d7bd4b38bbfdefc55.tar.gz
mutex and readability cleanup
-rw-r--r--svc/conf.go10
-rw-r--r--svc/handlers.go5
-rw-r--r--svc/leveldb.go16
-rw-r--r--svc/query.go6
-rw-r--r--svc/sqlite.go14
-rw-r--r--svc/svc.go16
6 files changed, 22 insertions, 45 deletions
diff --git a/svc/conf.go b/svc/conf.go
index c1b9abb..e80234f 100644
--- a/svc/conf.go
+++ b/svc/conf.go
@@ -74,8 +74,9 @@ func initConfig() {
 // to the default logger, and the same for the
 // request logger.
 func initLogging() {
-
 	confObj.Mu.RLock()
+	defer confObj.Mu.RUnlock()
+
 	if confObj.StdoutLogging {
 		log.SetOutput(os.Stdout)
 		reqLog = log.New(os.Stdout, "", log.LstdFlags)
@@ -100,7 +101,6 @@ func initLogging() {
 		log.SetOutput(msgLog)
 		reqLog = log.New(reqLogFile, "", log.LstdFlags)
 	}
-	confObj.Mu.RUnlock()
 }
 
 // Default values should a config file
@@ -189,14 +189,15 @@ func bindConfig() {
 	if *flagAssets != "" {
 		confObj.AssetsDir = *flagAssets
 	}
-	confObj.Mu.Unlock()
 
+	confObj.Mu.Unlock()
 	announceConfig()
-
 }
 
 func announceConfig() {
 	confObj.Mu.RLock()
+	defer confObj.Mu.RUnlock()
+
 	if confObj.IsProxied {
 		log.Printf("Behind reverse proxy, not using host matching\n")
 	} else {
@@ -216,5 +217,4 @@ func announceConfig() {
 	log.Printf("Using %v database: %v\n", confObj.DBType, confObj.DBPath)
 	log.Printf("Database push interval: %v\n", confObj.DBInterval)
 	log.Printf("User status fetch interval: %v\n", confObj.CacheInterval)
-	confObj.Mu.RUnlock()
 }
diff --git a/svc/handlers.go b/svc/handlers.go
index 924b487..973bd62 100644
--- a/svc/handlers.go
+++ b/svc/handlers.go
@@ -20,14 +20,14 @@ func getEtag(modtime time.Time) string {
 
 func servStatic(w http.ResponseWriter, isCSS bool) error {
 	pingAssets()
-
 	staticCache.mu.RLock()
+	defer staticCache.mu.RUnlock()
+
 	if isCSS {
 		etag := getEtag(staticCache.cssMod)
 		w.Header().Set("ETag", "\""+etag+"\"")
 		w.Header().Set("Content-Type", cssutf8)
 		_, err := w.Write(staticCache.css)
-		staticCache.mu.RUnlock()
 		return err
 	}
 
@@ -35,7 +35,6 @@ func servStatic(w http.ResponseWriter, isCSS bool) error {
 	w.Header().Set("ETag", "\""+etag+"\"")
 	w.Header().Set("Content-Type", htmlutf8)
 	_, err := w.Write(staticCache.index)
-	staticCache.mu.RUnlock()
 	return err
 }
 
diff --git a/svc/leveldb.go b/svc/leveldb.go
index 7446dad..91d7ca9 100644
--- a/svc/leveldb.go
+++ b/svc/leveldb.go
@@ -15,9 +15,10 @@ type dbLevel struct {
 
 func (lvl *dbLevel) push() error {
 	twtxtCache.Mu.RLock()
+	defer twtxtCache.Mu.RUnlock()
+
 	var dbBasket = &leveldb.Batch{}
 	for k, v := range twtxtCache.Users {
-
 		dbBasket.Put([]byte(k+"*Nick"), []byte(v.Nick))
 		dbBasket.Put([]byte(k+"*URL"), []byte(v.URL))
 		dbBasket.Put([]byte(k+"*IP"), []byte(v.IP.String()))
@@ -29,7 +30,6 @@ func (lvl *dbLevel) push() error {
 			dbBasket.Put([]byte(k+"*Status*"+rfc), []byte(e))
 		}
 	}
-	twtxtCache.Mu.RUnlock()
 
 	for k, v := range remoteRegistries.List {
 		dbBasket.Put([]byte("remote*"+string(k)), []byte(v))
@@ -40,6 +40,8 @@ func (lvl *dbLevel) push() error {
 
 func (lvl *dbLevel) pull() {
 	iter := lvl.db.NewIterator(nil, nil)
+	twtxtCache.Mu.Lock()
+	defer twtxtCache.Mu.Unlock()
 
 	for iter.Next() {
 		key := string(iter.Key())
@@ -54,13 +56,9 @@ func (lvl *dbLevel) pull() {
 		}
 
 		data := registry.NewUser()
-		twtxtCache.Mu.RLock()
 		if _, ok := twtxtCache.Users[urls]; ok {
-			twtxtCache.Users[urls].Mu.RLock()
 			data = twtxtCache.Users[urls]
-			twtxtCache.Users[urls].Mu.RUnlock()
 		}
-		twtxtCache.Mu.RUnlock()
 
 		data.Mu.Lock()
 		switch field {
@@ -79,15 +77,11 @@ func (lvl *dbLevel) pull() {
 			errLog("", err)
 			data.Status[thetime] = val
 		}
-		data.Mu.Unlock()
-
-		twtxtCache.Mu.Lock()
 		twtxtCache.Users[urls] = data
-		twtxtCache.Mu.Unlock()
+		data.Mu.Unlock()
 	}
 
 	remoteRegistries.List = dedupe(remoteRegistries.List)
-
 	iter.Release()
 	errLog("Error while pulling DB into registry cache: ", iter.Error())
 }
diff --git a/svc/query.go b/svc/query.go
index 25f0f44..6122850 100644
--- a/svc/query.go
+++ b/svc/query.go
@@ -43,7 +43,6 @@ func parseQueryOut(out []string) []byte {
 
 	for i, e := range out {
 		data = append(data, []byte(e)...)
-
 		if !strings.HasSuffix(e, "\n") && i != len(out)-1 {
 			data = append(data, byte('\n'))
 		}
@@ -86,7 +85,6 @@ func apiEndpointQuery(w http.ResponseWriter, r *http.Request) error {
 			out2, err = twtxtCache.QueryUser(urls)
 			apiErrCheck(err, r)
 		}
-
 		if query != "" && urls != "" {
 			out = joinQueryOuts(out2)
 		}
@@ -108,11 +106,10 @@ func apiEndpointQuery(w http.ResponseWriter, r *http.Request) error {
 
 	out = registry.ReduceToPage(page, out)
 	data := parseQueryOut(out)
-
 	etag := fmt.Sprintf("%x", sha256.Sum256(data))
+
 	w.Header().Set("ETag", etag)
 	w.Header().Set("Content-Type", txtutf8)
-
 	_, err = w.Write(data)
 
 	return err
@@ -123,7 +120,6 @@ func joinQueryOuts(data ...[]string) []string {
 	for _, e := range data {
 		single = append(single, e...)
 	}
-
 	return dedupe(single)
 }
 
diff --git a/svc/sqlite.go b/svc/sqlite.go
index b10e5ae..8a7b1a4 100644
--- a/svc/sqlite.go
+++ b/svc/sqlite.go
@@ -16,7 +16,6 @@ type dbSqlite struct {
 }
 
 func initSqlite() *dbSqlite {
-
 	confObj.Mu.RLock()
 	dbpath := confObj.DBPath
 	confObj.Mu.RUnlock()
@@ -52,6 +51,8 @@ func (lite *dbSqlite) push() error {
 	txst := tx.Stmt(lite.pushStmt)
 
 	twtxtCache.Mu.RLock()
+	defer twtxtCache.Mu.RUnlock()
+
 	for i, e := range twtxtCache.Users {
 		e.Mu.RLock()
 
@@ -68,10 +69,8 @@ func (lite *dbSqlite) push() error {
 			_, err = txst.Exec(i, true, k.Format(time.RFC3339), v)
 			errLog("", err)
 		}
-
 		e.Mu.RUnlock()
 	}
-	twtxtCache.Mu.RUnlock()
 
 	for _, e := range remoteRegistries.List {
 		_, err = txst.Exec(e, false, "REMOTE REGISTRY", "NULL")
@@ -83,13 +82,11 @@ func (lite *dbSqlite) push() error {
 		errLog("", tx.Rollback())
 		return err
 	}
-
 	return nil
 }
 
 func (lite *dbSqlite) pull() {
 	errLog("Error pinging sqlite DB: ", lite.db.Ping())
-
 	rows, err := lite.pullStmt.Query()
 	errLog("", err)
 
@@ -98,6 +95,8 @@ func (lite *dbSqlite) pull() {
 	}(rows)
 
 	twtxtCache.Mu.Lock()
+	defer twtxtCache.Mu.Unlock()
+
 	for rows.Next() {
 		var uid int
 		var urls string
@@ -106,7 +105,6 @@ func (lite *dbSqlite) pull() {
 		var dBlob []byte
 
 		errLog("", rows.Scan(&uid, &urls, &isUser, &dataKey, &dBlob))
-
 		if !isUser {
 			remoteRegistries.List = append(remoteRegistries.List, urls)
 			continue
@@ -117,7 +115,6 @@ func (lite *dbSqlite) pull() {
 			user = twtxtCache.Users[urls]
 		}
 		user.Mu.Lock()
-
 		switch dataKey {
 		case "nickname":
 			user.Nick = string(dBlob)
@@ -132,11 +129,8 @@ func (lite *dbSqlite) pull() {
 			errLog("While pulling statuses from SQLite: ", err)
 			user.Status[thetime] = string(dBlob)
 		}
-
 		twtxtCache.Users[urls] = user
 		user.Mu.Unlock()
 	}
-	twtxtCache.Mu.Unlock()
-
 	remoteRegistries.List = dedupe(remoteRegistries.List)
 }
diff --git a/svc/svc.go b/svc/svc.go
index 1748bf0..a7786d9 100644
--- a/svc/svc.go
+++ b/svc/svc.go
@@ -19,6 +19,9 @@ func Start() {
 	// to serve the same content without duplicating
 	// handlers/paths
 	index := mux.NewRouter().StrictSlash(true)
+	setIndexRouting(index)
+	api := index.PathPrefix("/api").Subrouter()
+	setEndpointRouting(api)
 
 	confObj.Mu.RLock()
 	portnum := fmt.Sprintf(":%v", confObj.Port)
@@ -30,13 +33,10 @@ func Start() {
 	TLSKey := confObj.TLS.Key
 	confObj.Mu.RUnlock()
 
-	setIndexRouting(index)
-	api := index.PathPrefix("/api").Subrouter()
-	setEndpointRouting(api)
-
 	server := newServer(portnum, index)
 	log.Printf("*** Listening on %v\n", portnum)
 	log.Printf("*** getwtxt %v Startup finished at %v, took %v\n\n", Vers, time.Now().Format(time.RFC3339), time.Since(before))
+
 	if TLS {
 		errLog("", server.ListenAndServeTLS(TLSCert, TLSKey))
 	} else {
@@ -52,6 +52,7 @@ func Start() {
 
 func newServer(port string, index *mux.Router) *http.Server {
 	// handlers.CompressHandler gzips all responses.
+	// ipMiddleware passes the request IP along.
 	// Write/Read timeouts are self explanatory.
 	return &http.Server{
 		Handler:      handlers.CompressHandler(ipMiddleware(index)),
@@ -65,11 +66,9 @@ func setIndexRouting(index *mux.Router) {
 	index.Path("/").
 		Methods("GET", "HEAD").
 		HandlerFunc(staticHandler)
-
 	index.Path("/css").
 		Methods("GET", "HEAD").
 		HandlerFunc(staticHandler)
-
 	index.Path("/api").
 		Methods("GET", "HEAD").
 		HandlerFunc(apiBaseHandler)
@@ -93,7 +92,6 @@ func setEndpointRouting(api *mux.Router) {
 	api.Path("/{format:(?:plain)}/{endpoint:(?:mentions|users|tweets)}").
 		Methods("GET", "HEAD").
 		HandlerFunc(apiEndpointHandler)
-
 	api.Path("/{format:(?:plain)}/{endpoint:(?:mentions|users|tweets)}").
 		Queries("url", "{url}", "q", "{query}", "page", "{[0-9]+}").
 		Methods("GET", "HEAD").
@@ -105,14 +103,12 @@ func setEndpointRouting(api *mux.Router) {
 		Queries("url", "{url}", "nickname", "{nickname:[a-zA-Z0-9_-]+}").
 		Methods("POST").
 		HandlerFunc(apiEndpointPOSTHandler)
-
 	// This is for submitting new users incorrectly
 	// and letting the requester know about their error.
 	api.Path("/{format:(?:plain)}/{endpoint:users}").
 		Queries("url", "{url}").
 		Methods("POST").
 		HandlerFunc(apiEndpointPOSTHandler)
-
 	// This is also for submitting new users incorrectly
 	// and letting the requester know about their error.
 	api.Path("/{format:(?:plain)}/{endpoint:users}").
@@ -124,7 +120,6 @@ func setEndpointRouting(api *mux.Router) {
 	api.Path("/{format:(?:plain)}/tags").
 		Methods("GET", "HEAD").
 		HandlerFunc(apiTagsBaseHandler)
-
 	// Show Nth page of all observed tags
 	api.Path("/{format:(?:plain)}/tags").
 		Queries("page", "{[0-9]+}").
@@ -135,7 +130,6 @@ func setEndpointRouting(api *mux.Router) {
 	api.Path("/{format:(?:plain)}/tags/{tags:[a-zA-Z0-9_-]+}").
 		Methods("GET", "HEAD").
 		HandlerFunc(apiTagsHandler)
-
 	// Requests Nth page of statuses with a specific tag
 	api.Path("/{format:(?:plain)}/tags/{tags:[a-zA-Z0-9_-]+}").
 		Queries("page", "{[0-9]+}").