about summary refs log tree commit diff stats
path: root/worker/notmuch/message.go
diff options
context:
space:
mode:
Diffstat (limited to 'worker/notmuch/message.go')
-rw-r--r--worker/notmuch/message.go134
1 files changed, 49 insertions, 85 deletions
diff --git a/worker/notmuch/message.go b/worker/notmuch/message.go
index aa16cee..c51e2e9 100644
--- a/worker/notmuch/message.go
+++ b/worker/notmuch/message.go
@@ -11,22 +11,24 @@ import (
 
 	"git.sr.ht/~sircmpwn/aerc/models"
 	"git.sr.ht/~sircmpwn/aerc/worker/lib"
+	notmuch "git.sr.ht/~sircmpwn/aerc/worker/notmuch/lib"
 	"github.com/emersion/go-message"
 	_ "github.com/emersion/go-message/charset"
-	notmuch "github.com/zenhack/go.notmuch"
 )
 
 type Message struct {
-	uid     uint32
-	key     string
-	msg     *notmuch.Message
-	rwDB    func() (*notmuch.DB, error) // used to open a db for writing
-	refresh func(*Message) error        // called after msg modification
+	uid uint32
+	key string
+	db  *notmuch.DB
 }
 
 // NewReader reads a message into memory and returns an io.Reader for it.
 func (m *Message) NewReader() (io.Reader, error) {
-	f, err := os.Open(m.msg.Filename())
+	name, err := m.Filename()
+	if err != nil {
+		return nil, err
+	}
+	f, err := os.Open(name)
 	if err != nil {
 		return nil, err
 	}
@@ -46,7 +48,11 @@ func (m *Message) MessageInfo() (*models.MessageInfo, error) {
 // NewBodyPartReader creates a new io.Reader for the requested body part(s) of
 // the message.
 func (m *Message) NewBodyPartReader(requestedParts []int) (io.Reader, error) {
-	f, err := os.Open(m.msg.Filename())
+	name, err := m.Filename()
+	if err != nil {
+		return nil, err
+	}
+	f, err := os.Open(name)
 	if err != nil {
 		return nil, err
 	}
@@ -61,7 +67,11 @@ func (m *Message) NewBodyPartReader(requestedParts []int) (io.Reader, error) {
 // MarkRead either adds or removes the maildir.FlagSeen flag from the message.
 func (m *Message) MarkRead(seen bool) error {
 	haveUnread := false
-	for _, t := range m.tags() {
+	tags, err := m.Tags()
+	if err != nil {
+		return err
+	}
+	for _, t := range tags {
 		if t == "unread" {
 			haveUnread = true
 			break
@@ -80,7 +90,7 @@ func (m *Message) MarkRead(seen bool) error {
 		return nil
 	}
 
-	err := m.AddTag("unread")
+	err = m.AddTag("unread")
 	if err != nil {
 		return err
 	}
@@ -88,86 +98,18 @@ func (m *Message) MarkRead(seen bool) error {
 }
 
 // tags returns the notmuch tags of a message
-func (m *Message) tags() []string {
-	ts := m.msg.Tags()
-	var tags []string
-	var tag *notmuch.Tag
-	for ts.Next(&tag) {
-		tags = append(tags, tag.Value)
-	}
-	return tags
-}
-
-func (m *Message) modify(cb func(*notmuch.Message) error) error {
-	db, err := m.rwDB()
-	if err != nil {
-		return err
-	}
-	defer db.Close()
-	msg, err := db.FindMessage(m.key)
-	if err != nil {
-		return err
-	}
-	err = cb(msg)
-	if err != nil {
-		return err
-	}
-	// we need to explicitly close here, else we don't commit
-	dcerr := db.Close()
-	if dcerr != nil && err == nil {
-		err = dcerr
-	}
-	// next we need to refresh the notmuch msg, else we serve stale tags
-	rerr := m.refresh(m)
-	if rerr != nil && err == nil {
-		err = rerr
-	}
-	return err
-}
-
-func (m *Message) AddTag(tag string) error {
-	err := m.modify(func(msg *notmuch.Message) error {
-		return msg.AddTag(tag)
-	})
-	return err
-}
-
-func (m *Message) AddTags(tags []string) error {
-	err := m.modify(func(msg *notmuch.Message) error {
-		ierr := msg.Atomic(func(msg *notmuch.Message) {
-			for _, t := range tags {
-				msg.AddTag(t)
-			}
-		})
-		return ierr
-	})
-	return err
-}
-
-func (m *Message) RemoveTag(tag string) error {
-	err := m.modify(func(msg *notmuch.Message) error {
-		return msg.RemoveTag(tag)
-	})
-	return err
-}
-
-func (m *Message) RemoveTags(tags []string) error {
-	err := m.modify(func(msg *notmuch.Message) error {
-		ierr := msg.Atomic(func(msg *notmuch.Message) {
-			for _, t := range tags {
-				msg.RemoveTag(t)
-			}
-		})
-		return ierr
-	})
-	return err
+func (m *Message) Tags() ([]string, error) {
+	return m.db.MsgTags(m.key)
 }
 
 func (m *Message) ModelFlags() ([]models.Flag, error) {
 	var flags []models.Flag
 	seen := true
-
-	for _, tag := range m.tags() {
+	tags, err := m.Tags()
+	if err != nil {
+		return nil, err
+	}
+	for _, tag := range tags {
 		switch tag {
 		case "replied":
 			flags = append(flags, models.AnsweredFlag)
@@ -188,3 +130,25 @@ func (m *Message) ModelFlags() ([]models.Flag, error) {
 func (m *Message) UID() uint32 {
 	return m.uid
 }
+
+func (m *Message) Filename() (string, error) {
+	return m.db.MsgFilename(m.key)
+}
+
+//AddTag adds a single tag.
+//Consider using *Message.ModifyTags for multiple additions / removals
+//instead of looping over a tag array
+func (m *Message) AddTag(tag string) error {
+	return m.ModifyTags([]string{tag}, nil)
+}
+
+//RemoveTag removes a single tag.
+//Consider using *Message.ModifyTags for multiple additions / removals
+//instead of looping over a tag array
+func (m *Message) RemoveTag(tag string) error {
+	return m.ModifyTags(nil, []string{tag})
+}
+
+func (m *Message) ModifyTags(add, remove []string) error {
+	return m.db.MsgModifyTags(m.key, add, remove)
+}