Skip to content

Commit 8a47c8d

Browse files
feat(GODT-2567): Simulate Answered/Forwarded behavior in GPA server
1 parent c9bc6f7 commit 8a47c8d

File tree

4 files changed

+137
-21
lines changed

4 files changed

+137
-21
lines changed

server/backend/api.go

+22-2
Original file line numberDiff line numberDiff line change
@@ -636,19 +636,21 @@ func (b *Backend) DeleteMessage(userID, messageID string) error {
636636
})
637637
}
638638

639-
func (b *Backend) CreateDraft(userID, addrID string, draft proton.DraftTemplate, parentID string) (proton.Message, error) {
639+
func (b *Backend) CreateDraft(userID, addrID string, draft proton.DraftTemplate, parentID string, action proton.CreateDraftAction) (proton.Message, error) {
640640
return withAcc(b, userID, func(acc *account) (proton.Message, error) {
641641
return withMessages(b, func(messages map[string]*message) (proton.Message, error) {
642642
return withLabels(b, func(labels map[string]*label) (proton.Message, error) {
643643
// Convert the parentID into externalRef.\
644644
var parentRef string
645+
var internalParentID string
645646
if parentID != "" {
646647
parentMsg, ok := messages[parentID]
647648
if ok {
648649
parentRef = "<" + parentMsg.externalID + ">"
650+
internalParentID = parentID
649651
}
650652
}
651-
msg := newMessageFromTemplate(addrID, draft, parentRef)
653+
msg := newMessageFromTemplate(addrID, draft, parentRef, internalParentID, action)
652654
// Drafts automatically get the sysLabel "Drafts".
653655
msg.addLabel(proton.DraftsLabel, labels)
654656

@@ -712,6 +714,24 @@ func (b *Backend) SendMessage(userID, messageID string, packages []*proton.Messa
712714
msg.flags |= proton.MessageFlagSent
713715
msg.addLabel(proton.SentLabel, labels)
714716

717+
if parent, ok := messages[msg.internalParentID]; ok {
718+
switch msg.draftAction {
719+
case proton.ReplyAction:
720+
parent.flags |= proton.MessageFlagReplied
721+
case proton.ReplyAllAction:
722+
parent.flags |= proton.MessageFlagRepliedAll
723+
case proton.ForwardAction:
724+
parent.flags |= proton.MessageFlagForwarded
725+
}
726+
727+
updateID, err := b.newUpdate(&messageUpdated{messageID: msg.internalParentID})
728+
if err != nil {
729+
return proton.Message{}, err
730+
}
731+
732+
acc.updateIDs = append(acc.updateIDs, updateID)
733+
}
734+
715735
updateID, err := b.newUpdate(&messageUpdated{messageID: messageID})
716736
if err != nil {
717737
return proton.Message{}, err

server/backend/message.go

+29-15
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,13 @@ import (
1313
)
1414

1515
type message struct {
16-
messageID string
17-
externalID string
18-
addrID string
19-
labelIDs []string
20-
attIDs []string
21-
inReplyTo string
16+
messageID string
17+
externalID string
18+
addrID string
19+
labelIDs []string
20+
attIDs []string
21+
inReplyTo string
22+
internalParentID string
2223

2324
// sysLabel is the system label for the message.
2425
// If nil, the message's flags are used to determine the system label (inbox, sent, drafts).
@@ -34,6 +35,8 @@ type message struct {
3435
replytos []*mail.Address
3536
date time.Time
3637

38+
draftAction proton.CreateDraftAction
39+
3740
armBody string
3841
mimeType rfc822.MIMEType
3942

@@ -92,13 +95,20 @@ func newMessageFromSent(addrID, armBody string, msg *message) *message {
9295
}
9396
}
9497

95-
func newMessageFromTemplate(addrID string, template proton.DraftTemplate, parentRef string) *message {
98+
func newMessageFromTemplate(
99+
addrID string,
100+
template proton.DraftTemplate,
101+
parentRef string,
102+
internalParentID string,
103+
action proton.CreateDraftAction,
104+
) *message {
96105
return &message{
97-
messageID: uuid.NewString(),
98-
externalID: template.ExternalID,
99-
addrID: addrID,
100-
sysLabel: pointer(""),
101-
inReplyTo: parentRef,
106+
messageID: uuid.NewString(),
107+
externalID: template.ExternalID,
108+
addrID: addrID,
109+
sysLabel: pointer(""),
110+
inReplyTo: parentRef,
111+
internalParentID: internalParentID,
102112

103113
subject: template.Subject,
104114
sender: template.Sender,
@@ -107,6 +117,8 @@ func newMessageFromTemplate(addrID string, template proton.DraftTemplate, parent
107117
bccList: template.BCCList,
108118
unread: bool(template.Unread),
109119

120+
draftAction: action,
121+
110122
armBody: template.Body,
111123
mimeType: template.MIMEType,
112124
}
@@ -186,9 +198,11 @@ func (msg *message) toMetadata(attData map[string][]byte, att map[string]*attach
186198
ReplyTos: msg.replytos,
187199
Size: messageSize,
188200

189-
Flags: msg.flags,
190-
Unread: proton.Bool(msg.unread),
191-
IsForwarded: msg.flags&proton.MessageFlagForwarded != 0,
201+
Flags: msg.flags,
202+
Unread: proton.Bool(msg.unread),
203+
IsForwarded: msg.flags&proton.MessageFlagForwarded != 0,
204+
IsReplied: msg.flags&proton.MessageFlagReplied != 0,
205+
IsRepliedAll: msg.flags&proton.MessageFlagRepliedAll != 0,
192206

193207
NumAttachments: len(attData),
194208
}

server/messages.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ func (s *Server) postMailMessages(c *gin.Context) {
102102
return
103103
}
104104

105-
message, err := s.b.CreateDraft(c.GetString("UserID"), addrID, req.Message, req.ParentID)
105+
message, err := s.b.CreateDraft(c.GetString("UserID"), addrID, req.Message, req.ParentID, req.Action)
106106
if err != nil {
107107
c.AbortWithStatus(http.StatusUnprocessableEntity)
108108
return

server/server_test.go

+85-3
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ import (
66
"encoding/json"
77
"errors"
88
"fmt"
9-
"github.com/ProtonMail/go-proton-api/server/backend"
109
"net/http"
1110
"net/mail"
1211
"net/url"
@@ -17,14 +16,14 @@ import (
1716
"testing"
1817
"time"
1918

20-
"github.com/bradenaw/juniper/parallel"
21-
2219
"github.com/Masterminds/semver/v3"
2320
"github.com/ProtonMail/gluon/async"
2421
"github.com/ProtonMail/gluon/rfc822"
2522
"github.com/ProtonMail/go-proton-api"
23+
"github.com/ProtonMail/go-proton-api/server/backend"
2624
"github.com/ProtonMail/gopenpgp/v2/crypto"
2725
"github.com/bradenaw/juniper/iterator"
26+
"github.com/bradenaw/juniper/parallel"
2827
"github.com/bradenaw/juniper/stream"
2928
"github.com/bradenaw/juniper/xslices"
3029
"github.com/google/uuid"
@@ -2232,6 +2231,89 @@ func TestServer_GetMessageGroupCount(t *testing.T) {
22322231
})
22332232
}
22342233

2234+
func TestServer_TestDraftActions(t *testing.T) {
2235+
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
2236+
withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
2237+
ctx, cancel := context.WithCancel(ctx)
2238+
defer cancel()
2239+
2240+
user, err := c.GetUser(ctx)
2241+
require.NoError(t, err)
2242+
2243+
addr, err := c.GetAddresses(ctx)
2244+
require.NoError(t, err)
2245+
2246+
salt, err := c.GetSalts(ctx)
2247+
require.NoError(t, err)
2248+
2249+
pass, err := salt.SaltForKey([]byte("pass"), user.Keys.Primary().ID)
2250+
require.NoError(t, err)
2251+
2252+
_, addrKRs, err := proton.Unlock(user, addr, pass, async.NoopPanicHandler{})
2253+
require.NoError(t, err)
2254+
2255+
type testData struct {
2256+
action proton.CreateDraftAction
2257+
flag proton.MessageFlag
2258+
}
2259+
2260+
tests := []testData{
2261+
{
2262+
action: proton.ReplyAction,
2263+
flag: proton.MessageFlagReplied,
2264+
},
2265+
{
2266+
action: proton.ReplyAllAction,
2267+
flag: proton.MessageFlagRepliedAll,
2268+
},
2269+
{
2270+
action: proton.ForwardAction,
2271+
flag: proton.MessageFlagForwarded,
2272+
},
2273+
}
2274+
2275+
importedMessages := importMessages(ctx, t, c, addr[0].ID, addrKRs[addr[0].ID], []string{}, 0, len(tests))
2276+
2277+
for i := 0; i < len(tests); i++ {
2278+
importedMessageID := importedMessages[i].MessageID
2279+
2280+
msg, err := c.GetMessage(ctx, importedMessageID)
2281+
require.NoError(t, err)
2282+
2283+
{
2284+
kr := addrKRs[addr[0].ID]
2285+
msg, err := c.CreateDraft(ctx, kr, proton.CreateDraftReq{
2286+
Message: proton.DraftTemplate{
2287+
Subject: "Foo",
2288+
Sender: &mail.Address{Address: addr[0].Email},
2289+
ToList: []*mail.Address{{Address: "foo@bar"}},
2290+
CCList: nil,
2291+
BCCList: nil,
2292+
},
2293+
AttachmentKeyPackets: nil,
2294+
ParentID: msg.ID,
2295+
Action: tests[i].action,
2296+
})
2297+
2298+
require.NoError(t, err)
2299+
2300+
var sreq proton.SendDraftReq
2301+
2302+
require.NoError(t, sreq.AddTextPackage(kr, "Hello", "text/plain", map[string]proton.SendPreferences{}, map[string]*crypto.SessionKey{}))
2303+
2304+
_, err = c.SendDraft(ctx, msg.ID, sreq)
2305+
require.NoError(t, err)
2306+
2307+
msg, err = c.GetMessage(ctx, importedMessageID)
2308+
require.NoError(t, err)
2309+
require.True(t, msg.Flags&tests[i].flag != 0)
2310+
}
2311+
}
2312+
2313+
})
2314+
})
2315+
}
2316+
22352317
func withServer(t *testing.T, fn func(ctx context.Context, s *Server, m *proton.Manager), opts ...Option) {
22362318
ctx, cancel := context.WithCancel(context.Background())
22372319
defer cancel()

0 commit comments

Comments
 (0)