Skip to content

Commit 4fea597

Browse files
committed
update Credentials struct
1 parent 795b157 commit 4fea597

File tree

8 files changed

+43
-39
lines changed

8 files changed

+43
-39
lines changed

internal/aws/types.go

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,29 @@ package aws
1212

1313
import (
1414
"io"
15+
"time"
1516
)
1617

1718
// Credentials represents AWS credentials.
1819
type Credentials struct {
19-
AccessKeyID string
20-
SecretAccessKey string
21-
SessionToken string
22-
ExpirationCallback func() bool
20+
AccessKeyID string
21+
SecretAccessKey string
22+
SessionToken string
23+
Source string
24+
CanExpire bool
25+
Expires time.Time
26+
AccountID string
27+
}
28+
29+
func (v Credentials) Expired() bool {
30+
if v.CanExpire {
31+
// Calling Round(0) on the current time will truncate the monotonic
32+
// reading only. Ensures credential expiry time is always based on
33+
// reported wall-clock time.
34+
return !v.Expires.After(time.Now().Round(0))
35+
}
36+
37+
return false
2338
}
2439

2540
// ReadSeekCloser wraps a io.Reader returning a ReaderSeekerCloser. Allows the

internal/credproviders/aws_provider.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ func (a *AwsProvider) Retrieve(ctx context.Context) (credentials.Value, error) {
4040

4141
// IsExpired returns true if the credentials have not been retrieved.
4242
func (a *AwsProvider) IsExpired() bool {
43-
if a.credentials == nil || a.credentials.ExpirationCallback == nil {
43+
if a.credentials == nil {
4444
return true
4545
}
46-
return a.credentials.ExpirationCallback()
46+
return a.credentials.Expired()
4747
}

internal/integration/client_side_encryption_prose_test.go

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3190,9 +3190,8 @@ func TestCustomAwsCredentialsProse(t *testing.T) {
31903190
"aws": func(_ context.Context) (options.Credentials, error) {
31913191
calledCount++
31923192
return options.Credentials{
3193-
AccessKeyID: awsAccessKeyID,
3194-
SecretAccessKey: awsSecretAccessKey,
3195-
ExpirationCallback: func() bool { return false },
3193+
AccessKeyID: awsAccessKeyID,
3194+
SecretAccessKey: awsSecretAccessKey,
31963195
}, nil
31973196
},
31983197
})
@@ -3248,9 +3247,8 @@ func TestCustomAwsCredentialsProse(t *testing.T) {
32483247
"aws": func(ctx context.Context) (options.Credentials, error) {
32493248
calledCount++
32503249
return options.Credentials{
3251-
AccessKeyID: awsAccessKeyID,
3252-
SecretAccessKey: awsSecretAccessKey,
3253-
ExpirationCallback: func() bool { return false },
3250+
AccessKeyID: awsAccessKeyID,
3251+
SecretAccessKey: awsSecretAccessKey,
32543252
}, nil
32553253
},
32563254
})

mongo/client.go

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -603,16 +603,11 @@ func (c *Client) newMongoCrypt(opts *options.AutoEncryptionOptions) (*mongocrypt
603603
if k == "aws" && fn != nil {
604604
providers[k] = &credproviders.AwsProvider{
605605
Provider: func(ctx context.Context) (aws.Credentials, error) {
606-
var creds aws.Credentials
607606
c, err := fn(ctx)
608607
if err != nil {
609-
return creds, err
608+
return aws.Credentials{}, err
610609
}
611-
creds.AccessKeyID = c.AccessKeyID
612-
creds.SecretAccessKey = c.SecretAccessKey
613-
creds.SessionToken = c.SessionToken
614-
creds.ExpirationCallback = c.ExpirationCallback
615-
return creds, nil
610+
return aws.Credentials(c), nil
616611
},
617612
}
618613
}

mongo/client_encryption.go

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,16 +61,11 @@ func NewClientEncryption(keyVaultClient *Client, opts ...options.Lister[options.
6161
if k == "aws" && fn != nil {
6262
providers[k] = &credproviders.AwsProvider{
6363
Provider: func(ctx context.Context) (aws.Credentials, error) {
64-
var creds aws.Credentials
6564
c, err := fn(ctx)
6665
if err != nil {
67-
return creds, err
66+
return aws.Credentials{}, err
6867
}
69-
creds.AccessKeyID = c.AccessKeyID
70-
creds.SecretAccessKey = c.SecretAccessKey
71-
creds.SessionToken = c.SessionToken
72-
creds.ExpirationCallback = c.ExpirationCallback
73-
return creds, nil
68+
return aws.Credentials(c), nil
7469
},
7570
}
7671
}

mongo/client_examples_test.go

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -363,10 +363,9 @@ func ExampleConnect_aWS() {
363363
AwsCredentialsProvider: func(_ context.Context) (
364364
options.Credentials, error) {
365365
return options.Credentials{
366-
AccessKeyID: accessKeyID,
367-
SecretAccessKey: secretAccessKey,
368-
SessionToken: sessionToken,
369-
ExpirationCallback: func() bool { return false },
366+
AccessKeyID: accessKeyID,
367+
SecretAccessKey: secretAccessKey,
368+
SessionToken: sessionToken,
370369
}, nil
371370
},
372371
}

mongo/options/clientoptions.go

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ import (
2525
"github.com/youmark/pkcs8"
2626
"go.mongodb.org/mongo-driver/v2/bson"
2727
"go.mongodb.org/mongo-driver/v2/event"
28-
"go.mongodb.org/mongo-driver/v2/internal/aws"
2928
"go.mongodb.org/mongo-driver/v2/internal/httputil"
3029
"go.mongodb.org/mongo-driver/v2/internal/optionsutil"
3130
"go.mongodb.org/mongo-driver/v2/mongo/readconcern"
@@ -117,7 +116,7 @@ type Credential struct {
117116
PasswordSet bool
118117
OIDCMachineCallback OIDCCallback
119118
OIDCHumanCallback OIDCCallback
120-
AwsCredentialsProvider func(context.Context) (Credentials, error)
119+
AwsCredentialsProvider CredentialsProvider
121120
}
122121

123122
// OIDCCallback is the type for both Human and Machine Callback flows.
@@ -150,7 +149,15 @@ type IDPInfo struct {
150149
type CredentialsProvider func(context.Context) (Credentials, error)
151150

152151
// Credentials represents AWS credentials.
153-
type Credentials aws.Credentials
152+
type Credentials struct {
153+
AccessKeyID string
154+
SecretAccessKey string
155+
SessionToken string
156+
Source string
157+
CanExpire bool
158+
Expires time.Time
159+
AccountID string
160+
}
154161

155162
// BSONOptions are optional BSON marshaling and unmarshaling behaviors.
156163
type BSONOptions struct {

x/mongo/driver/topology/topology_options.go

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -117,12 +117,7 @@ func ConvertCreds(cred *options.Credential) *driver.Cred {
117117
if cred.AwsCredentialsProvider != nil {
118118
awsCredentialsProvider = func(ctx context.Context) (aws.Credentials, error) {
119119
creds, err := cred.AwsCredentialsProvider(ctx)
120-
return aws.Credentials{
121-
AccessKeyID: creds.AccessKeyID,
122-
SecretAccessKey: creds.SecretAccessKey,
123-
SessionToken: creds.SessionToken,
124-
ExpirationCallback: creds.ExpirationCallback,
125-
}, err
120+
return aws.Credentials(creds), err
126121
}
127122
}
128123

0 commit comments

Comments
 (0)