Skip to content

Commit 795b157

Browse files
committed
support custom AWS credential provider
1 parent f01f780 commit 795b157

25 files changed

+546
-96
lines changed

internal/aws/credentials/chain_provider.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
package credentials
1212

1313
import (
14+
"context"
15+
1416
"go.mongodb.org/mongo-driver/v2/internal/aws/awserr"
1517
)
1618

@@ -45,10 +47,10 @@ func NewChainCredentials(providers []Provider) *Credentials {
4547
//
4648
// If a provider is found it will be cached and any calls to IsExpired()
4749
// will return the expired state of the cached provider.
48-
func (c *ChainProvider) Retrieve() (Value, error) {
50+
func (c *ChainProvider) Retrieve(ctx context.Context) (Value, error) {
4951
var errs = make([]error, 0, len(c.Providers))
5052
for _, p := range c.Providers {
51-
creds, err := p.Retrieve()
53+
creds, err := p.Retrieve(ctx)
5254
if err == nil {
5355
c.curr = p
5456
return creds, nil

internal/aws/credentials/chain_provider_test.go

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
package credentials
1212

1313
import (
14+
"context"
1415
"reflect"
1516
"testing"
1617

@@ -23,7 +24,7 @@ type secondStubProvider struct {
2324
err error
2425
}
2526

26-
func (s *secondStubProvider) Retrieve() (Value, error) {
27+
func (s *secondStubProvider) Retrieve(_ context.Context) (Value, error) {
2728
s.expired = false
2829
s.creds.ProviderName = "secondStubProvider"
2930
return s.creds, s.err
@@ -54,7 +55,7 @@ func TestChainProviderWithNames(t *testing.T) {
5455
},
5556
}
5657

57-
creds, err := p.Retrieve()
58+
creds, err := p.Retrieve(context.Background())
5859
if err != nil {
5960
t.Errorf("Expect no error, got %v", err)
6061
}
@@ -90,7 +91,7 @@ func TestChainProviderGet(t *testing.T) {
9091
},
9192
}
9293

93-
creds, err := p.Retrieve()
94+
creds, err := p.Retrieve(context.Background())
9495
if err != nil {
9596
t.Errorf("Expect no error, got %v", err)
9697
}
@@ -113,10 +114,12 @@ func TestChainProviderIsExpired(t *testing.T) {
113114
},
114115
}
115116

117+
ctx := context.Background()
118+
116119
if !p.IsExpired() {
117120
t.Errorf("Expect expired to be true before any Retrieve")
118121
}
119-
_, err := p.Retrieve()
122+
_, err := p.Retrieve(ctx)
120123
if err != nil {
121124
t.Errorf("Expect no error, got %v", err)
122125
}
@@ -129,7 +132,7 @@ func TestChainProviderIsExpired(t *testing.T) {
129132
t.Errorf("Expect return of expired provider")
130133
}
131134

132-
_, err = p.Retrieve()
135+
_, err = p.Retrieve(ctx)
133136
if err != nil {
134137
t.Errorf("Expect no error, got %v", err)
135138
}
@@ -146,7 +149,7 @@ func TestChainProviderWithNoProvider(t *testing.T) {
146149
if !p.IsExpired() {
147150
t.Errorf("Expect expired with no providers")
148151
}
149-
_, err := p.Retrieve()
152+
_, err := p.Retrieve(context.Background())
150153
if err.Error() != "NoCredentialProviders: no valid providers in chain" {
151154
t.Errorf("Expect no providers error returned, got %v", err)
152155
}
@@ -167,7 +170,7 @@ func TestChainProviderWithNoValidProvider(t *testing.T) {
167170
if !p.IsExpired() {
168171
t.Errorf("Expect expired with no providers")
169172
}
170-
_, err := p.Retrieve()
173+
_, err := p.Retrieve(context.Background())
171174

172175
expectErr := awserr.NewBatchError("NoCredentialProviders", "no valid providers in chain", errs)
173176
if e, a := expectErr, err; !reflect.DeepEqual(e, a) {

internal/aws/credentials/credentials.go

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -52,20 +52,13 @@ func (v Value) HasKeys() bool {
5252
type Provider interface {
5353
// Retrieve returns nil if it successfully retrieved the value.
5454
// Error is returned if the value were not obtainable, or empty.
55-
Retrieve() (Value, error)
55+
Retrieve(context.Context) (Value, error)
5656

5757
// IsExpired returns if the credentials are no longer valid, and need
5858
// to be retrieved.
5959
IsExpired() bool
6060
}
6161

62-
// ProviderWithContext is a Provider that can retrieve credentials with a Context
63-
type ProviderWithContext interface {
64-
Provider
65-
66-
RetrieveWithContext(context.Context) (Value, error)
67-
}
68-
6962
// A Credentials provides concurrency safe retrieval of AWS credentials Value.
7063
//
7164
// A Credentials is also used to fetch Azure credentials Value.
@@ -143,13 +136,7 @@ func (c *Credentials) singleRetrieve(ctx context.Context) (interface{}, error) {
143136
return curCreds, nil
144137
}
145138

146-
var creds Value
147-
var err error
148-
if p, ok := c.provider.(ProviderWithContext); ok {
149-
creds, err = p.RetrieveWithContext(ctx)
150-
} else {
151-
creds, err = c.provider.Retrieve()
152-
}
139+
creds, err := c.provider.Retrieve(ctx)
153140
if err == nil {
154141
c.creds = creds
155142
}

internal/aws/credentials/credentials_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ type stubProvider struct {
3333
err error
3434
}
3535

36-
func (s *stubProvider) Retrieve() (Value, error) {
36+
func (s *stubProvider) Retrieve(_ context.Context) (Value, error) {
3737
s.retrievedCount++
3838
s.expired = false
3939
s.creds.ProviderName = "stubProvider"
@@ -133,7 +133,7 @@ func (e *MockProvider) IsExpired() bool {
133133
return e.expiration.Before(curTime())
134134
}
135135

136-
func (*MockProvider) Retrieve() (Value, error) {
136+
func (*MockProvider) Retrieve(_ context.Context) (Value, error) {
137137
return Value{}, nil
138138
}
139139

@@ -162,9 +162,9 @@ type stubProviderConcurrent struct {
162162
done chan struct{}
163163
}
164164

165-
func (s *stubProviderConcurrent) Retrieve() (Value, error) {
165+
func (s *stubProviderConcurrent) Retrieve(ctx context.Context) (Value, error) {
166166
<-s.done
167-
return s.stubProvider.Retrieve()
167+
return s.stubProvider.Retrieve(ctx)
168168
}
169169

170170
func TestCredentialsGetConcurrent(t *testing.T) {

internal/aws/types.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,14 @@ import (
1414
"io"
1515
)
1616

17+
// Credentials represents AWS credentials.
18+
type Credentials struct {
19+
AccessKeyID string
20+
SecretAccessKey string
21+
SessionToken string
22+
ExpirationCallback func() bool
23+
}
24+
1725
// ReadSeekCloser wraps a io.Reader returning a ReaderSeekerCloser. Allows the
1826
// SDK to accept an io.Reader that is not also an io.Seeker for unsigned
1927
// streaming payload API operations.

internal/credproviders/assume_role_provider.go

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ func NewAssumeRoleProvider(httpClient *http.Client, expiryWindow time.Duration)
5757
}
5858
}
5959

60-
// RetrieveWithContext retrieves the keys from the AWS service.
61-
func (a *AssumeRoleProvider) RetrieveWithContext(ctx context.Context) (credentials.Value, error) {
60+
// Retrieve retrieves the keys from the AWS service.
61+
func (a *AssumeRoleProvider) Retrieve(ctx context.Context) (credentials.Value, error) {
6262
const defaultHTTPTimeout = 10 * time.Second
6363

6464
v := credentials.Value{ProviderName: assumeRoleProviderName}
@@ -137,11 +137,6 @@ func (a *AssumeRoleProvider) RetrieveWithContext(ctx context.Context) (credentia
137137
return v, nil
138138
}
139139

140-
// Retrieve retrieves the keys from the AWS service.
141-
func (a *AssumeRoleProvider) Retrieve() (credentials.Value, error) {
142-
return a.RetrieveWithContext(context.Background())
143-
}
144-
145140
// IsExpired returns true if the credentials are expired.
146141
func (a *AssumeRoleProvider) IsExpired() bool {
147142
return a.expiration.Before(time.Now())
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
// Copyright (C) MongoDB, Inc. 2025-present.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License"); you may
4+
// not use this file except in compliance with the License. You may obtain
5+
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6+
7+
package credproviders
8+
9+
import (
10+
"context"
11+
12+
"go.mongodb.org/mongo-driver/v2/internal/aws"
13+
"go.mongodb.org/mongo-driver/v2/internal/aws/credentials"
14+
)
15+
16+
const awsProviderName = "AwsProvider"
17+
18+
// AwsProvider retrieves credentials from the given AWS credentials provider.
19+
type AwsProvider struct {
20+
credentials *aws.Credentials
21+
Provider func(context.Context) (aws.Credentials, error)
22+
}
23+
24+
// Retrieve retrieves the keys from the given AWS credentials provider.
25+
func (a *AwsProvider) Retrieve(ctx context.Context) (credentials.Value, error) {
26+
var value credentials.Value
27+
if a.credentials == nil {
28+
creds, err := a.Provider(ctx)
29+
if err != nil {
30+
return value, err
31+
}
32+
a.credentials = &creds
33+
}
34+
value.AccessKeyID = a.credentials.AccessKeyID
35+
value.SecretAccessKey = a.credentials.SecretAccessKey
36+
value.SessionToken = a.credentials.SessionToken
37+
value.ProviderName = awsProviderName
38+
return value, nil
39+
}
40+
41+
// IsExpired returns true if the credentials have not been retrieved.
42+
func (a *AwsProvider) IsExpired() bool {
43+
if a.credentials == nil || a.credentials.ExpirationCallback == nil {
44+
return true
45+
}
46+
return a.credentials.ExpirationCallback()
47+
}

internal/credproviders/ec2_provider.go

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -146,8 +146,8 @@ func (e *EC2Provider) getCredentials(ctx context.Context, token string, role str
146146
return v, ec2Resp.Expiration, nil
147147
}
148148

149-
// RetrieveWithContext retrieves the keys from the AWS service.
150-
func (e *EC2Provider) RetrieveWithContext(ctx context.Context) (credentials.Value, error) {
149+
// Retrieve retrieves the keys from the AWS service.
150+
func (e *EC2Provider) Retrieve(ctx context.Context) (credentials.Value, error) {
151151
v := credentials.Value{ProviderName: ec2ProviderName}
152152

153153
token, err := e.getToken(ctx)
@@ -172,11 +172,6 @@ func (e *EC2Provider) RetrieveWithContext(ctx context.Context) (credentials.Valu
172172
return v, nil
173173
}
174174

175-
// Retrieve retrieves the keys from the AWS service.
176-
func (e *EC2Provider) Retrieve() (credentials.Value, error) {
177-
return e.RetrieveWithContext(context.Background())
178-
}
179-
180175
// IsExpired returns true if the credentials are expired.
181176
func (e *EC2Provider) IsExpired() bool {
182177
return e.expiration.Before(time.Now())

internal/credproviders/ecs_provider.go

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ func NewECSProvider(httpClient *http.Client, expiryWindow time.Duration) *ECSPro
4949
}
5050
}
5151

52-
// RetrieveWithContext retrieves the keys from the AWS service.
53-
func (e *ECSProvider) RetrieveWithContext(ctx context.Context) (credentials.Value, error) {
52+
// Retrieve retrieves the keys from the AWS service.
53+
func (e *ECSProvider) Retrieve(ctx context.Context) (credentials.Value, error) {
5454
const defaultHTTPTimeout = 10 * time.Second
5555

5656
v := credentials.Value{ProviderName: ecsProviderName}
@@ -101,11 +101,6 @@ func (e *ECSProvider) RetrieveWithContext(ctx context.Context) (credentials.Valu
101101
return v, nil
102102
}
103103

104-
// Retrieve retrieves the keys from the AWS service.
105-
func (e *ECSProvider) Retrieve() (credentials.Value, error) {
106-
return e.RetrieveWithContext(context.Background())
107-
}
108-
109104
// IsExpired returns true if the credentials are expired.
110105
func (e *ECSProvider) IsExpired() bool {
111106
return e.expiration.Before(time.Now())

internal/credproviders/env_provider.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
package credproviders
88

99
import (
10+
"context"
1011
"os"
1112

1213
"go.mongodb.org/mongo-driver/v2/internal/aws/credentials"
@@ -46,7 +47,7 @@ func NewEnvProvider() *EnvProvider {
4647
}
4748

4849
// Retrieve retrieves the keys from the environment.
49-
func (e *EnvProvider) Retrieve() (credentials.Value, error) {
50+
func (e *EnvProvider) Retrieve(_ context.Context) (credentials.Value, error) {
5051
e.retrieved = false
5152

5253
v := credentials.Value{

0 commit comments

Comments
 (0)