about summary refs log tree commit diff stats
path: root/svc
diff options
context:
space:
mode:
Diffstat (limited to 'svc')
-rw-r--r--svc/common.go16
-rw-r--r--svc/common_test.go34
-rw-r--r--svc/conf.go13
3 files changed, 63 insertions, 0 deletions
diff --git a/svc/common.go b/svc/common.go
new file mode 100644
index 0000000..5f169af
--- /dev/null
+++ b/svc/common.go
@@ -0,0 +1,16 @@
+package svc
+
+import "golang.org/x/crypto/bcrypt"
+
+// HashPass returns the bcrypt hash of the provided string.
+// If an empty string is provided, return an empty string.
+func HashPass(s string) (string, error) {
+	if s == "" {
+		return "", nil
+	}
+	h, err := bcrypt.GenerateFromPassword([]byte(s), 14)
+	if err != nil {
+		return "", err
+	}
+	return string(h), nil
+}
diff --git a/svc/common_test.go b/svc/common_test.go
new file mode 100644
index 0000000..d9a08b3
--- /dev/null
+++ b/svc/common_test.go
@@ -0,0 +1,34 @@
+package svc
+
+import (
+	"testing"
+)
+
+func TestHashPass(t *testing.T) {
+	cases := []struct {
+		in, name   string
+		shouldFail bool
+	}{
+		{
+			in:         "foo",
+			name:       "non-empty password",
+			shouldFail: false,
+		},
+		{
+			in:         "",
+			name:       "empty password",
+			shouldFail: true,
+		},
+	}
+	for _, v := range cases {
+		t.Run(v.name, func(t *testing.T) {
+			out, err := HashPass(v.in)
+			if err != nil && !v.shouldFail {
+				t.Errorf("Shouldn't have failed: Case %s, Error: %s", v.name, err)
+			}
+			if out == "" && v.in != "" {
+				t.Errorf("Got empty out for case %s input %s", v.name, v.in)
+			}
+		})
+	}
+}
diff --git a/svc/conf.go b/svc/conf.go
index 7365b2b..5f826fb 100644
--- a/svc/conf.go
+++ b/svc/conf.go
@@ -20,6 +20,7 @@ along with Getwtxt.  If not, see <https://www.gnu.org/licenses/>.
 package svc // import "git.sr.ht/~gbmor/getwtxt/svc"
 
 import (
+	"fmt"
 	"log"
 	"os"
 	"path/filepath"
@@ -43,6 +44,7 @@ type Configuration struct {
 	DBPath        string        `yaml:"DatabasePath"`
 	AssetsDir     string        `yaml:"AssetsDirectory"`
 	StaticDir     string        `yaml:"StaticFilesDirectory"`
+	AdminPassHash string        `yaml:"-"`
 	StdoutLogging bool          `yaml:"StdoutLogging"`
 	CacheInterval time.Duration `yaml:"StatusFetchInterval"`
 	DBInterval    time.Duration `yaml:"DatabasePushInterval"`
@@ -126,6 +128,7 @@ func setConfigDefaults() {
 	viper.SetDefault("StdoutLogging", false)
 	viper.SetDefault("ReCacheInterval", "1h")
 	viper.SetDefault("DatabasePushInterval", "5m")
+	viper.SetDefault("AdminPassword", "please_change_me")
 
 	viper.SetDefault("Instance.SiteName", "getwtxt")
 	viper.SetDefault("Instance.OwnerName", "Anonymous Microblogger")
@@ -173,6 +176,16 @@ func bindConfig() {
 	confObj.StdoutLogging = viper.GetBool("StdoutLogging")
 	confObj.CacheInterval = viper.GetDuration("StatusFetchInterval")
 	confObj.DBInterval = viper.GetDuration("DatabasePushInterval")
+	txtPass := viper.GetString("AdminPassword")
+	if txtPass == "please_change_me" {
+		fmt.Println("Please set AdminPassword in getwtxt.yml")
+		os.Exit(1)
+	}
+	passHash, err := HashPass(txtPass)
+	if err != nil {
+		errFatal("Failed to hash administrator password: ", err)
+	}
+	confObj.AdminPassHash = passHash
 
 	confObj.Instance.Vers = Vers
 	confObj.Instance.Name = viper.GetString("Instance.SiteName")