Skip to content

Commit 1a5bc81

Browse files
committed
update Credentials struct
1 parent 795b157 commit 1a5bc81

File tree

7 files changed

+43
-37
lines changed

7 files changed

+43
-37
lines changed

internal/aws/types.go

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,29 @@ package aws
1212

1313
import (
1414
"io"
15+
"time"
1516
)
1617

1718
// Credentials represents AWS credentials.
1819
type Credentials struct {
19-
AccessKeyID string
20-
SecretAccessKey string
21-
SessionToken string
22-
ExpirationCallback func() bool
20+
AccessKeyID string
21+
SecretAccessKey string
22+
SessionToken string
23+
Source string
24+
CanExpire bool
25+
Expires time.Time
26+
AccountID string
27+
}
28+
29+
func (v Credentials) Expired() bool {
30+
if v.CanExpire {
31+
// Calling Round(0) on the current time will truncate the monotonic
32+
// reading only. Ensures credential expiry time is always based on
33+
// reported wall-clock time.
34+
return !v.Expires.After(time.Now().Round(0))
35+
}
36+
37+
return false
2338
}
2439

2540
// ReadSeekCloser wraps a io.Reader returning a ReaderSeekerCloser. Allows the

internal/credproviders/aws_provider.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ func (a *AwsProvider) Retrieve(ctx context.Context) (credentials.Value, error) {
4040

4141
// IsExpired returns true if the credentials have not been retrieved.
4242
func (a *AwsProvider) IsExpired() bool {
43-
if a.credentials == nil || a.credentials.ExpirationCallback == nil {
43+
if a.credentials == nil {
4444
return true
4545
}
46-
return a.credentials.ExpirationCallback()
46+
return a.credentials.Expired()
4747
}

mongo/client.go

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -600,19 +600,14 @@ func (c *Client) newMongoCrypt(opts *options.AutoEncryptionOptions) (*mongocrypt
600600

601601
providers := make(map[string]credentials.Provider)
602602
for k, fn := range opts.CredentialProviders {
603-
if k == "aws" && fn != nil {
603+
if provider := fn; k == "aws" && provider != nil {
604604
providers[k] = &credproviders.AwsProvider{
605605
Provider: func(ctx context.Context) (aws.Credentials, error) {
606-
var creds aws.Credentials
607-
c, err := fn(ctx)
606+
c, err := provider(ctx)
608607
if err != nil {
609-
return creds, err
608+
return aws.Credentials{}, err
610609
}
611-
creds.AccessKeyID = c.AccessKeyID
612-
creds.SecretAccessKey = c.SecretAccessKey
613-
creds.SessionToken = c.SessionToken
614-
creds.ExpirationCallback = c.ExpirationCallback
615-
return creds, nil
610+
return aws.Credentials(c), nil
616611
},
617612
}
618613
}

mongo/client_encryption.go

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -58,19 +58,14 @@ func NewClientEncryption(keyVaultClient *Client, opts ...options.Lister[options.
5858

5959
providers := make(map[string]credentials.Provider)
6060
for k, fn := range cea.CredentialProviders {
61-
if k == "aws" && fn != nil {
61+
if provider := fn; k == "aws" && provider != nil {
6262
providers[k] = &credproviders.AwsProvider{
6363
Provider: func(ctx context.Context) (aws.Credentials, error) {
64-
var creds aws.Credentials
65-
c, err := fn(ctx)
64+
c, err := provider(ctx)
6665
if err != nil {
67-
return creds, err
66+
return aws.Credentials{}, err
6867
}
69-
creds.AccessKeyID = c.AccessKeyID
70-
creds.SecretAccessKey = c.SecretAccessKey
71-
creds.SessionToken = c.SessionToken
72-
creds.ExpirationCallback = c.ExpirationCallback
73-
return creds, nil
68+
return aws.Credentials(c), nil
7469
},
7570
}
7671
}

mongo/client_examples_test.go

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -363,10 +363,9 @@ func ExampleConnect_aWS() {
363363
AwsCredentialsProvider: func(_ context.Context) (
364364
options.Credentials, error) {
365365
return options.Credentials{
366-
AccessKeyID: accessKeyID,
367-
SecretAccessKey: secretAccessKey,
368-
SessionToken: sessionToken,
369-
ExpirationCallback: func() bool { return false },
366+
AccessKeyID: accessKeyID,
367+
SecretAccessKey: secretAccessKey,
368+
SessionToken: sessionToken,
370369
}, nil
371370
},
372371
}

mongo/options/clientoptions.go

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ import (
2525
"github.com/youmark/pkcs8"
2626
"go.mongodb.org/mongo-driver/v2/bson"
2727
"go.mongodb.org/mongo-driver/v2/event"
28-
"go.mongodb.org/mongo-driver/v2/internal/aws"
2928
"go.mongodb.org/mongo-driver/v2/internal/httputil"
3029
"go.mongodb.org/mongo-driver/v2/internal/optionsutil"
3130
"go.mongodb.org/mongo-driver/v2/mongo/readconcern"
@@ -117,7 +116,7 @@ type Credential struct {
117116
PasswordSet bool
118117
OIDCMachineCallback OIDCCallback
119118
OIDCHumanCallback OIDCCallback
120-
AwsCredentialsProvider func(context.Context) (Credentials, error)
119+
AwsCredentialsProvider CredentialsProvider
121120
}
122121

123122
// OIDCCallback is the type for both Human and Machine Callback flows.
@@ -150,7 +149,15 @@ type IDPInfo struct {
150149
type CredentialsProvider func(context.Context) (Credentials, error)
151150

152151
// Credentials represents AWS credentials.
153-
type Credentials aws.Credentials
152+
type Credentials struct {
153+
AccessKeyID string
154+
SecretAccessKey string
155+
SessionToken string
156+
Source string
157+
CanExpire bool
158+
Expires time.Time
159+
AccountID string
160+
}
154161

155162
// BSONOptions are optional BSON marshaling and unmarshaling behaviors.
156163
type BSONOptions struct {

x/mongo/driver/topology/topology_options.go

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -117,12 +117,7 @@ func ConvertCreds(cred *options.Credential) *driver.Cred {
117117
if cred.AwsCredentialsProvider != nil {
118118
awsCredentialsProvider = func(ctx context.Context) (aws.Credentials, error) {
119119
creds, err := cred.AwsCredentialsProvider(ctx)
120-
return aws.Credentials{
121-
AccessKeyID: creds.AccessKeyID,
122-
SecretAccessKey: creds.SecretAccessKey,
123-
SessionToken: creds.SessionToken,
124-
ExpirationCallback: creds.ExpirationCallback,
125-
}, err
120+
return aws.Credentials(creds), err
126121
}
127122
}
128123

0 commit comments

Comments
 (0)