Skip to content

Commit

Permalink
Update AWS SDK Go to v2
Browse files Browse the repository at this point in the history
Signed-off-by: Micah Hausler <[email protected]>
  • Loading branch information
micahhausler committed Feb 12, 2025
1 parent 1d552d5 commit 9ee7592
Show file tree
Hide file tree
Showing 19 changed files with 928 additions and 1,022 deletions.
108 changes: 62 additions & 46 deletions auth/auth.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Package responsible for returning an AWS SDK session with credentials
* Package responsible for returning an AWS SDK config with credentials
* given an AWS region, K8s namespace, and K8s service account.
*
* This package requries that the K8s service account be associated with an IAM
Expand All @@ -10,13 +10,13 @@ package auth
import (
"context"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/aws/aws-sdk-go/service/sts/stsiface"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials/stscreds"
"github.com/aws/aws-sdk-go-v2/service/sts"
"github.com/aws/secrets-store-csi-driver-provider-aws/credential_provider"
"github.com/aws/smithy-go/middleware"
smithyhttp "github.com/aws/smithy-go/transport/http"
k8sv1 "k8s.io/client-go/kubernetes/typed/core/v1"
"k8s.io/klog/v2"
)
Expand All @@ -25,85 +25,101 @@ const (
ProviderName = "secrets-store-csi-driver-provider-aws"
)

// Auth is the main entry point to retrieve an AWS session. The caller
// Auth is the main entry point to retrieve an AWS config. The caller
// initializes a new Auth object with NewAuth passing the region, namespace, pod name,
// K8s service account and usePodIdentity flag (and request context). The caller can then obtain AWS
// sessions by calling GetAWSSession.
// config by calling GetAWSConfig.
type Auth struct {
region, nameSpace, svcAcc, podName, preferredAddressType string
usePodIdentity bool
k8sClient k8sv1.CoreV1Interface
stsClient stsiface.STSAPI
ctx context.Context
region string
namespace string
serviceAccount string
podName string
preferredAddressType string
usePodIdentity bool
k8sClient k8sv1.CoreV1Interface
stsClient stscreds.AssumeRoleWithWebIdentityAPIClient
}

