diff --git a/config/config.go b/config/config.go index 59cb706..7dca277 100644 --- a/config/config.go +++ b/config/config.go @@ -1,10 +1,12 @@ package config import ( + "crypto/tls" "encoding/json" "flag" "io/ioutil" "log" + "strings" "github.com/ian-kent/envconf" "github.com/mailhog/MailHog-Server/monkey" @@ -49,6 +51,8 @@ type Config struct { OutgoingSMTPFile string OutgoingSMTP map[string]*OutgoingSMTP WebPath string + CertsPaths string + TLSConfig *tls.Config } // OutgoingSMTP is an outgoing SMTP server config @@ -112,6 +116,32 @@ func Configure() *Config { cfg.OutgoingSMTP = o } + if cfg.CertsPaths != "" { + pairCandidates := strings.Split(cfg.CertsPaths, ";") + + if len(pairCandidates) > 0 { + certificates := make([]tls.Certificate, len(pairCandidates)) + for i, pairCandidate := range pairCandidates { + pair := strings.Split(pairCandidate, ",") + + if len(pair) != 2 { + log.Fatalf("Certificate path pair %d must be in form certPath,keyPath", i) + } + + cert, err := tls.LoadX509KeyPair(pair[0], pair[1]) + if err != nil { + log.Fatal(err) + } + + certificates[i] = cert + } + + cfg.TLSConfig = &tls.Config{ + Certificates: certificates, + } + } + } + return cfg } @@ -128,5 +158,6 @@ func RegisterFlags() { flag.StringVar(&cfg.MaildirPath, "maildir-path", envconf.FromEnvP("MH_MAILDIR_PATH", "").(string), "Maildir path (if storage type is 'maildir')") flag.BoolVar(&cfg.InviteJim, "invite-jim", envconf.FromEnvP("MH_INVITE_JIM", false).(bool), "Decide whether to invite Jim (beware, he causes trouble)") flag.StringVar(&cfg.OutgoingSMTPFile, "outgoing-smtp", envconf.FromEnvP("MH_OUTGOING_SMTP", "").(string), "JSON file containing outgoing SMTP servers") + flag.StringVar(&cfg.CertsPaths, "certs-paths", envconf.FromEnvP("MH_CERTS_PATHS", "").(string), "A comma separated list of tls certificates, in schema cert1Path,key1Path;cert1Path,key2Path ... etc") Jim.RegisterFlags() } diff --git a/smtp/session.go b/smtp/session.go index 10a220d..79ffc3d 100644 --- a/smtp/session.go +++ b/smtp/session.go @@ -31,7 +31,7 @@ type Session struct { } // Accept starts a new SMTP session using io.ReadWriteCloser -func Accept(remoteAddress string, conn io.ReadWriteCloser, storage storage.Storage, messageChan chan *data.Message, hostname string, monkey monkey.ChaosMonkey) { +func Accept(remoteAddress string, conn io.ReadWriteCloser, tlsUpgrader func() io.ReadWriteCloser, storage storage.Storage, messageChan chan *data.Message, hostname string, monkey monkey.ChaosMonkey) { defer conn.Close() proto := smtp.NewProtocol() @@ -56,10 +56,30 @@ func Accept(remoteAddress string, conn io.ReadWriteCloser, storage storage.Stora proto.ValidateAuthenticationHandler = session.validateAuthentication proto.GetAuthenticationMechanismsHandler = func() []string { return []string{"PLAIN"} } + if tlsUpgrader != nil { + proto.TLSHandler = func(done func(ok bool)) (errorReply *smtp.Reply, callback func(), ok bool) { + done(true) + return nil, func() { + newCon := tlsUpgrader() + + session.reader = io.Reader(newCon) + session.writer = io.Writer(newCon) + if monkey != nil { + linkSpeed := monkey.LinkSpeed() + if linkSpeed != nil { + link = linkio.NewLink(*linkSpeed * linkio.BytePerSecond) + session.reader = link.NewLinkReader(io.Reader(newCon)) + session.writer = link.NewLinkWriter(io.Writer(newCon)) + } + } + }, true + } + } + session.logf("Starting session") session.Write(proto.Start()) for session.Read() == true { - if monkey != nil && monkey.Disconnect != nil && monkey.Disconnect() { + if monkey != nil && monkey.Disconnect() { session.conn.Close() break } @@ -160,4 +180,8 @@ func (c *Session) Write(reply *smtp.Reply) { c.logf("Sent %d bytes: '%s'", len(l), logText) c.writer.Write([]byte(l)) } + + if reply.Done != nil { + reply.Done() + } } diff --git a/smtp/session_test.go b/smtp/session_test.go index 27f3cef..4dafc64 100644 --- a/smtp/session_test.go +++ b/smtp/session_test.go @@ -2,6 +2,7 @@ package smtp import ( "errors" + "io" "sync" "testing" @@ -40,7 +41,7 @@ func TestAccept(t *testing.T) { Convey("Accept should handle a connection", t, func() { frw := &fakeRw{} mChan := make(chan *data.Message) - Accept("1.1.1.1:11111", frw, storage.CreateInMemory(), mChan, "localhost", nil) + Accept("1.1.1.1:11111", frw, nil, storage.CreateInMemory(), mChan, "localhost", nil) }) } @@ -52,58 +53,134 @@ func TestSocketError(t *testing.T) { }, } mChan := make(chan *data.Message) - Accept("1.1.1.1:11111", frw, storage.CreateInMemory(), mChan, "localhost", nil) + Accept("1.1.1.1:11111", frw, nil, storage.CreateInMemory(), mChan, "localhost", nil) }) } func TestAcceptMessage(t *testing.T) { Convey("acceptMessage should be called", t, func() { - mbuf := "EHLO localhost\nMAIL FROM:\nRCPT TO:\nDATA\nHi.\r\n.\r\nQUIT\n" - var rbuf []byte - frw := &fakeRw{ - _read: func(p []byte) (n int, err error) { - if len(p) >= len(mbuf) { - ba := []byte(mbuf) - mbuf = "" - for i, b := range ba { - p[i] = b - } - return len(ba), nil - } + mbuf := "EHLO localhost\r\n" + + "MAIL FROM:\r\n" + + "RCPT TO:\r\n" + + "DATA\r\n" + + "Hi.\r\n" + + ".\r\n" + + "QUIT\n" + + frw, obuf := getBuffer(mbuf) + mChan := make(chan *data.Message) + var wg sync.WaitGroup + wg.Add(1) + handlerCalled := false + var storedMessage *data.Message + go func() { + handlerCalled = true + storedMessage = <-mChan + wg.Done() + }() + Accept("1.1.1.1:11111", frw, nil, storage.CreateInMemory(), mChan, "localhost", nil) + wg.Wait() - ba := []byte(mbuf[0:len(p)]) - mbuf = mbuf[len(p):] - for i, b := range ba { - p[i] = b - } - return len(ba), nil - }, - _write: func(p []byte) (n int, err error) { - rbuf = append(rbuf, p...) - return len(p), nil - }, - _close: func() error { - return nil - }, - } + So(handlerCalled, ShouldBeTrue) + + So(storedMessage, ShouldNotBeNil) + So(string(*obuf), ShouldEqual, + "220 localhost ESMTP MailHog\r\n"+ + "250-Hello localhost\r\n"+ + "250-PIPELINING\r\n"+ + "250 AUTH PLAIN\r\n"+ + "250 Sender test ok\r\n"+ + "250 Recipient test ok\r\n"+ + "354 End data with .\r\n"+ + "250 Ok: queued as "+storedMessage.ID+"\r\n", + ) + }) +} + +func TestAcceptTLSUpgrade(t *testing.T) { + Convey("acceptMessage should be called", t, func() { + mbuf1 := "STARTTLS\r\n" + mbuf2 := "EHLO localhost\r\n" + + "MAIL FROM:\r\n" + + "RCPT TO:\r\n" + + "DATA\r\n" + + "Hi.\r\n" + + ".\r\n" + + "QUIT\n" + + frw1, obuf1 := getBuffer(mbuf1) + frw2, obuf2 := getBuffer(mbuf2) mChan := make(chan *data.Message) var wg sync.WaitGroup wg.Add(1) handlerCalled := false + var storedMessage *data.Message go func() { handlerCalled = true - <-mChan - //FIXME breaks some tests (in drone.io) - //m := <-mChan - //So(m, ShouldNotBeNil) + storedMessage = <-mChan wg.Done() }() - Accept("1.1.1.1:11111", frw, storage.CreateInMemory(), mChan, "localhost", nil) + + tlsWasUpgraded := false + tlsUpgrade := func() io.ReadWriteCloser { + tlsWasUpgraded = true + return frw2 + } + + Accept("1.1.1.1:11111", frw1, tlsUpgrade, storage.CreateInMemory(), mChan, "localhost", nil) wg.Wait() + So(handlerCalled, ShouldBeTrue) + So(tlsWasUpgraded, ShouldBeTrue) + + So(storedMessage, ShouldNotBeNil) + So(string(*obuf1), ShouldEqual, + "220 localhost ESMTP MailHog\r\n"+ + "220 Ready to start TLS\r\n", + ) + So(string(*obuf2), ShouldEqual, + "250-Hello localhost\r\n"+ + "250-PIPELINING\r\n"+ + "250 AUTH PLAIN\r\n"+ + "250 Sender test ok\r\n"+ + "250 Recipient test ok\r\n"+ + "354 End data with .\r\n"+ + "250 Ok: queued as "+storedMessage.ID+"\r\n", + ) }) } +func getBuffer(input string) (io.ReadWriteCloser, *[]byte) { + var rbuf []byte + frw := &fakeRw{ + _read: func(p []byte) (n int, err error) { + if len(p) >= len(input) { + ba := []byte(input) + input = "" + for i, b := range ba { + p[i] = b + } + return len(ba), nil + } + + ba := []byte(input[0:len(p)]) + input = input[len(p):] + for i, b := range ba { + p[i] = b + } + return len(ba), nil + }, + _write: func(p []byte) (n int, err error) { + rbuf = append(rbuf, p...) + return len(p), nil + }, + _close: func() error { + return nil + }, + } + return frw, &rbuf +} + func TestValidateAuthentication(t *testing.T) { Convey("validateAuthentication is always successful", t, func() { c := &Session{} diff --git a/smtp/smtp.go b/smtp/smtp.go index 38a9b51..a045e80 100644 --- a/smtp/smtp.go +++ b/smtp/smtp.go @@ -1,6 +1,7 @@ package smtp import ( + "crypto/tls" "io" "log" "net" @@ -31,9 +32,17 @@ func Listen(cfg *config.Config, exitCh chan int) *net.TCPListener { } } + var tlsUpgrader func() io.ReadWriteCloser + if cfg.TLSConfig != nil { + tlsUpgrader = func() io.ReadWriteCloser { + return io.ReadWriteCloser(tls.Server(conn, cfg.TLSConfig)) + } + } + go Accept( conn.(*net.TCPConn).RemoteAddr().String(), io.ReadWriteCloser(conn), + tlsUpgrader, cfg.Storage, cfg.MessageChan, cfg.Hostname,