Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Go SDK v2 #429

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 63 additions & 46 deletions auth/auth.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Package responsible for returning an AWS SDK session with credentials
* Package responsible for returning an AWS SDK config with credentials
* given an AWS region, K8s namespace, and K8s service account.
*
* This package requries that the K8s service account be associated with an IAM
Expand All @@ -9,13 +9,14 @@ package auth

import (
"context"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/aws/aws-sdk-go/service/sts/stsiface"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials/stscreds"
"github.com/aws/aws-sdk-go-v2/service/sts"
"github.com/aws/secrets-store-csi-driver-provider-aws/credential_provider"
"github.com/aws/smithy-go/middleware"
smithyhttp "github.com/aws/smithy-go/transport/http"
k8sv1 "k8s.io/client-go/kubernetes/typed/core/v1"
"k8s.io/klog/v2"
)
Expand All @@ -24,85 +25,101 @@ const (
ProviderName = "secrets-store-csi-driver-provider-aws"
)

// Auth is the main entry point to retrieve an AWS session. The caller
// Auth is the main entry point to retrieve an AWS config. The caller
// initializes a new Auth object with NewAuth passing the region, namespace, pod name,
// K8s service account and usePodIdentity flag (and request context). The caller can then obtain AWS
// sessions by calling GetAWSSession.
// config by calling GetAWSConfig.
type Auth struct {
region, nameSpace, svcAcc, podName, preferredAddressType string
usePodIdentity bool
k8sClient k8sv1.CoreV1Interface
stsClient stsiface.STSAPI
ctx context.Context
region string
namespace string
serviceAccount string
podName string
preferredAddressType string
usePodIdentity bool
k8sClient k8sv1.CoreV1Interface
stsClient stscreds.AssumeRoleWithWebIdentityAPIClient
}