// Factory method to create a new Auth object for an incomming mount request.
// NewAuth creates an Auth object for an incoming mount request.
func NewAuth(
ctx context.Context,
region, nameSpace, svcAcc, podName, preferredAddressType string,
region, namespace, serviceAccount, podName, preferredAddressType string,
usePodIdentity bool,
k8sClient k8sv1.CoreV1Interface,
) (auth *Auth, e error) {
var stsClient stsiface.STSAPI
var stsClient *sts.Client

if !usePodIdentity {
// Get an initial session to use for STS calls when using IRSA
sess, err := session.NewSession(aws.NewConfig().
WithSTSRegionalEndpoint(endpoints.RegionalSTSEndpoint).
WithRegion(region),
// Get an initial config to use for STS calls when using IRSA
cfg, err := config.LoadDefaultConfig(context.Background(),
config.WithRegion(region),
config.WithDefaultsMode(aws.DefaultsModeStandard),
)
if err != nil {
return nil, err
}
stsClient = sts.New(sess)
stsClient = sts.NewFromConfig(cfg)
}

return &Auth{
region: region,
nameSpace: nameSpace,
svcAcc: svcAcc,
namespace: namespace,
serviceAccount: serviceAccount,
podName: podName,
preferredAddressType: preferredAddressType,
usePodIdentity: usePodIdentity,
k8sClient: k8sClient,
stsClient: stsClient,
ctx: ctx,
}, nil

}

// Get the AWS session credentials associated with a given pod's service account.
// Get the AWS config associated with a given pod's service account.
//
// The returned session is capable of automatically refreshing creds as needed
// The returned config is capable of automatically refreshing creds as needed
// by using a private TokenFetcher helper.
func (p Auth) GetAWSSession() (awsSession *session.Session, e error) {
var credProvider credential_provider.CredentialProvider
func (p Auth) GetAWSConfig(ctx context.Context) (aws.Config, error) {
var credProvider credential_provider.ConfigProvider

if p.usePodIdentity {
klog.Infof("Using Pod Identity for authentication in namespace: %s, service account: %s", p.nameSpace, p.svcAcc)
klog.Infof("Using Pod Identity for authentication in namespace: %s, service account: %s", p.namespace, p.serviceAccount)
var err error
credProvider, err = credential_provider.NewPodIdentityCredentialProvider(p.region, p.nameSpace, p.svcAcc, p.podName, p.preferredAddressType, p.k8sClient)
credProvider, err = credential_provider.NewPodIdentityCredentialProvider(p.region, p.namespace, p.serviceAccount, p.podName, p.preferredAddressType, p.k8sClient)
if err != nil {
return nil, err
return aws.Config{}, err
}
} else {
klog.Infof("Using IAM Roles for Service Accounts for authentication in namespace: %s, service account: %s", p.nameSpace, p.svcAcc)
credProvider = credential_provider.NewIRSACredentialProvider(p.stsClient, p.region, p.nameSpace, p.svcAcc, p.k8sClient, p.ctx)
klog.Infof("Using IAM Roles for Service Accounts for authentication in namespace: %s, service account: %s", p.namespace, p.serviceAccount)
credProvider = credential_provider.NewIRSACredentialProvider(p.stsClient, p.region, p.namespace, p.serviceAccount, p.k8sClient)
}

config, err := credProvider.GetAWSConfig()
cfg, err := credProvider.GetAWSConfig(ctx)
if err != nil {
return nil, err
return aws.Config{}, err
}

// Include the provider in the user agent string.
sess, err := session.NewSession(config)
if err != nil {
return nil, err
}
sess.Handlers.Build.PushFront(func(r *request.Request) {
request.AddToUserAgent(r, ProviderName)
// add the user agent to the config
cfg.APIOptions = append(cfg.APIOptions, func(stack *middleware.Stack) error {
return stack.Build.Add(&userAgentMiddleware{
providerName: ProviderName,
}, middleware.After)
})

return session.Must(sess, err), nil
return cfg, nil
}

type userAgentMiddleware struct {
providerName string
}

func (m *userAgentMiddleware) ID() string {
return "UserAgent"
}

func (m *userAgentMiddleware) HandleBuild(ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler) (
out middleware.BuildOutput, metadata middleware.Metadata, err error) {
req, ok := in.Request.(*smithyhttp.Request)
if !ok {
return next.HandleBuild(ctx, in)
}
req.Header.Set("User-Agent", m.providerName+" "+req.Header.Get("User-Agent"))
return next.HandleBuild(ctx, in)
}
83 changes: 47 additions & 36 deletions auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,65 +2,76 @@ package auth

import (
"context"
"fmt"
"strings"
"testing"

"github.com/aws/aws-sdk-go/service/sts/stsiface"
"github.com/aws/aws-sdk-go-v2/service/sts"
"k8s.io/client-go/kubernetes/fake"
)

// Mock STS client
type mockSTS struct {
stsiface.STSAPI
sts.Client
}

func (m *mockSTS) AssumeRoleWithWebIdentity(ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error) {
return nil, fmt.Errorf("fake error for serviceaccounst")
}

type sessionTest struct {
testName string
testPodIdentity bool
expError string
}

var sessionTests []sessionTest = []sessionTest{
{
testName: "IRSA",
testPodIdentity: false,
expError: "serviceaccounts", // IRSA path will fail at getting service account since using fake client
},
{
testName: "Pod Identity",
testPodIdentity: true,
expError: "failed to fetch token", // Pod Identity path will fail fetching token since using fake client
},
cfgError string
}

func TestGetAWSSession(t *testing.T) {
for _, tstData := range sessionTests {
t.Run(tstData.testName, func(t *testing.T) {
cases := []sessionTest{
{
testName: "IRSA",
testPodIdentity: false,
cfgError: "serviceaccounts", // IRSA path will fail at getting creds since its in the hot path of the config

},
{
testName: "Pod Identity",
testPodIdentity: true,
cfgError: "", // Pod Identity path succeeds since token is lazy loaded
},
}
for _, tt := range cases {
t.Run(tt.testName, func(t *testing.T) {

auth := &Auth{
region: "someRegion",
nameSpace: "someNamespace",
svcAcc: "someSvcAcc",
podName: "somepod",
usePodIdentity: tstData.testPodIdentity,
k8sClient: fake.NewSimpleClientset().CoreV1(),
stsClient: &mockSTS{},
ctx: context.Background(),
auth, err := NewAuth(
"someRegion",
"someNamespace",
"someSvcAcc",
"somepod",
"",
tt.testPodIdentity,
fake.NewSimpleClientset().CoreV1(),
)
if err != nil {
t.Fatalf("%s case: failed to create auth: %v", tt.testName, err)
}
auth.stsClient = &mockSTS{}
auth.k8sClient = fake.NewSimpleClientset().CoreV1()

sess, err := auth.GetAWSSession()
cfg, err := auth.GetAWSConfig(context.Background())

if len(tstData.expError) == 0 && err != nil {
t.Errorf("%s case: got unexpected auth error: %s", tstData.testName, err)
if len(tt.cfgError) == 0 && err != nil {
t.Errorf("%s case: got unexpected auth error: %s", tt.testName, err)
}
if len(tstData.expError) == 0 && sess == nil {
t.Errorf("%s case: got empty session", tstData.testName)
if len(tt.cfgError) == 0 && cfg.Credentials == nil {
t.Errorf("%s case: got empty credentials", tt.testName)
}
if len(tstData.expError) != 0 && err == nil {
t.Errorf("%s case: expected error but got none", tstData.testName)
if len(tt.cfgError) != 0 && err == nil {
t.Errorf("%s case: expected error but got none", tt.testName)
}
if len(tstData.expError) != 0 && !strings.Contains(err.Error(), tstData.expError) {
t.Errorf("%s case: expected error prefix '%s' but got '%s'", tstData.testName, tstData.expError, err.Error())
if len(tt.cfgError) != 0 && err != nil {
if !strings.Contains(err.Error(), tt.cfgError) {
t.Errorf("%s case: expected error prefix '%s' but got '%s'", tt.testName, tt.cfgError, err.Error())
}
}
})
}
Expand Down
17 changes: 6 additions & 11 deletions credential_provider/credential_provider.go
Original file line number Diff line number Diff line change
@@ -1,18 +1,13 @@
package credential_provider

import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"context"

"github.com/aws/aws-sdk-go-v2/aws"
)

// CredentialProvider interface defines methods for obtaining AWS credentials configuration
type CredentialProvider interface {
// ConfigProvider interface defines methods for obtaining AWS credentials configuration
type ConfigProvider interface {
// GetAWSConfig returns an AWS configuration containing credentials obtained from the provider
GetAWSConfig() (*aws.Config, error)
}

// authTokenFetcher interface defines methods for fetching a token given a K8s namespace and service account.
// It matches stscreds.TokenFetcher interface.
type authTokenFetcher interface {
FetchToken(ctx credentials.Context) ([]byte, error)
GetAWSConfig(ctx context.Context) (aws.Config, error)
}
32 changes: 24 additions & 8 deletions credential_provider/credential_provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ import (
"context"
"fmt"

"github.com/aws/aws-sdk-go/service/sts/stsiface"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/sts"
"github.com/aws/aws-sdk-go-v2/service/sts/types"
authv1 "k8s.io/api/authentication/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
k8sv1 "k8s.io/client-go/kubernetes/typed/core/v1"
Expand All @@ -18,23 +20,37 @@ const (

// Mock STS client
type mockSTS struct {
stsiface.STSAPI
sts.Client
}

func (m *mockSTS) AssumeRoleWithWebIdentity(ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error) {
return &sts.AssumeRoleWithWebIdentityOutput{
Credentials: &types.Credentials{
AccessKeyId: aws.String("TEST_ACCESS_KEY"),
SecretAccessKey: aws.String("TEST_SECRET"),
SessionToken: aws.String("TEST_TOKEN"),
},
}, nil
}

// Mock K8s client for creating tokens
type mockK8sV1 struct {
k8sv1.CoreV1Interface
k8CTOneShotError bool
k8sv1.CoreV1Interface // satisfy the interface
fake k8sv1.CoreV1Interface // plumb down to the ServiceAccounts method
k8CTOneShotError bool
}

func (m *mockK8sV1) ServiceAccounts(namespace string) k8sv1.ServiceAccountInterface {
return &mockK8sV1SA{v1mock: m}
return &mockK8sV1SA{
m.fake.ServiceAccounts(namespace),
m.k8CTOneShotError,
}
}

// Mock the K8s service account client
type mockK8sV1SA struct {
k8sv1.ServiceAccountInterface
v1mock *mockK8sV1
oneShotGetTokenError bool
}

func (ma *mockK8sV1SA) CreateToken(
Expand All @@ -44,8 +60,8 @@ func (ma *mockK8sV1SA) CreateToken(
opts metav1.CreateOptions,
) (*authv1.TokenRequest, error) {

if ma.v1mock.k8CTOneShotError {
ma.v1mock.k8CTOneShotError = false // Reset so other tests don't fail
if ma.oneShotGetTokenError {
ma.oneShotGetTokenError = false // Reset so other tests don't fail
return nil, fmt.Errorf("Fake create token error")
}

Expand Down
Loading

0 comments on commit 9ee7592

Please sign in to comment.