about summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--worker/notmuch/lib/database.go179
-rw-r--r--worker/notmuch/message.go134
-rw-r--r--worker/notmuch/worker.go112
3 files changed, 250 insertions, 175 deletions
diff --git a/worker/notmuch/lib/database.go b/worker/notmuch/lib/database.go
new file mode 100644
index 0000000..3398504
--- /dev/null
+++ b/worker/notmuch/lib/database.go
@@ -0,0 +1,179 @@
+//+build notmuch
+
+package lib
+
+import (
+	"fmt"
+	"log"
+
+	notmuch "github.com/zenhack/go.notmuch"
+)
+
+type DB struct {
+	path         string
+	excludedTags []string
+	ro           *notmuch.DB
+	logger       *log.Logger
+}
+
+func NewDB(path string, excludedTags []string,
+	logger *log.Logger) *DB {
+	db := &DB{
+		path:         path,
+		excludedTags: excludedTags,
+		logger:       logger,
+	}
+	return db
+}
+
+func (db *DB) Connect() error {
+	return db.connectRO()
+}
+
+// connectRW returns a writable notmuch DB, which needs to be closed to commit
+// the changes and to release the DB lock
+func (db *DB) connectRW() (*notmuch.DB, error) {
+	rw, err := notmuch.Open(db.path, notmuch.DBReadWrite)
+	if err != nil {
+		return nil, fmt.Errorf("could not connect to notmuch db: %v", err)
+	}
+	return rw, err
+}
+
+// connectRO connects a RO db to the worker
+func (db *DB) connectRO() error {
+	if db.ro != nil {
+		if err := db.ro.Close(); err != nil {
+			db.logger.Printf("connectRO: could not close the old db: %v", err)
+		}
+	}
+	var err error
+	db.ro, err = notmuch.Open(db.path, notmuch.DBReadOnly)
+	if err != nil {
+		return fmt.Errorf("could not connect to notmuch db: %v", err)
+	}
+	return nil
+}
+
+//getQuery returns a query based on the provided query string.
+//It also configures the query as specified on the worker
+func (db *DB) newQuery(query string) (*notmuch.Query, error) {
+	if db.ro == nil {
+		return nil, fmt.Errorf("not connected to the notmuch db")
+	}
+	q := db.ro.NewQuery(query)
+	q.SetExcludeScheme(notmuch.EXCLUDE_TRUE)
+	q.SetSortScheme(notmuch.SORT_OLDEST_FIRST)
+	for _, t := range db.excludedTags {
+		err := q.AddTagExclude(t)
+		if err != nil && err != notmuch.ErrIgnored {
+			return nil, err
+		}
+	}
+	return q, nil
+}
+
+func (db *DB) MsgIDsFromQuery(q string) ([]string, error) {
+	if db.ro == nil {
+		return nil, fmt.Errorf("not connected to the notmuch db")
+	}
+	query, err := db.newQuery(q)
+	if err != nil {
+		return nil, err
+	}
+	msgs, err := query.Messages()
+	if err != nil {
+		return nil, err
+	}
+	var msg *notmuch.Message
+	var msgIDs []string
+	for msgs.Next(&msg) {
+		msgIDs = append(msgIDs, msg.ID())
+	}
+	return msgIDs, nil
+}
+
+type MessageCount struct {
+	Exists int
+	Unread int
+}
+
+func (db *DB) QueryCountMessages(q string) (MessageCount, error) {
+	query, err := db.newQuery(q)
+	if err != nil {
+		return MessageCount{}, err
+	}
+	exists := query.CountMessages()
+	query.Close()
+	uq, err := db.newQuery(fmt.Sprintf("(%v) and (tag:unread)", q))
+	if err != nil {
+		return MessageCount{}, err
+	}
+	defer uq.Close()
+	unread := uq.CountMessages()
+	return MessageCount{
+		Exists: exists,
+		Unread: unread,
+	}, nil
+}
+
+func (db *DB) MsgFilename(key string) (string, error) {
+	msg, err := db.ro.FindMessage(key)
+	if err != nil {
+		return "", err
+	}
+	defer msg.Close()
+	return msg.Filename(), nil
+}
+
+func (db *DB) MsgTags(key string) ([]string, error) {
+	msg, err := db.ro.FindMessage(key)
+	if err != nil {
+		return nil, err
+	}
+	defer msg.Close()
+	ts := msg.Tags()
+	var tags []string
+	var tag *notmuch.Tag
+	for ts.Next(&tag) {
+		tags = append(tags, tag.Value)
+	}
+	return tags, nil
+}
+
+func (db *DB) msgModify(key string,
+	cb func(*notmuch.Message) error) error {
+	defer db.connectRO()
+	db.ro.Close()
+
+	rw, err := db.connectRW()
+	if err != nil {
+		return err
+	}
+	defer rw.Close()
+
+	msg, err := rw.FindMessage(key)
+	if err != nil {
+		return err
+	}
+	defer msg.Close()
+
+	cb(msg)
+	return nil
+}
+
+func (db *DB) MsgModifyTags(key string, add, remove []string) error {
+	err := db.msgModify(key, func(msg *notmuch.Message) error {
+		ierr := msg.Atomic(func(msg *notmuch.Message) {
+			for _, t := range add {
+				msg.AddTag(t)
+			}
+			for _, t := range remove {
+				msg.RemoveTag(t)
+			}
+		})
+		return ierr
+	})
+	return err
+}
+
diff --git a/worker/notmuch/message.go b/worker/notmuch/message.go
index aa16cee..c51e2e9 100644
--- a/worker/notmuch/message.go
+++ b/worker/notmuch/message.go
@@ -11,22 +11,24 @@ import (
 
 	"git.sr.ht/~sircmpwn/aerc/models"
 	"git.sr.ht/~sircmpwn/aerc/worker/lib"
+	notmuch "git.sr.ht/~sircmpwn/aerc/worker/notmuch/lib"
 	"github.com/emersion/go-message"
 	_ "github.com/emersion/go-message/charset"
-	notmuch "github.com/zenhack/go.notmuch"
 )
 
 type Message struct {
-	uid     uint32
-	key     string
-	msg     *notmuch.Message
-	rwDB    func() (*notmuch.DB, error) // used to open a db for writing
-	refresh func(*Message) error        // called after msg modification
+	uid uint32
+	key string
+	db  *notmuch.DB
 }
 
 // NewReader reads a message into memory and returns an io.Reader for it.
 func (m *Message) NewReader() (io.Reader, error) {
-	f, err := os.Open(m.msg.Filename())
+	name, err := m.Filename()
+	if err != nil {
+		return nil, err
+	}
+	f, err := os.Open(name)
 	if err != nil {
 		return nil, err
 	}
@@ -46,7 +48,11 @@ func (m *Message) MessageInfo() (*models.MessageInfo, error) {
 // NewBodyPartReader creates a new io.Reader for the requested body part(s) of
 // the message.
 func (m *Message) NewBodyPartReader(requestedParts []int) (io.Reader, error) {
-	f, err := os.Open(m.msg.Filename())
+	name, err := m.Filename()
+	if err != nil {
+		return nil, err
+	}
+	f, err := os.Open(name)
 	if err != nil {
 		return nil, err
 	}
@@ -61,7 +67,11 @@ func (m *Message) NewBodyPartReader(requestedParts []int) (io.Reader, error) {
 // MarkRead either adds or removes the maildir.FlagSeen flag from the message.
 func (m *Message) MarkRead(seen bool) error {
 	haveUnread := false
-	for _, t := range m.tags() {
+	tags, err := m.Tags()
+	if err != nil {
+		return err
+	}
+	for _, t := range tags {
 		if t == "unread" {
 			haveUnread = true
 			break
@@ -80,7 +90,7 @@ func (m *Message) MarkRead(seen bool) error {
 		return nil
 	}
 
-	err := m.AddTag("unread")
+	err = m.AddTag("unread")
 	if err != nil {
 		return err
 	}
@@ -88,86 +98,18 @@ func (m *Message) MarkRead(seen bool) error {
 }
 
 // tags returns the notmuch tags of a message
-func (m *Message) tags() []string {
-	ts := m.msg.Tags()
-	var tags []string
-	var tag *notmuch.Tag
-	for ts.Next(&tag) {
-		tags = append(tags, tag.Value)
-	}
-	return tags
-}
-
-func (m *Message) modify(cb func(*notmuch.Message) error) error {
-	db, err := m.rwDB()
-	if err != nil {
-		return err
-	}
-	defer db.Close()
-	msg, err := db.FindMessage(m.key)
-	if err != nil {
-		return err
-	}
-	err = cb(msg)
-	if err != nil {
-		return err
-	}
-	// we need to explicitly close here, else we don't commit
-	dcerr := db.Close()
-	if dcerr != nil && err == nil {
-		err = dcerr
-	}
-	// next we need to refresh the notmuch msg, else we serve stale tags
-	rerr := m.refresh(m)
-	if rerr != nil && err == nil {
-		err = rerr
-	}
-	return err
-}
-
-func (m *Message) AddTag(tag string) error {
-	err := m.modify(func(msg *notmuch.Message) error {
-		return msg.AddTag(tag)
-	})
-	return err
-}
-
-func (m *Message) AddTags(tags []string) error {
-	err := m.modify(func(msg *notmuch.Message) error {
-		ierr := msg.Atomic(func(msg *notmuch.Message) {
-			for _, t := range tags {
-				msg.AddTag(t)
-			}
-		})
-		return ierr
-	})
-	return err
-}
-
-func (m *Message) RemoveTag(tag string) error {
-	err := m.modify(func(msg *notmuch.Message) error {
-		return msg.RemoveTag(tag)
-	})
-	return err
-}
-
-func (m *Message) RemoveTags(tags []string) error {
-	err := m.modify(func(msg *notmuch.Message) error {
-		ierr := msg.Atomic(func(msg *notmuch.Message) {
-			for _, t := range tags {
-				msg.RemoveTag(t)
-			}
-		})
-		return ierr
-	})
-	return err
+func (m *Message) Tags() ([]string, error) {
+	return m.db.MsgTags(m.key)
 }
 
 func (m *Message) ModelFlags() ([]models.Flag, error) {
 	var flags []models.Flag
 	seen := true
-
-	for _, tag := range m.tags() {
+	tags, err := m.Tags()
+	if err != nil {
+		return nil, err
+	}
+	for _, tag := range tags {
 		switch tag {
 		case "replied":
 			flags = append(flags, models.AnsweredFlag)
@@ -188,3 +130,25 @@ func (m *Message) ModelFlags() ([]models.Flag, error) {
 func (m *Message) UID() uint32 {
 	return m.uid
 }
+
+func (m *Message) Filename() (string, error) {
+	return m.db.MsgFilename(m.key)
+}
+
+//AddTag adds a single tag.
+//Consider using *Message.ModifyTags for multiple additions / removals
+//instead of looping over a tag array
+func (m *Message) AddTag(tag string) error {
+	return m.ModifyTags([]string{tag}, nil)
+}
+
+//RemoveTag removes a single tag.
+//Consider using *Message.ModifyTags for multiple additions / removals
+//instead of looping over a tag array
+func (m *Message) RemoveTag(tag string) error {
+	return m.ModifyTags(nil, []string{tag})
+}
+
+func (m *Message) ModifyTags(add, remove []string) error {
+	return m.db.MsgModifyTags(m.key, add, remove)
+}
diff --git a/worker/notmuch/worker.go b/worker/notmuch/worker.go
index 58a63ec..59624b3 100644
--- a/worker/notmuch/worker.go
+++ b/worker/notmuch/worker.go
@@ -14,9 +14,9 @@ import (
 	"git.sr.ht/~sircmpwn/aerc/lib/uidstore"
 	"git.sr.ht/~sircmpwn/aerc/models"
 	"git.sr.ht/~sircmpwn/aerc/worker/handlers"
+	notmuch "git.sr.ht/~sircmpwn/aerc/worker/notmuch/lib"
 	"git.sr.ht/~sircmpwn/aerc/worker/types"
 	"github.com/mitchellh/go-homedir"
-	notmuch "github.com/zenhack/go.notmuch"
 )
 
 func init() {
@@ -27,12 +27,10 @@ var errUnsupported = fmt.Errorf("unsupported command")
 
 type worker struct {
 	w            *types.Worker
-	pathToDB     string
-	db           *notmuch.DB
 	query        string
 	uidStore     *uidstore.Store
-	excludedTags []string
 	nameQueryMap map[string]string
+	db           *notmuch.DB
 }
 
 // NewWorker creates a new maildir worker with the provided worker.
@@ -116,46 +114,18 @@ func (w *worker) handleConfigure(msg *types.Configure) error {
 	if err != nil {
 		return fmt.Errorf("could not resolve home directory: %v", err)
 	}
-	w.pathToDB = filepath.Join(home, u.Path)
+	pathToDB := filepath.Join(home, u.Path)
 	w.uidStore = uidstore.NewStore()
-
 	if err = w.loadQueryMap(msg.Config); err != nil {
 		return fmt.Errorf("could not load query map: %v", err)
 	}
-	if err = w.loadExcludeTags(msg.Config); err != nil {
-		return fmt.Errorf("could not load excluded tags: %v", err)
-	}
-	w.w.Logger.Printf("configured db directory: %s", w.pathToDB)
-	return nil
-}
-
-// connectRW returns a writable notmuch DB, which needs to be closed to commit
-// the changes and to release the DB lock
-func (w *worker) connectRW() (*notmuch.DB, error) {
-	db, err := notmuch.Open(w.pathToDB, notmuch.DBReadWrite)
-	if err != nil {
-		return nil, fmt.Errorf("could not connect to notmuch db: %v", err)
-	}
-	return db, err
-}
-
-// connectRO connects a RO db to the worker
-func (w *worker) connectRO() error {
-	if w.db != nil {
-		if err := w.db.Close(); err != nil {
-			w.w.Logger.Printf("connectRO: could not close the old db: %v", err)
-		}
-	}
-	var err error
-	w.db, err = notmuch.Open(w.pathToDB, notmuch.DBReadOnly)
-	if err != nil {
-		return fmt.Errorf("could not connect to notmuch db: %v", err)
-	}
+	excludedTags := w.loadExcludeTags(msg.Config)
+	w.db = notmuch.NewDB(pathToDB, excludedTags, w.w.Logger)
 	return nil
 }
 
 func (w *worker) handleConnect(msg *types.Connect) error {
-	err := w.connectRO()
+	err := w.db.Connect()
 	if err != nil {
 		return err
 	}
@@ -177,21 +147,6 @@ func (w *worker) handleListDirectories(msg *types.ListDirectories) error {
 	return nil
 }
 
-//getQuery returns a query based on the provided query string.
-//It also configures the query as specified on the worker
-func (w *worker) getQuery(query string) (*notmuch.Query, error) {
-	q := w.db.NewQuery(query)
-	q.SetExcludeScheme(notmuch.EXCLUDE_TRUE)
-	q.SetSortScheme(notmuch.SORT_OLDEST_FIRST)
-	for _, t := range w.excludedTags {
-		err := q.AddTagExclude(t)
-		if err != nil && err != notmuch.ErrIgnored {
-			return nil, err
-		}
-	}
-	return q, nil
-}
-
 func (w *worker) handleOpenDirectory(msg *types.OpenDirectory) error {
 	w.w.Logger.Printf("opening %s", msg.Directory)
 	// try the friendly name first, if that fails assume it's a query
@@ -200,7 +155,7 @@ func (w *worker) handleOpenDirectory(msg *types.OpenDirectory) error {
 		q = msg.Directory
 	}
 	w.query = q
-	query, err := w.getQuery(w.query)
+	count, err := w.db.QueryCountMessages(w.query)
 	if err != nil {
 		return err
 	}
@@ -211,11 +166,11 @@ func (w *worker) handleOpenDirectory(msg *types.OpenDirectory) error {
 			Flags:    []string{},
 			ReadOnly: false,
 			// total messages
-			Exists: query.CountMessages(),
+			Exists: count.Exists,
 			// new messages since mailbox was last opened
 			Recent: 0,
 			// total unread
-			Unseen: 0,
+			Unseen: count.Unread,
 		},
 	}
 	w.w.PostMessage(info, nil)
@@ -226,11 +181,7 @@ func (w *worker) handleOpenDirectory(msg *types.OpenDirectory) error {
 
 func (w *worker) handleFetchDirectoryContents(
 	msg *types.FetchDirectoryContents) error {
-	q, err := w.getQuery(w.query)
-	if err != nil {
-		return err
-	}
-	uids, err := w.uidsFromQuery(q)
+	uids, err := w.uidsFromQuery(w.query)
 	if err != nil {
 		w.w.Logger.Printf("error scanning uids: %v", err)
 		return err
@@ -267,15 +218,14 @@ func (w *worker) handleFetchMessageHeaders(
 	return nil
 }
 
-func (w *worker) uidsFromQuery(query *notmuch.Query) ([]uint32, error) {
-	msgs, err := query.Messages()
+func (w *worker) uidsFromQuery(query string) ([]uint32, error) {
+	msgIDs, err := w.db.MsgIDsFromQuery(query)
 	if err != nil {
 		return nil, err
 	}
-	var msg *notmuch.Message
 	var uids []uint32
-	for msgs.Next(&msg) {
-		uid := w.uidStore.GetOrInsert(msg.ID())
+	for _, id := range msgIDs {
+		uid := w.uidStore.GetOrInsert(id)
 		uids = append(uids, uid)
 
 	}
@@ -287,25 +237,10 @@ func (w *worker) msgFromUid(uid uint32) (*Message, error) {
 	if !ok {
 		return nil, fmt.Errorf("Invalid uid: %v", uid)
 	}
-	nm, err := w.db.FindMessage(key)
-	if err != nil {
-		return nil, fmt.Errorf("Could not fetch message for key %q: %v", key, err)
-	}
 	msg := &Message{
-		key:  key,
-		uid:  uid,
-		msg:  nm,
-		rwDB: w.connectRW,
-		refresh: func(m *Message) error {
-			//close the old message manually, else we segfault during gc
-			m.msg.Close()
-			err := w.connectRO()
-			if err != nil {
-				return err
-			}
-			m.msg, err = w.db.FindMessage(m.key)
-			return err
-		},
+		key: key,
+		uid: uid,
+		db:  w.db,
 	}
 	return msg, nil
 }
@@ -409,11 +344,7 @@ func (w *worker) handleSearchDirectory(msg *types.SearchDirectory) error {
 	s := strings.Join(msg.Argv[1:], " ")
 	// we only want to search in the current query, so merge the two together
 	search := fmt.Sprintf("(%v) and (%v)", w.query, s)
-	query, err := w.getQuery(search)
-	if err != nil {
-		return err
-	}
-	uids, err := w.uidsFromQuery(query)
+	uids, err := w.uidsFromQuery(search)
 	if err != nil {
 		return err
 	}
@@ -452,12 +383,13 @@ func (w *worker) loadQueryMap(acctConfig *config.AccountConfig) error {
 	return nil
 }
 
-func (w *worker) loadExcludeTags(acctConfig *config.AccountConfig) error {
+func (w *worker) loadExcludeTags(
+	acctConfig *config.AccountConfig) []string {
 	raw, ok := acctConfig.Params["exclude-tags"]
 	if !ok {
 		// nothing to do
 		return nil
 	}
-	w.excludedTags = strings.Split(raw, ",")
-	return nil
+	excludedTags := strings.Split(raw, ",")
+	return excludedTags
 }