summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--commands/msg/utils.go17
-rw-r--r--commands/util.go17
-rw-r--r--widgets/account.go4
-rw-r--r--widgets/msgviewer.go4
-rw-r--r--widgets/providesmessage.go2
5 files changed, 30 insertions, 14 deletions
diff --git a/commands/msg/utils.go b/commands/msg/utils.go
index ae25535..cad0f82 100644
--- a/commands/msg/utils.go
+++ b/commands/msg/utils.go
@@ -18,12 +18,7 @@ func newHelper(aerc *widgets.Aerc) *helper {
 }
 
 func (h *helper) markedOrSelectedUids() ([]uint32, error) {
-	msgs, err := commands.MarkedOrSelected(h.msgProvider)
-	if err != nil {
-		return nil, err
-	}
-	uids := commands.UidsFromMessageInfos(msgs)
-	return uids, nil
+	return commands.MarkedOrSelected(h.msgProvider)
 }
 
 func (h *helper) store() (*lib.MessageStore, error) {
@@ -43,5 +38,13 @@ func (h *helper) account() (*widgets.AccountView, error) {
 }
 
 func (h *helper) messages() ([]*models.MessageInfo, error) {
-	return commands.MarkedOrSelected(h.msgProvider)
+	uid, err := commands.MarkedOrSelected(h.msgProvider)
+	if err != nil {
+		return nil, err
+	}
+	store, err := h.store()
+	if err != nil {
+		return nil, err
+	}
+	return commands.MsgInfoFromUids(store, uid)
 }
diff --git a/commands/util.go b/commands/util.go
index 5529edb..e3395fd 100644
--- a/commands/util.go
+++ b/commands/util.go
@@ -10,6 +10,7 @@ import (
 	"strings"
 	"time"
 
+	"git.sr.ht/~sircmpwn/aerc/lib"
 	"git.sr.ht/~sircmpwn/aerc/models"
 	"git.sr.ht/~sircmpwn/aerc/widgets"
 	"github.com/gdamore/tcell"
@@ -152,7 +153,7 @@ func listDir(path string, hidden bool) []string {
 
 // MarkedOrSelected returns either all marked messages if any are marked or the
 // selected message instead
-func MarkedOrSelected(pm widgets.ProvidesMessages) ([]*models.MessageInfo, error) {
+func MarkedOrSelected(pm widgets.ProvidesMessages) ([]uint32, error) {
 	// marked has priority over the selected message
 	marked, err := pm.MarkedMessages()
 	if err != nil {
@@ -165,7 +166,7 @@ func MarkedOrSelected(pm widgets.ProvidesMessages) ([]*models.MessageInfo, error
 	if err != nil {
 		return nil, err
 	}
-	return []*models.MessageInfo{msg}, nil
+	return []uint32{msg.Uid}, nil
 }
 
 // UidsFromMessageInfos extracts a uid slice from a slice of MessageInfos
@@ -178,3 +179,15 @@ func UidsFromMessageInfos(msgs []*models.MessageInfo) []uint32 {
 	}
 	return uids
 }
+
+func MsgInfoFromUids(store *lib.MessageStore, uids []uint32) ([]*models.MessageInfo, error) {
+	infos := make([]*models.MessageInfo, len(uids))
+	for i, uid := range uids {
+		var ok bool
+		infos[i], ok = store.Messages[uid]
+		if !ok {
+			return nil, fmt.Errorf("uid not found")
+		}
+	}
+	return infos, nil
+}
diff --git a/widgets/account.go b/widgets/account.go
index 31384a5..20ed345 100644
--- a/widgets/account.go
+++ b/widgets/account.go
@@ -215,9 +215,9 @@ func (acct *AccountView) SelectedMessage() (*models.MessageInfo, error) {
 	return msg, nil
 }
 
-func (acct *AccountView) MarkedMessages() ([]*models.MessageInfo, error) {
+func (acct *AccountView) MarkedMessages() ([]uint32, error) {
 	store := acct.Store()
-	return msgInfoFromUids(store, store.Marked())
+	return store.Marked(), nil
 }
 
 func (acct *AccountView) SelectedMessagePart() *PartInfo {
diff --git a/widgets/msgviewer.go b/widgets/msgviewer.go
index e192ae6..0cfabd7 100644
--- a/widgets/msgviewer.go
+++ b/widgets/msgviewer.go
@@ -262,9 +262,9 @@ func (mv *MessageViewer) SelectedMessage() (*models.MessageInfo, error) {
 	return mv.msg.MessageInfo(), nil
 }
 
-func (mv *MessageViewer) MarkedMessages() ([]*models.MessageInfo, error) {
+func (mv *MessageViewer) MarkedMessages() ([]uint32, error) {
 	store := mv.Store()
-	return msgInfoFromUids(store, store.Marked())
+	return store.Marked(), nil
 }
 
 func (mv *MessageViewer) ToggleHeaders() {
diff --git a/widgets/providesmessage.go b/widgets/providesmessage.go
index b06825f..6e00b1c 100644
--- a/widgets/providesmessage.go
+++ b/widgets/providesmessage.go
@@ -25,5 +25,5 @@ type ProvidesMessages interface {
 	Store() *lib.MessageStore
 	SelectedAccount() *AccountView
 	SelectedMessage() (*models.MessageInfo, error)
-	MarkedMessages() ([]*models.MessageInfo, error)
+	MarkedMessages() ([]uint32, error)
 }