summary refs log tree commit diff stats
path: root/commands/msgview
diff options
context:
space:
mode:
Diffstat (limited to 'commands/msgview')
-rw-r--r--commands/msgview/save.go195
1 files changed, 125 insertions, 70 deletions
diff --git a/commands/msgview/save.go b/commands/msgview/save.go
index c017e70..7f236cb 100644
--- a/commands/msgview/save.go
+++ b/commands/msgview/save.go
@@ -1,11 +1,9 @@
 package msgview
 
 import (
-	"encoding/base64"
 	"errors"
 	"fmt"
 	"io"
-	"mime/quotedprintable"
 	"os"
 	"path/filepath"
 	"strings"
@@ -15,6 +13,7 @@ import (
 	"github.com/mitchellh/go-homedir"
 
 	"git.sr.ht/~sircmpwn/aerc/commands"
+	"git.sr.ht/~sircmpwn/aerc/models"
 	"git.sr.ht/~sircmpwn/aerc/widgets"
 )
 
@@ -34,102 +33,158 @@ func (Save) Complete(aerc *widgets.Aerc, args []string) []string {
 }
 
 func (Save) Execute(aerc *widgets.Aerc, args []string) error {
-	if len(args) == 1 {
-		return errors.New("Usage: :save [-p] <path>")
-	}
-	opts, optind, err := getopt.Getopts(args, "p")
+	opts, optind, err := getopt.Getopts(args, "fp")
 	if err != nil {
 		return err
 	}
 
 	var (
-		mkdirs bool
-		path   string = strings.Join(args[optind:], " ")
+		force         bool
+		createDirs    bool
+		trailingSlash bool
 	)
 
 	for _, opt := range opts {
 		switch opt.Option {
+		case 'f':
+			force = true
 		case 'p':
-			mkdirs = true
+			createDirs = true
 		}
 	}
-	if defaultPath := aerc.Config().General.DefaultSavePath; defaultPath != "" {
-		path = defaultPath
+
+	defaultPath := aerc.Config().General.DefaultSavePath
+	// we either need a path or a defaultPath
+	if defaultPath == "" && len(args) == optind {
+		return errors.New("Usage: :save [-fp] <path>")
 	}
 
-	mv := aerc.SelectedTab().(*widgets.MessageViewer)
-	p := mv.SelectedMessagePart()
+	// as a convenience we join with spaces, so that the user doesn't need to
+	// quote filenames containing spaces
+	path := strings.Join(args[optind:], " ")
+
+	// needs to be determined prior to calling filepath.Clean / filepath.Join
+	// it gets stripped by Clean.
+	// we auto generate a name if a directory was given
+	if len(path) > 0 {
+		trailingSlash = path[len(path)-1] == '/'
+	} else if len(defaultPath) > 0 && len(path) == 0 {
+		// empty path, so we might have a default that ends in a trailingSlash
+		trailingSlash = defaultPath[len(defaultPath)-1] == '/'
+	}
 
-	p.Store.FetchBodyPart(p.Msg.Uid, p.Msg.BodyStructure, p.Index, func(reader io.Reader) {
-		// email parts are encoded as 7bit (plaintext), quoted-printable, or base64
+	// Absolute paths are taken as is so that the user can override the default
+	// if they want to
+	if !isAbsPath(path) {
+		path = filepath.Join(defaultPath, path)
+	}
 
-		if strings.EqualFold(p.Part.Encoding, "base64") {
-			reader = base64.NewDecoder(base64.StdEncoding, reader)
-		} else if strings.EqualFold(p.Part.Encoding, "quoted-printable") {
-			reader = quotedprintable.NewReader(reader)
-		}
+	path, err = homedir.Expand(path)
+	if err != nil {
+		return err
+	}
 
-		var pathIsDir bool
-		if path[len(path)-1:] == "/" {
-			pathIsDir = true
-		}
-		// Note: path expansion has to happen after test for trailing /,
-		// since it is stripped when path is expanded
-		path, err := homedir.Expand(path)
+	mv, ok := aerc.SelectedTab().(*widgets.MessageViewer)
+	if !ok {
+		return fmt.Errorf("SelectedTab is not a MessageViewer")
+	}
+	pi := mv.SelectedMessagePart()
+
+	if trailingSlash || isDirExists(path) {
+		filename := generateFilename(pi.Part)
+		path = filepath.Join(path, filename)
+	}
+
+	dir := filepath.Dir(path)
+	if createDirs && dir != "" {
+		err := os.MkdirAll(dir, 0755)
 		if err != nil {
-			aerc.PushError(" " + err.Error())
+			return err
 		}
+	}
 
-		pathinfo, err := os.Stat(path)
-		if err == nil && pathinfo.IsDir() {
-			pathIsDir = true
-		} else if os.IsExist(err) && pathIsDir {
-			aerc.PushError("The given directory is an existing file")
-		}
-		var (
-			save_file string
-			save_dir  string
-		)
-		if pathIsDir {
-			save_dir = path
-			if filename, ok := p.Part.DispositionParams["filename"]; ok {
-				save_file = filename
-			} else if filename, ok := p.Part.Params["name"]; ok {
-				save_file = filename
-			} else {
-				timestamp := time.Now().Format("2006-01-02-150405")
-				save_file = fmt.Sprintf("aerc_%v", timestamp)
+	if pathExists(path) && !force {
+		return fmt.Errorf("%q already exists and -f not given", path)
+	}
+
+	ch := make(chan error, 1)
+	pi.Store.FetchBodyPart(
+		pi.Msg.Uid, pi.Msg.BodyStructure, pi.Index, func(reader io.Reader) {
+			f, err := os.Create(path)
+			if err != nil {
+				ch <- err
+				return
 			}
-		} else {
-			save_file = filepath.Base(path)
-			save_dir = filepath.Dir(path)
-		}
-		if _, err := os.Stat(save_dir); os.IsNotExist(err) {
-			if mkdirs {
-				os.MkdirAll(save_dir, 0755)
-			} else {
-				aerc.PushError("Target directory does not exist, use " +
-					":save with the -p option to create it")
+			defer f.Close()
+			_, err = io.Copy(f, reader)
+			if err != nil {
+				ch <- err
 				return
 			}
-		}
-		target := filepath.Clean(filepath.Join(save_dir, save_file))
+			ch <- nil
+		})
 
-		f, err := os.Create(target)
+	// we need to wait for the callback prior to displaying a result
+	go func() {
+		err := <-ch
 		if err != nil {
-			aerc.PushError(" " + err.Error())
+			aerc.PushError(fmt.Sprintf("Save failed: %v", err))
 			return
 		}
-		defer f.Close()
+		aerc.PushStatus("Saved to "+path, 10*time.Second)
+	}()
+	return nil
+}
 
-		_, err = io.Copy(f, reader)
-		if err != nil {
-			aerc.PushError(" " + err.Error())
-			return
-		}
+//isDir returns true if path is a directory and exists
+func isDirExists(path string) bool {
+	pathinfo, err := os.Stat(path)
+	if err != nil {
+		return false // we don't really care
+	}
+	if pathinfo.IsDir() {
+		return true
+	}
+	return false
+}
 
-		aerc.PushStatus("Saved to "+target, 10*time.Second)
-	})
+//pathExists returns true if path exists
+func pathExists(path string) bool {
+	_, err := os.Stat(path)
+	if err != nil {
+		return false // we don't really care why it failed
+	}
+	return true
+}
 
-	return nil
+//isAbsPath returns true if path given is anchored to / or . or ~
+func isAbsPath(path string) bool {
+	if len(path) == 0 {
+		return false
+	}
+	switch path[0] {
+	case '/':
+		return true
+	case '.':
+		return true
+	case '~':
+		return true
+	default:
+		return false
+	}
+}
+
+// generateFilename tries to get the filename from the given part.
+// if that fails it will fallback to a generated one based on the date
+func generateFilename(part *models.BodyStructure) string {
+	var filename string
+	if fn, ok := part.DispositionParams["filename"]; ok {
+		filename = fn
+	} else if fn, ok := part.Params["name"]; ok {
+		filename = fn
+	} else {
+		timestamp := time.Now().Format("2006-01-02-150405")
+		filename = fmt.Sprintf("aerc_%v", timestamp)
+	}
+	return filename
 }