summary refs log tree commit diff stats
diff options
context:
space:
mode:
authorReto Brunner <reto@labrat.space>2020-11-11 23:50:35 +0100
committerReto Brunner <reto@labrat.space>2020-11-21 15:40:50 +0100
commit67923707ffd826ad1d02c0a5b5ebd75ffbc71364 (patch)
tree103104b96682a6afc81041d7b5fc0a323214a460
parent7b12f2d1ea791139022b63029bbd8616564355f3 (diff)
downloadaerc-67923707ffd826ad1d02c0a5b5ebd75ffbc71364.tar.gz
Refactor send command
-rw-r--r--commands/compose/send.go505
1 files changed, 304 insertions, 201 deletions
diff --git a/commands/compose/send.go b/commands/compose/send.go
index e7ef509..40d0ae3 100644
--- a/commands/compose/send.go
+++ b/commands/compose/send.go
@@ -1,6 +1,7 @@
 package compose
 
 import (
+	"bytes"
 	"crypto/tls"
 	"fmt"
 	"io"
@@ -42,6 +43,7 @@ func (Send) Execute(aerc *widgets.Aerc, args []string) error {
 		return errors.New("Usage: send")
 	}
 	composer, _ := aerc.SelectedTab().(*widgets.Composer)
+	tabName := aerc.TabNames()[aerc.SelectedTabIndex()]
 	config := composer.Config()
 
 	if config.Outgoing == "" {
@@ -49,28 +51,6 @@ func (Send) Execute(aerc *widgets.Aerc, args []string) error {
 			"No outgoing mail transport configured for this account")
 	}
 
-	aerc.Logger().Println("Sending mail")
-
-	uri, err := url.Parse(config.Outgoing)
-	if err != nil {
-		return errors.Wrap(err, "url.Parse(outgoing)")
-	}
-	var (
-		scheme string
-		auth   string = "plain"
-	)
-	if uri.Scheme != "" {
-		parts := strings.Split(uri.Scheme, "+")
-		if len(parts) == 1 {
-			scheme = parts[0]
-		} else if len(parts) == 2 {
-			scheme = parts[0]
-			auth = parts[1]
-		} else {
-			return fmt.Errorf("Unknown transfer protocol %s", uri.Scheme)
-		}
-	}
-
 	header, err := composer.PrepareHeader()
 	if err != nil {
 		return errors.Wrap(err, "PrepareHeader")
@@ -83,15 +63,187 @@ func (Send) Execute(aerc *widgets.Aerc, args []string) error {
 	if config.From == "" {
 		return errors.New("No 'From' configured for this account")
 	}
+	// TODO: the user could conceivably want to use a different From and sender
 	from, err := mail.ParseAddress(config.From)
 	if err != nil {
 		return errors.Wrap(err, "ParseAddress(config.From)")
 	}
 
-	var (
-		saslClient sasl.Client
-		conn       *smtp.Client
-	)
+	uri, err := url.Parse(config.Outgoing)
+	if err != nil {
+		return errors.Wrap(err, "url.Parse(outgoing)")
+	}
+
+	scheme, auth, err := parseScheme(uri)
+	if err != nil {
+		return err
+	}
+	var starttls bool
+	if starttls_, ok := config.Params["smtp-starttls"]; ok {
+		starttls = starttls_ == "yes"
+	}
+	ctx := sendCtx{
+		uri:      uri,
+		scheme:   scheme,
+		auth:     auth,
+		starttls: starttls,
+		from:     from,
+		rcpts:    rcpts,
+	}
+
+	var sender io.WriteCloser
+	switch ctx.scheme {
+	case "smtp":
+		fallthrough
+	case "smtps":
+		sender, err = newSmtpSender(ctx)
+	case "":
+		sender, err = newSendmailSender(ctx)
+	default:
+		sender, err = nil, fmt.Errorf("unsupported scheme %v", ctx.scheme)
+	}
+	if err != nil {
+		return errors.Wrap(err, "send:")
+	}
+
+	// if we copy via the worker we need to know the count
+	counter := datacounter.NewWriterCounter(sender)
+	var writer io.Writer = counter
+	writer = counter
+
+	var copyBuf bytes.Buffer
+	if config.CopyTo != "" {
+		writer = io.MultiWriter(writer, &copyBuf)
+	}
+
+	aerc.RemoveTab(composer)
+	aerc.PushStatus("Sending...", 10*time.Second)
+
+	ch := make(chan error)
+	go func() {
+		err := composer.WriteMessage(header, writer)
+		if err != nil {
+			ch <- err
+			return
+		}
+		ch <- sender.Close()
+	}()
+
+	// we don't want to block the UI thread while we are sending
+	go func() {
+		err = <-ch
+		if err != nil {
+			aerc.PushError(err.Error())
+			aerc.NewTab(composer, tabName)
+			return
+		}
+		if config.CopyTo != "" {
+			aerc.PushStatus("Copying to "+config.CopyTo, 10*time.Second)
+			errCh := copyToSent(composer.Worker(), config.CopyTo,
+				int(counter.Count()), &copyBuf)
+			err = <-errCh
+			if err != nil {
+				errmsg := fmt.Sprintf(
+					"message sent, but copying to %v failed: %v",
+					config.CopyTo, err.Error())
+				aerc.PushError(errmsg)
+				composer.SetSent()
+				composer.Close()
+				return
+			}
+		}
+		aerc.PushStatus("Message sent.", 10*time.Second)
+		composer.SetSent()
+		composer.Close()
+	}()
+	return nil
+}
+
+func listRecipients(h *mail.Header) ([]*mail.Address, error) {
+	var rcpts []*mail.Address
+	for _, key := range []string{"to", "cc", "bcc"} {
+		list, err := h.AddressList(key)
+		if err != nil {
+			return nil, err
+		}
+		rcpts = append(rcpts, list...)
+	}
+	return rcpts, nil
+}
+
+type sendCtx struct {
+	uri      *url.URL
+	scheme   string
+	auth     string
+	starttls bool
+	from     *mail.Address
+	rcpts    []*mail.Address
+}
+
+func newSendmailSender(ctx sendCtx) (io.WriteCloser, error) {
+	args, err := shlex.Split(ctx.uri.Path)
+	if err != nil {
+		return nil, err
+	}
+	if len(args) == 0 {
+		return nil, fmt.Errorf("no command specified")
+	}
+	bin := args[0]
+	rs := make([]string, len(ctx.rcpts), len(ctx.rcpts))
+	for i := range ctx.rcpts {
+		rs[i] = ctx.rcpts[i].Address
+	}
+	args = append(args[1:], rs...)
+	cmd := exec.Command(bin, args...)
+	s := &sendmailSender{cmd: cmd}
+	s.stdin, err = s.cmd.StdinPipe()
+	if err != nil {
+		return nil, errors.Wrap(err, "cmd.StdinPipe")
+	}
+	err = s.cmd.Start()
+	if err != nil {
+		return nil, errors.Wrap(err, "cmd.Start")
+	}
+	return s, nil
+}
+
+type sendmailSender struct {
+	cmd   *exec.Cmd
+	stdin io.WriteCloser
+}
+
+func (s *sendmailSender) Write(p []byte) (int, error) {
+	return s.stdin.Write(p)
+}
+
+func (s *sendmailSender) Close() error {
+	se := s.stdin.Close()
+	ce := s.cmd.Wait()
+	if se != nil {
+		return se
+	}
+	return ce
+}
+
+func parseScheme(uri *url.URL) (scheme string, auth string, err error) {
+	scheme = ""
+	auth = "plain"
+	if uri.Scheme != "" {
+		parts := strings.Split(uri.Scheme, "+")
+		if len(parts) == 1 {
+			scheme = parts[0]
+		} else if len(parts) == 2 {
+			scheme = parts[0]
+			auth = parts[1]
+		} else {
+			return "", "", fmt.Errorf("Unknown transfer protocol %s", uri.Scheme)
+		}
+	}
+	return scheme, auth, nil
+}
+
+func newSaslClient(auth string, uri *url.URL) (sasl.Client, error) {
+	var saslClient sasl.Client
 	switch auth {
 	case "":
 		fallthrough
@@ -105,7 +257,6 @@ func (Send) Execute(aerc *widgets.Aerc, args []string) error {
 		saslClient = sasl.NewPlainClient("", uri.User.Username(), password)
 	case "oauthbearer":
 		q := uri.Query()
-
 		oauth2 := &oauth2.Config{}
 		if q.Get("token_endpoint") != "" {
 			oauth2.ClientID = q.Get("client_id")
@@ -113,212 +264,164 @@ func (Send) Execute(aerc *widgets.Aerc, args []string) error {
 			oauth2.Scopes = []string{q.Get("scope")}
 			oauth2.Endpoint.TokenURL = q.Get("token_endpoint")
 		}
-
 		password, _ := uri.User.Password()
 		bearer := lib.OAuthBearer{
 			OAuth2:  oauth2,
 			Enabled: true,
 		}
 		if bearer.OAuth2.Endpoint.TokenURL == "" {
-			return fmt.Errorf("No 'TokenURL' configured for this account")
+			return nil, fmt.Errorf("No 'TokenURL' configured for this account")
 		}
 		token, err := bearer.ExchangeRefreshToken(password)
 		if err != nil {
-			return err
+			return nil, err
 		}
 		password = token.AccessToken
-
 		saslClient = sasl.NewOAuthBearerClient(&sasl.OAuthBearerOptions{
 			Username: uri.User.Username(),
 			Token:    password,
 		})
 	default:
-		return fmt.Errorf("Unsupported auth mechanism %s", auth)
+		return nil, fmt.Errorf("Unsupported auth mechanism %s", auth)
 	}
+	return saslClient, nil
+}
 
-	aerc.RemoveTab(composer)
+type smtpSender struct {
+	ctx  sendCtx
+	conn *smtp.Client
+	w    io.WriteCloser
+}
 
-	var starttls bool
-	if starttls_, ok := config.Params["smtp-starttls"]; ok {
-		starttls = starttls_ == "yes"
+func (s *smtpSender) Write(p []byte) (int, error) {
+	return s.w.Write(p)
+}
+
+func (s *smtpSender) Close() error {
+	we := s.w.Close()
+	ce := s.conn.Close()
+	if we != nil {
+		return we
 	}
+	return ce
+}
 
-	smtpAsync := func() (int, error) {
-		switch scheme {
-		case "smtp":
-			host := uri.Host
-			serverName := uri.Host
-			if !strings.ContainsRune(host, ':') {
-				host = host + ":587" // Default to submission port
-			} else {
-				serverName = host[:strings.IndexRune(host, ':')]
-			}
-			conn, err = smtp.Dial(host)
-			if err != nil {
-				return 0, errors.Wrap(err, "smtp.Dial")
-			}
-			defer conn.Close()
-			if sup, _ := conn.Extension("STARTTLS"); sup {
-				if !starttls {
-					err := errors.New("STARTTLS is supported by this server, " +
-						"but not set in accounts.conf. " +
-						"Add smtp-starttls=yes")
-					return 0, err
-				}
-				if err = conn.StartTLS(&tls.Config{
-					ServerName: serverName,
-				}); err != nil {
-					return 0, errors.Wrap(err, "StartTLS")
-				}
-			} else {
-				if starttls {
-					err := errors.New("STARTTLS requested, but not supported " +
-						"by this SMTP server. Is someone tampering with your " +
-						"connection?")
-					return 0, err
-				}
-			}
-		case "smtps":
-			host := uri.Host
-			serverName := uri.Host
-			if !strings.ContainsRune(host, ':') {
-				host = host + ":465" // Default to smtps port
-			} else {
-				serverName = host[:strings.IndexRune(host, ':')]
-			}
-			conn, err = smtp.DialTLS(host, &tls.Config{
-				ServerName: serverName,
-			})
-			if err != nil {
-				return 0, errors.Wrap(err, "smtp.DialTLS")
-			}
-			defer conn.Close()
-		}
+func newSmtpSender(ctx sendCtx) (io.WriteCloser, error) {
+	var (
+		err  error
+		conn *smtp.Client
+	)
+	switch ctx.scheme {
+	case "smtp":
+		conn, err = connectSmtp(ctx.starttls, ctx.uri.Host)
+	case "smtps":
+		conn, err = connectSmtps(ctx.uri.Host)
+	default:
+		return nil, fmt.Errorf("not an smtp protocol %s", ctx.scheme)
+	}
 
-		if saslClient != nil {
-			if err = conn.Auth(saslClient); err != nil {
-				return 0, errors.Wrap(err, "conn.Auth")
-			}
-		}
-		// TODO: the user could conceivably want to use a different From and sender
-		if err = conn.Mail(from.Address, nil); err != nil {
-			return 0, errors.Wrap(err, "conn.Mail")
-		}
-		aerc.Logger().Printf("rcpt to: %v", rcpts)
-		for _, rcpt := range rcpts {
-			if err = conn.Rcpt(rcpt); err != nil {
-				return 0, errors.Wrap(err, "conn.Rcpt")
-			}
+	saslclient, err := newSaslClient(ctx.auth, ctx.uri)
+	if err != nil {
+		conn.Close()
+		return nil, err
+	}
+	if saslclient != nil {
+		if err := conn.Auth(saslclient); err != nil {
+			conn.Close()
+			return nil, errors.Wrap(err, "conn.Auth")
 		}
-		wc, err := conn.Data()
-		if err != nil {
-			return 0, errors.Wrap(err, "conn.Data")
+	}
+	s := &smtpSender{
+		ctx:  ctx,
+		conn: conn,
+	}
+	if err := s.conn.Mail(s.ctx.from.Address, nil); err != nil {
+		conn.Close()
+		return nil, errors.Wrap(err, "conn.Mail")
+	}
+	for _, rcpt := range s.ctx.rcpts {
+		if err := s.conn.Rcpt(rcpt.Address); err != nil {
+			conn.Close()
+			return nil, errors.Wrap(err, "conn.Rcpt")
 		}
-		defer wc.Close()
-		ctr := datacounter.NewWriterCounter(wc)
-		composer.WriteMessage(header, ctr)
-		return int(ctr.Count()), nil
 	}
+	s.w, err = s.conn.Data()
+	if err != nil {
+		conn.Close()
+		return nil, errors.Wrap(err, "conn.Data")
+	}
+	return s.w, nil
+}
 
-	sendmailAsync := func() (int, error) {
-		args, err := shlex.Split(uri.Path)
-		if err != nil {
-			return 0, err
-		}
-		if len(args) == 0 {
-			return 0, fmt.Errorf("no command specified")
-		}
-		bin := args[0]
-		args = append(args[1:], rcpts...)
-		cmd := exec.Command(bin, args...)
-		wc, err := cmd.StdinPipe()
-		if err != nil {
-			return 0, errors.Wrap(err, "cmd.StdinPipe")
-		}
-		err = cmd.Start()
-		if err != nil {
-			return 0, errors.Wrap(err, "cmd.Start")
+func connectSmtp(starttls bool, host string) (*smtp.Client, error) {
+	serverName := host
+	if !strings.ContainsRune(host, ':') {
+		host = host + ":587" // Default to submission port
+	} else {
+		serverName = host[:strings.IndexRune(host, ':')]
+	}
+	conn, err := smtp.Dial(host)
+	if err != nil {
+		return nil, errors.Wrap(err, "smtp.Dial")
+	}
+	if sup, _ := conn.Extension("STARTTLS"); sup {
+		if !starttls {
+			err := errors.New("STARTTLS is supported by this server, " +
+				"but not set in accounts.conf. " +
+				"Add smtp-starttls=yes")
+			conn.Close()
+			return nil, err
 		}
-		ctr := datacounter.NewWriterCounter(wc)
-		composer.WriteMessage(header, ctr)
-		wc.Close() // force close to make sendmail send
-		err = cmd.Wait()
-		if err != nil {
-			return 0, errors.Wrap(err, "cmd.Wait")
+		if err = conn.StartTLS(&tls.Config{
+			ServerName: serverName,
+		}); err != nil {
+			conn.Close()
+			return nil, errors.Wrap(err, "StartTLS")
 		}
-		return int(ctr.Count()), nil
-	}
-
-	sendAsync := func() (int, error) {
-		fmt.Println(scheme)
-		switch scheme {
-		case "smtp":
-			fallthrough
-		case "smtps":
-			return smtpAsync()
-		case "":
-			return sendmailAsync()
+	} else {
+		if starttls {
+			err := errors.New("STARTTLS requested, but not supported " +
+				"by this SMTP server. Is someone tampering with your " +
+				"connection?")
+			conn.Close()
+			return nil, err
 		}
-		return 0, errors.New("Unknown scheme")
 	}
+	return conn, nil
+}
 
-	go func() {
-		aerc.PushStatus("Sending...", 10*time.Second)
-		nbytes, err := sendAsync()
-		if err != nil {
-			aerc.PushError(" " + err.Error())
-			return
-		}
-		if config.CopyTo != "" {
-			aerc.PushStatus("Copying to "+config.CopyTo, 10*time.Second)
-			worker := composer.Worker()
-			r, w := io.Pipe()
-			worker.PostAction(&types.AppendMessage{
-				Destination: config.CopyTo,
-				Flags:       []models.Flag{models.SeenFlag},
-				Date:        time.Now(),
-				Reader:      r,
-				Length:      nbytes,
-			}, func(msg types.WorkerMessage) {
-				switch msg := msg.(type) {
-				case *types.Done:
-					aerc.PushStatus("Message sent.", 10*time.Second)
-					r.Close()
-					composer.SetSent()
-					composer.Close()
-				case *types.Error:
-					aerc.PushError(" " + msg.Error.Error())
-					r.Close()
-					composer.Close()
-				}
-			})
-			header, err := composer.PrepareHeader()
-			if err != nil {
-				aerc.PushError(" " + err.Error())
-				w.Close()
-				return
-			}
-			composer.WriteMessage(header, w)
-			w.Close()
-		} else {
-			aerc.PushStatus("Message sent.", 10*time.Second)
-			composer.SetSent()
-			composer.Close()
-		}
-	}()
-	return nil
+func connectSmtps(host string) (*smtp.Client, error) {
+	serverName := host
+	if !strings.ContainsRune(host, ':') {
+		host = host + ":465" // Default to smtps port
+	} else {
+		serverName = host[:strings.IndexRune(host, ':')]
+	}
+	conn, err := smtp.DialTLS(host, &tls.Config{
+		ServerName: serverName,
+	})
+	if err != nil {
+		return nil, errors.Wrap(err, "smtp.DialTLS")
+	}
+	return conn, nil
 }
 
-func listRecipients(h *mail.Header) ([]string, error) {
-	var rcpts []string
-	for _, key := range []string{"to", "cc", "bcc"} {
-		list, err := h.AddressList(key)
-		if err != nil {
-			return nil, err
-		}
-		for _, addr := range list {
-			rcpts = append(rcpts, addr.Address)
+func copyToSent(worker *types.Worker, dest string,
+	n int, msg io.Reader) <-chan error {
+	errCh := make(chan error)
+	worker.PostAction(&types.AppendMessage{
+		Destination: dest,
+		Flags:       []models.Flag{models.SeenFlag},
+		Date:        time.Now(),
+		Reader:      msg,
+		Length:      n,
+	}, func(msg types.WorkerMessage) {
+		switch msg := msg.(type) {
+		case *types.Done:
+			errCh <- nil
+		case *types.Error:
+			errCh <- msg.Error
 		}
-	}
-	return rcpts, nil
+	})
+	return errCh
 }