// Factory method to create a new Auth object for an incomming mount request.
// NewAuth creates an Auth object for an incoming mount request.
func NewAuth(
ctx context.Context,
region, nameSpace, svcAcc, podName, preferredAddressType string,
region, namespace, serviceAccount, podName, preferredAddressType string,
usePodIdentity bool,
k8sClient k8sv1.CoreV1Interface,
) (auth *Auth, e error) {
var stsClient stsiface.STSAPI
var stsClient *sts.Client

if !usePodIdentity {
// Get an initial session to use for STS calls when using IRSA
sess, err := session.NewSession(aws.NewConfig().
WithSTSRegionalEndpoint(endpoints.RegionalSTSEndpoint).
WithRegion(region),
// Get an initial config to use for STS calls when using IRSA
cfg, err := config.LoadDefaultConfig(context.Background(),
config.WithRegion(region),
config.WithDefaultsMode(aws.DefaultsModeStandard),
)
if err != nil {
return nil, err
}
stsClient = sts.New(sess)
stsClient = sts.NewFromConfig(cfg)
}

return &Auth{
region: region,
nameSpace: nameSpace,
svcAcc: svcAcc,
namespace: namespace,
serviceAccount: serviceAccount,
podName: podName,
preferredAddressType: preferredAddressType,
usePodIdentity: usePodIdentity,
k8sClient: k8sClient,
stsClient: stsClient,
ctx: ctx,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Context is meant to be passed through function calls. To quote from the context package documentation

Do not store Contexts inside a struct type; instead, pass a Context explicitly to each function that needs it. This is discussed further in https://go.dev/blog/context-and-structs.

Rather than putting it in the struct, it can be synchronously passed in GetAWSConfig(context.Context) below.

}, nil

}

// Get the AWS session credentials associated with a given pod's service account.
// Get the AWS config associated with a given pod's service account.
//
// The returned session is capable of automatically refreshing creds as needed
// The returned config is capable of automatically refreshing creds as needed
// by using a private TokenFetcher helper.
func (p Auth) GetAWSSession() (awsSession *session.Session, e error) {
var credProvider credential_provider.CredentialProvider
func (p Auth) GetAWSConfig(ctx context.Context) (aws.Config, error) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no longer a "Session" type in V2 of the SDK, just aws.Config{}, so I've renamed that appropriately

var credProvider credential_provider.ConfigProvider

if p.usePodIdentity {
klog.Infof("Using Pod Identity for authentication in namespace: %s, service account: %s", p.nameSpace, p.svcAcc)
klog.Infof("Using Pod Identity for authentication in namespace: %s, service account: %s", p.namespace, p.serviceAccount)
var err error
credProvider, err = credential_provider.NewPodIdentityCredentialProvider(p.region, p.nameSpace, p.svcAcc, p.podName, p.preferredAddressType, p.k8sClient)
credProvider, err = credential_provider.NewPodIdentityCredentialProvider(p.region, p.namespace, p.serviceAccount, p.podName, p.preferredAddressType, p.k8sClient)
if err != nil {
return nil, err
return aws.Config{}, err
}
} else {
klog.Infof("Using IAM Roles for Service Accounts for authentication in namespace: %s, service account: %s", p.nameSpace, p.svcAcc)
credProvider = credential_provider.NewIRSACredentialProvider(p.stsClient, p.region, p.nameSpace, p.svcAcc, p.k8sClient, p.ctx)
klog.Infof("Using IAM Roles for Service Accounts for authentication in namespace: %s, service account: %s", p.namespace, p.serviceAccount)
credProvider = credential_provider.NewIRSACredentialProvider(p.stsClient, p.region, p.namespace, p.serviceAccount, p.k8sClient)
}

config, err := credProvider.GetAWSConfig()
cfg, err := credProvider.GetAWSConfig(ctx)
if err != nil {
return nil, err
return aws.Config{}, err
}

// Include the provider in the user agent string.
sess, err := session.NewSession(config)
if err != nil {
return nil, err
}
sess.Handlers.Build.PushFront(func(r *request.Request) {
request.AddToUserAgent(r, ProviderName)
// add the user agent to the config
cfg.APIOptions = append(cfg.APIOptions, func(stack *middleware.Stack) error {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Header/UserAgent functionality is different in V2, See the docs on request customization

return stack.Build.Add(&userAgentMiddleware{
providerName: ProviderName,
}, middleware.After)
})

return session.Must(sess, err), nil
return cfg, nil
}

type userAgentMiddleware struct {
providerName string
}

func (m *userAgentMiddleware) ID() string {
return "UserAgent"
}

func (m *userAgentMiddleware) HandleBuild(ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler) (
out middleware.BuildOutput, metadata middleware.Metadata, err error) {
req, ok := in.Request.(*smithyhttp.Request)
if !ok {
return next.HandleBuild(ctx, in)
}
req.Header.Set("User-Agent", m.providerName+" "+req.Header.Get("User-Agent"))
return next.HandleBuild(ctx, in)
}
83 changes: 47 additions & 36 deletions auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,65 +2,76 @@ package auth

import (
"context"
"fmt"
"strings"
"testing"

"github.com/aws/aws-sdk-go/service/sts/stsiface"
"github.com/aws/aws-sdk-go-v2/service/sts"
"k8s.io/client-go/kubernetes/fake"
)

// Mock STS client
type mockSTS struct {
stsiface.STSAPI
sts.Client
}

func (m *mockSTS) AssumeRoleWithWebIdentity(ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error) {
return nil, fmt.Errorf("fake error for serviceaccounst")
}

type sessionTest struct {
testName string
testPodIdentity bool
expError string
}

var sessionTests []sessionTest = []sessionTest{
{
testName: "IRSA",
testPodIdentity: false,
expError: "serviceaccounts", // IRSA path will fail at getting service account since using fake client
},
{
testName: "Pod Identity",
testPodIdentity: true,
expError: "failed to fetch token", // Pod Identity path will fail fetching token since using fake client
},
cfgError string
}

func TestGetAWSSession(t *testing.T) {
for _, tstData := range sessionTests {
t.Run(tstData.testName, func(t *testing.T) {
cases := []sessionTest{
{
testName: "IRSA",
testPodIdentity: false,
cfgError: "serviceaccounts", // IRSA path will fail at getting creds since its in the hot path of the config

},
{
testName: "Pod Identity",
testPodIdentity: true,
cfgError: "", // Pod Identity path succeeds since token is lazy loaded
},
}
for _, tt := range cases {
t.Run(tt.testName, func(t *testing.T) {

auth := &Auth{
region: "someRegion",
nameSpace: "someNamespace",
svcAcc: "someSvcAcc",
podName: "somepod",
usePodIdentity: tstData.testPodIdentity,
k8sClient: fake.NewSimpleClientset().CoreV1(),
stsClient: &mockSTS{},
ctx: context.Background(),
auth, err := NewAuth(
"someRegion",
"someNamespace",
"someSvcAcc",
"somepod",
"",
tt.testPodIdentity,
fake.NewSimpleClientset().CoreV1(),
)
if err != nil {
t.Fatalf("%s case: failed to create auth: %v", tt.testName, err)
}
auth.stsClient = &mockSTS{}
auth.k8sClient = fake.NewSimpleClientset().CoreV1()

sess, err := auth.GetAWSSession()
cfg, err := auth.GetAWSConfig(context.Background())

if len(tstData.expError) == 0 && err != nil {
t.Errorf("%s case: got unexpected auth error: %s", tstData.testName, err)
if len(tt.cfgError) == 0 && err != nil {
t.Errorf("%s case: got unexpected auth error: %s", tt.testName, err)
}
if len(tstData.expError) == 0 && sess == nil {
t.Errorf("%s case: got empty session", tstData.testName)
if len(tt.cfgError) == 0 && cfg.Credentials == nil {
t.Errorf("%s case: got empty credentials", tt.testName)
}
if len(tstData.expError) != 0 && err == nil {
t.Errorf("%s case: expected error but got none", tstData.testName)
if len(tt.cfgError) != 0 && err == nil {
t.Errorf("%s case: expected error but got none", tt.testName)
}
if len(tstData.expError) != 0 && !strings.Contains(err.Error(), tstData.expError) {
t.Errorf("%s case: expected error prefix '%s' but got '%s'", tstData.testName, tstData.expError, err.Error())
if len(tt.cfgError) != 0 && err != nil {
if !strings.Contains(err.Error(), tt.cfgError) {
t.Errorf("%s case: expected error prefix '%s' but got '%s'", tt.testName, tt.cfgError, err.Error())
}
}
})
}
Expand Down
17 changes: 6 additions & 11 deletions credential_provider/credential_provider.go
Original file line number Diff line number Diff line change
@@ -1,18 +1,13 @@
package credential_provider

import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"context"

"github.com/aws/aws-sdk-go-v2/aws"
)

// CredentialProvider interface defines methods for obtaining AWS credentials configuration
type CredentialProvider interface {
// ConfigProvider interface defines methods for obtaining AWS credentials configuration
type ConfigProvider interface {
// GetAWSConfig returns an AWS configuration containing credentials obtained from the provider
GetAWSConfig() (*aws.Config, error)
}

// authTokenFetcher interface defines methods for fetching a token given a K8s namespace and service account.
// It matches stscreds.TokenFetcher interface.
type authTokenFetcher interface {
FetchToken(ctx credentials.Context) ([]byte, error)
GetAWSConfig(ctx context.Context) (aws.Config, error)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than force a unifying method signature for getting a token and creating unnecessary wrappers, we can use the interface types provided in each credential method for retrieving tokens, and just make the ConfigProvider return an aws.Config{} with the correct AWS cred provider

}
33 changes: 25 additions & 8 deletions credential_provider/credential_provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ package credential_provider
import (
"context"
"fmt"
"github.com/aws/aws-sdk-go/service/sts/stsiface"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/sts"
"github.com/aws/aws-sdk-go-v2/service/sts/types"
authv1 "k8s.io/api/authentication/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
k8sv1 "k8s.io/client-go/kubernetes/typed/core/v1"
Expand All @@ -17,23 +20,37 @@ const (

// Mock STS client
type mockSTS struct {
stsiface.STSAPI
sts.Client
}

func (m *mockSTS) AssumeRoleWithWebIdentity(ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error) {
return &sts.AssumeRoleWithWebIdentityOutput{
Credentials: &types.Credentials{
AccessKeyId: aws.String("TEST_ACCESS_KEY"),
SecretAccessKey: aws.String("TEST_SECRET"),
SessionToken: aws.String("TEST_TOKEN"),
},
}, nil
}

// Mock K8s client for creating tokens
type mockK8sV1 struct {
k8sv1.CoreV1Interface
k8CTOneShotError bool
k8sv1.CoreV1Interface // satisfy the interface
fake k8sv1.CoreV1Interface // plumb down to the ServiceAccounts method
Comment on lines +38 to +39
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need both the interface satisfied, and an addressable field so that .ServiceAccounts().CreateToken() and .ServiceAccounts().Get() can be called from the same mock type.

k8CTOneShotError bool
}

func (m *mockK8sV1) ServiceAccounts(namespace string) k8sv1.ServiceAccountInterface {
return &mockK8sV1SA{v1mock: m}
return &mockK8sV1SA{
m.fake.ServiceAccounts(namespace),
m.k8CTOneShotError,
}
}

// Mock the K8s service account client
type mockK8sV1SA struct {
k8sv1.ServiceAccountInterface
v1mock *mockK8sV1
oneShotGetTokenError bool
}

func (ma *mockK8sV1SA) CreateToken(
Expand All @@ -43,8 +60,8 @@ func (ma *mockK8sV1SA) CreateToken(
opts metav1.CreateOptions,
) (*authv1.TokenRequest, error) {

if ma.v1mock.k8CTOneShotError {
ma.v1mock.k8CTOneShotError = false // Reset so other tests don't fail
if ma.oneShotGetTokenError {
ma.oneShotGetTokenError = false // Reset so other tests don't fail
return nil, fmt.Errorf("Fake create token error")
}

Expand Down
Loading