Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LocalStack support #52

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func (p authTokenFetcher) FetchToken(ctx credentials.Context) ([]byte, error) {
// sessions by calling GetAWSSession.
//
type Auth struct {
region, nameSpace, svcAcc string
region, endpoint, nameSpace, svcAcc string
k8sClient k8sv1.CoreV1Interface
stsClient stsiface.STSAPI
ctx context.Context
Expand All @@ -77,21 +77,23 @@ type Auth struct {
//
func NewAuth(
ctx context.Context,
region, nameSpace, svcAcc string,
region, endpoint, nameSpace, svcAcc string,
k8sClient k8sv1.CoreV1Interface,
) (auth *Auth, e error) {

// Get an initial session to use for STS calls.
sess, err := session.NewSession(aws.NewConfig().
WithSTSRegionalEndpoint(endpoints.RegionalSTSEndpoint).
WithRegion(region),
)
sess, err := session.NewSession(&aws.Config{
Endpoint: aws.String(endpoint),
Region: aws.String(region),
STSRegionalEndpoint: endpoints.RegionalSTSEndpoint,
})
if err != nil {
return nil, err
}

return &Auth{
region: region,
endpoint: endpoint,
nameSpace: nameSpace,
svcAcc: svcAcc,
k8sClient: k8sClient,
Expand Down Expand Up @@ -140,6 +142,7 @@ func (p Auth) GetAWSSession() (awsSession *session.Session, e error) {
fetcher := &authTokenFetcher{p.nameSpace, p.svcAcc, p.k8sClient}
ar := stscreds.NewWebIdentityRoleProviderWithToken(p.stsClient, *roleArn, ProviderName, fetcher)
config := aws.NewConfig().
WithEndpoint(p.endpoint).
WithSTSRegionalEndpoint(endpoints.RegionalSTSEndpoint). // Use regional STS endpoint
WithRegion(p.region).
WithCredentials(credentials.NewCredentials(ar))
Expand Down
1 change: 0 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ require (
github.com/aws/aws-sdk-go v1.37.0
github.com/jmespath/go-jmespath v0.4.0
google.golang.org/grpc v1.35.0
gopkg.in/yaml.v2 v2.3.0
k8s.io/api v0.20.2
k8s.io/apimachinery v0.20.2
k8s.io/client-go v0.20.2
Expand Down
878 changes: 878 additions & 0 deletions go.sum

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions provider/parameter_store_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,8 @@ func NewParameterStoreProviderWithClient(client ssmiface.SSMAPI) *ParameterStore
client: client,
}
}
func NewParameterStoreProvider(region string, awsSession *session.Session) *ParameterStoreProvider {
parameterStoreClient := ssm.New(awsSession, aws.NewConfig().WithRegion(region))
func NewParameterStoreProvider(region string, endpoint string,awsSession *session.Session) *ParameterStoreProvider {
parameterStoreClient := ssm.New(awsSession, &aws.Config{Endpoint: aws.String(endpoint), Region: aws.String(region)})
return NewParameterStoreProviderWithClient(parameterStoreClient)
}

Expand Down
8 changes: 4 additions & 4 deletions provider/secret_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,19 @@ type SecretProviderFactory struct {

// The prototype for the provider factory fatory
//
type ProviderFactoryFactory func(region string, session *session.Session) (factory *SecretProviderFactory)
type ProviderFactoryFactory func(region string, endpoints map[string]string, session *session.Session) (factory *SecretProviderFactory)

// Creates the provider factory.
//
// This factory catagorizes the request and returns the correct concrete
// provider implementation using the secret type.
//
func NewSecretProviderFactory(region string, session *session.Session) (factory *SecretProviderFactory) {
func NewSecretProviderFactory(region string, endpoints map[string]string, session *session.Session) (factory *SecretProviderFactory) {

return &SecretProviderFactory{
Providers: map[SecretType]SecretProvider{
SSMParameter: NewParameterStoreProvider(region, session),
SecretsManager: NewSecretsManagerProvider(region, session),
SSMParameter: NewParameterStoreProvider(region, endpoints["SSMParameter"], session),
SecretsManager: NewSecretsManagerProvider(region, endpoints["SecretsManager"], session),
},
}

Expand Down
4 changes: 2 additions & 2 deletions provider/secrets_manager_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ func NewSecretsManagerProviderWithClient(client secretsmanageriface.SecretsManag
client: client,
}
}
func NewSecretsManagerProvider(region string, awsSession *session.Session) *SecretsManagerProvider {
secretsManagerClient := secretsmanager.New(awsSession, aws.NewConfig().WithRegion(region))
func NewSecretsManagerProvider(region string, endpoint string, awsSession *session.Session) *SecretsManagerProvider {
secretsManagerClient := secretsmanager.New(awsSession, &aws.Config{Endpoint: aws.String(endpoint), Region: aws.String(region)})
return NewSecretsManagerProviderWithClient(secretsManagerClient)
}
40 changes: 31 additions & 9 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,17 @@ import (
var Version string

const (
namespaceAttrib = "csi.storage.k8s.io/pod.namespace"
acctAttrib = "csi.storage.k8s.io/serviceAccount.name"
podnameAttrib = "csi.storage.k8s.io/pod.name"
regionAttrib = "region" // The attribute name for the region in the SecretProviderClass
transAttrib = "pathTranslation" // Path translation char
regionLabel = "topology.kubernetes.io/region" // The node label giving the region
secProvAttrib = "objects" // The attributed used to pass the SecretProviderClass definition (with what to mount)
namespaceAttrib = "csi.storage.k8s.io/pod.namespace"
acctAttrib = "csi.storage.k8s.io/serviceAccount.name"
podnameAttrib = "csi.storage.k8s.io/pod.name"
regionAttrib = "region" // The attribute name for the region in the SecretProviderClass
baseEndpointAttrib = "baseEndpoint"
ssmEndpointAttrib = "ssmEndpoint"
secretsEndpointAttrib = "secretsManagerEndpoint"
stsEndpointAttrib = "stsEndpoint"
transAttrib = "pathTranslation" // Path translation char
regionLabel = "topology.kubernetes.io/region" // The node label giving the region
secProvAttrib = "objects" // The attributed used to pass the SecretProviderClass definition (with what to mount)
)

// A Secrets Store CSI Driver provider implementation for AWS Secrets Manager and SSM Parameter Store.
Expand Down Expand Up @@ -94,8 +98,26 @@ func (s *CSIDriverProviderServer) Mount(ctx context.Context, req *v1alpha1.Mount
svcAcct := attrib[acctAttrib]
podName := attrib[podnameAttrib]
region := attrib[regionAttrib]
baseEndpoint := attrib[baseEndpointAttrib]
secretsEndpoint := attrib[secretsEndpointAttrib]
ssmEndpoint := attrib[ssmEndpointAttrib]
stsEndpoint := attrib[stsEndpointAttrib]
translate := attrib[transAttrib]

if len(secretsEndpoint) <= 0 && len(baseEndpoint) > 0 {
secretsEndpoint = baseEndpoint
}

if len(ssmEndpoint) <= 0 && len(baseEndpoint) > 0 {
ssmEndpoint = baseEndpoint
}

if len(stsEndpoint) <= 0 && len(baseEndpoint) > 0 {
stsEndpoint = baseEndpoint
}

endpoints := map[string]string{ "SSMParameter": ssmEndpoint, "SecretsManager": secretsEndpoint }

// Lookup the region if one was not specified.
if len(region) <= 0 {
region, err = s.getRegionFromNode(ctx, nameSpace, podName)
Expand All @@ -121,7 +143,7 @@ func (s *CSIDriverProviderServer) Mount(ctx context.Context, req *v1alpha1.Mount
}

// Get the pod's AWS creds.
oidcAuth, err := auth.NewAuth(ctx, region, nameSpace, svcAcct, s.k8sClient)
oidcAuth, err := auth.NewAuth(ctx, region, stsEndpoint, nameSpace, svcAcct, s.k8sClient)
if err != nil {
return nil, err
}
Expand All @@ -140,7 +162,7 @@ func (s *CSIDriverProviderServer) Mount(ctx context.Context, req *v1alpha1.Mount
}

// Fetch all secrets before saving so we write nothing on failure.
providerFactory := s.secretProviderFactory(region, awsSession)
providerFactory := s.secretProviderFactory(region, endpoints, awsSession)
var fetchedSecrets []*provider.SecretValue
for sType := range descriptors { // Iterate over each secret type.

Expand Down