diff --git a/pkg/driver/driver.go b/pkg/driver/driver.go index 86654b9e..b09042c9 100644 --- a/pkg/driver/driver.go +++ b/pkg/driver/driver.go @@ -21,12 +21,11 @@ import ( "fmt" "net" "os" - "time" "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node" + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/credentialprovider" "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/mounter" "github.com/awslabs/aws-s3-csi-driver/pkg/driver/version" - "github.com/awslabs/aws-s3-csi-driver/pkg/util" "github.com/container-storage-interface/spec/lib/go/csi" "google.golang.org/grpc" "k8s.io/client-go/kubernetes" @@ -35,15 +34,11 @@ import ( ) const ( - driverName = "s3.csi.aws.com" - webIdentityTokenEnv = "AWS_WEB_IDENTITY_TOKEN_FILE" + driverName = "s3.csi.aws.com" grpcServerMaxReceiveMessageSize = 1024 * 1024 * 2 // 2MB unixSocketPerm = os.FileMode(0700) // only owner can write and read. - - // This is the plugin directory for CSI driver mounted in the container. - containerPluginDir = "/csi" ) type Driver struct { @@ -74,13 +69,13 @@ func NewDriver(endpoint string, mpVersion string, nodeID string) (*Driver, error klog.Infof("Driver version: %v, Git commit: %v, build date: %v, nodeID: %v, mount-s3 version: %v, kubernetes version: %v", version.DriverVersion, version.GitCommit, version.BuildDate, nodeID, mpVersion, kubernetesVersion) - systemd_mounter, err := mounter.NewSystemdMounter(mpVersion, kubernetesVersion) + credProvider := credentialprovider.New(clientset.CoreV1(), credentialprovider.RegionFromIMDSOnce) + systemdMounter, err := mounter.NewSystemdMounter(credProvider, mpVersion, kubernetesVersion) if err != nil { klog.Fatalln(err) } - credentialProvider := mounter.NewCredentialProvider(clientset.CoreV1(), containerPluginDir, mounter.RegionFromIMDSOnce) - nodeServer := node.NewS3NodeServer(nodeID, systemd_mounter, credentialProvider) + nodeServer := node.NewS3NodeServer(nodeID, systemdMounter) return &Driver{ Endpoint: endpoint, @@ -90,14 +85,6 @@ func NewDriver(endpoint string, mpVersion string, nodeID string) (*Driver, error } func (d *Driver) Run() error { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - tokenFile := os.Getenv(webIdentityTokenEnv) - if tokenFile != "" { - klog.Infof("Found AWS_WEB_IDENTITY_TOKEN_FILE, syncing token") - go tokenFileTender(ctx, tokenFile, "/csi/token") - } - scheme, addr, err := ParseEndpoint(d.Endpoint) if err != nil { return err @@ -150,22 +137,6 @@ func (d *Driver) Stop() { d.Srv.Stop() } -func tokenFileTender(ctx context.Context, sourcePath string, destPath string) { - for { - timer := time.After(10 * time.Second) - err := util.ReplaceFile(destPath, sourcePath, 0600) - if err != nil { - klog.Infof("Failed to sync AWS web token file: %v", err) - } - select { - case <-timer: - continue - case <-ctx.Done(): - return - } - } -} - func kubernetesVersion(clientset *kubernetes.Clientset) (string, error) { version, err := clientset.ServerVersion() if err != nil { diff --git a/pkg/driver/node/awsprofile/aws_profile.go b/pkg/driver/node/awsprofile/aws_profile.go deleted file mode 100644 index 8ecd4f78..00000000 --- a/pkg/driver/node/awsprofile/aws_profile.go +++ /dev/null @@ -1,120 +0,0 @@ -// Package awsprofile provides utilities for creating and deleting AWS Profile (i.e., credentials & config files). -package awsprofile - -import ( - "errors" - "fmt" - "io/fs" - "os" - "path/filepath" - "strings" - "unicode" -) - -const ( - awsProfileName = "s3-csi" - awsProfileConfigFilename = "s3-csi-config" - awsProfileCredentialsFilename = "s3-csi-credentials" - awsProfileFilePerm = fs.FileMode(0400) // only owner readable -) - -// ErrInvalidCredentials is returned when given AWS Credentials contains invalid characters. -var ErrInvalidCredentials = errors.New("aws-profile: Invalid AWS Credentials") - -// An AWSProfile represents an AWS profile with it's credentials and config files. -type AWSProfile struct { - Name string - ConfigPath string - CredentialsPath string -} - -// CreateAWSProfile creates an AWS Profile with credentials and config files from given credentials. -// Created credentials and config files can be clean up with `CleanupAWSProfile`. -func CreateAWSProfile(basepath string, accessKeyID string, secretAccessKey string, sessionToken string) (AWSProfile, error) { - if !isValidCredential(accessKeyID) || !isValidCredential(secretAccessKey) || !isValidCredential(sessionToken) { - return AWSProfile{}, ErrInvalidCredentials - } - - name := awsProfileName - - configPath := filepath.Join(basepath, awsProfileConfigFilename) - err := writeAWSProfileFile(configPath, configFileContents(name)) - if err != nil { - return AWSProfile{}, fmt.Errorf("aws-profile: Failed to create config file %s: %v", configPath, err) - } - - credentialsPath := filepath.Join(basepath, awsProfileCredentialsFilename) - err = writeAWSProfileFile(credentialsPath, credentialsFileContents(name, accessKeyID, secretAccessKey, sessionToken)) - if err != nil { - return AWSProfile{}, fmt.Errorf("aws-profile: Failed to create credentials file %s: %v", credentialsPath, err) - } - - return AWSProfile{ - Name: name, - ConfigPath: configPath, - CredentialsPath: credentialsPath, - }, nil -} - -// CleanupAWSProfile cleans up credentials and config files created in given `basepath` via `CreateAWSProfile`. -func CleanupAWSProfile(basepath string) error { - configPath := filepath.Join(basepath, awsProfileConfigFilename) - if err := os.Remove(configPath); err != nil { - if !errors.Is(err, fs.ErrNotExist) { - return fmt.Errorf("aws-profile: Failed to remove config file %s: %v", configPath, err) - } - } - - credentialsPath := filepath.Join(basepath, awsProfileCredentialsFilename) - if err := os.Remove(credentialsPath); err != nil { - if !errors.Is(err, fs.ErrNotExist) { - return fmt.Errorf("aws-profile: Failed to remove credentials file %s: %v", credentialsPath, err) - } - } - - return nil -} - -func writeAWSProfileFile(path string, content string) error { - err := os.WriteFile(path, []byte(content), awsProfileFilePerm) - if err != nil { - return err - } - // If the given file exists, `os.WriteFile` just truncates it without changing it's permissions, - // so we need to ensure it always has the correct permissions. - return os.Chmod(path, awsProfileFilePerm) -} - -func credentialsFileContents(profile string, accessKeyID string, secretAccessKey string, sessionToken string) string { - var b strings.Builder - b.Grow(128) - b.WriteRune('[') - b.WriteString(profile) - b.WriteRune(']') - b.WriteRune('\n') - - b.WriteString("aws_access_key_id=") - b.WriteString(accessKeyID) - b.WriteRune('\n') - - b.WriteString("aws_secret_access_key=") - b.WriteString(secretAccessKey) - b.WriteRune('\n') - - if sessionToken != "" { - b.WriteString("aws_session_token=") - b.WriteString(sessionToken) - b.WriteRune('\n') - } - - return b.String() -} - -func configFileContents(profile string) string { - return fmt.Sprintf("[profile %s]\n", profile) -} - -// isValidCredential checks whether given credential file contains any non-printable characters. -func isValidCredential(s string) bool { - return !strings.ContainsFunc(s, func(r rune) bool { return !unicode.IsPrint(r) }) -} diff --git a/pkg/driver/node/awsprofile/aws_profile_test.go b/pkg/driver/node/awsprofile/aws_profile_test.go deleted file mode 100644 index d64a9f0c..00000000 --- a/pkg/driver/node/awsprofile/aws_profile_test.go +++ /dev/null @@ -1,112 +0,0 @@ -package awsprofile_test - -import ( - "context" - "errors" - "io/fs" - "os" - "testing" - - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/config" - "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/awsprofile" -) - -const testAccessKeyId = "test-access-key-id" -const testSecretAccessKey = "test-secret-access-key" -const testSessionToken = "test-session-token" - -func TestCreatingAWSProfile(t *testing.T) { - t.Run("create config and credentials files", func(t *testing.T) { - profile, err := awsprofile.CreateAWSProfile(t.TempDir(), testAccessKeyId, testSecretAccessKey, testSessionToken) - assertNoError(t, err) - assertCredentialsFromAWSProfile(t, profile, testAccessKeyId, testSecretAccessKey, testSessionToken) - }) - - t.Run("create config and credentials files with empty session token", func(t *testing.T) { - profile, err := awsprofile.CreateAWSProfile(t.TempDir(), testAccessKeyId, testSecretAccessKey, "") - assertNoError(t, err) - assertCredentialsFromAWSProfile(t, profile, testAccessKeyId, testSecretAccessKey, "") - }) - - t.Run("ensure config and credentials files are owner readable only", func(t *testing.T) { - profile, err := awsprofile.CreateAWSProfile(t.TempDir(), testAccessKeyId, testSecretAccessKey, testSessionToken) - assertNoError(t, err) - assertCredentialsFromAWSProfile(t, profile, testAccessKeyId, testSecretAccessKey, testSessionToken) - - configStat, err := os.Stat(profile.ConfigPath) - assertNoError(t, err) - assertEquals(t, 0400, configStat.Mode()) - - credentialsStat, err := os.Stat(profile.CredentialsPath) - assertNoError(t, err) - assertEquals(t, 0400, credentialsStat.Mode()) - }) - - t.Run("fail if credentials contains non-ascii characters", func(t *testing.T) { - t.Run("access key ID", func(t *testing.T) { - _, err := awsprofile.CreateAWSProfile(t.TempDir(), testAccessKeyId+"\n\t\r credential_process=exit", testSecretAccessKey, testSessionToken) - assertEquals(t, true, errors.Is(err, awsprofile.ErrInvalidCredentials)) - }) - t.Run("secret access key", func(t *testing.T) { - _, err := awsprofile.CreateAWSProfile(t.TempDir(), testAccessKeyId, testSecretAccessKey+"\n", testSessionToken) - assertEquals(t, true, errors.Is(err, awsprofile.ErrInvalidCredentials)) - }) - t.Run("session token", func(t *testing.T) { - _, err := awsprofile.CreateAWSProfile(t.TempDir(), testAccessKeyId, testSecretAccessKey, testSessionToken+"\n\r") - assertEquals(t, true, errors.Is(err, awsprofile.ErrInvalidCredentials)) - }) - }) -} - -func TestCleaningUpAWSProfile(t *testing.T) { - t.Run("clean config and credentials files", func(t *testing.T) { - basepath := t.TempDir() - - profile, err := awsprofile.CreateAWSProfile(basepath, testAccessKeyId, testSecretAccessKey, testSessionToken) - assertNoError(t, err) - assertCredentialsFromAWSProfile(t, profile, testAccessKeyId, testSecretAccessKey, testSessionToken) - - err = awsprofile.CleanupAWSProfile(basepath) - assertNoError(t, err) - - _, err = os.Stat(profile.ConfigPath) - assertEquals(t, true, errors.Is(err, fs.ErrNotExist)) - - _, err = os.Stat(profile.CredentialsPath) - assertEquals(t, true, errors.Is(err, fs.ErrNotExist)) - }) - - t.Run("cleaning non-existent config and credentials files should not be an error", func(t *testing.T) { - err := awsprofile.CleanupAWSProfile(t.TempDir()) - assertNoError(t, err) - }) -} - -func assertCredentialsFromAWSProfile(t *testing.T, profile awsprofile.AWSProfile, accessKeyID string, secretAccessKey string, sessionToken string) { - credentials := parseAWSProfile(t, profile) - assertEquals(t, accessKeyID, credentials.AccessKeyID) - assertEquals(t, secretAccessKey, credentials.SecretAccessKey) - assertEquals(t, sessionToken, credentials.SessionToken) -} - -func parseAWSProfile(t *testing.T, profile awsprofile.AWSProfile) aws.Credentials { - sharedConfig, err := config.LoadSharedConfigProfile(context.Background(), profile.Name, func(c *config.LoadSharedConfigOptions) { - c.ConfigFiles = []string{profile.ConfigPath} - c.CredentialsFiles = []string{profile.CredentialsPath} - }) - assertNoError(t, err) - return sharedConfig.Credentials -} - -func assertEquals[T comparable](t *testing.T, expected T, got T) { - if expected != got { - t.Errorf("Expected %#v, Got %#v", expected, got) - } -} - -func assertNoError(t *testing.T, err error) { - if err != nil { - t.Errorf("Expected no error, but got: %s", err) - } -} diff --git a/pkg/driver/node/credentialprovider/awsprofile/aws_profile.go b/pkg/driver/node/credentialprovider/awsprofile/aws_profile.go new file mode 100644 index 00000000..c13f6519 --- /dev/null +++ b/pkg/driver/node/credentialprovider/awsprofile/aws_profile.go @@ -0,0 +1,162 @@ +// Package awsprofile provides utilities for creating and deleting AWS Profile (i.e., credentials & config files). +package awsprofile + +import ( + "errors" + "fmt" + "io/fs" + "os" + "path/filepath" + "strings" + "unicode" + + "github.com/google/renameio" +) + +const ( + awsProfileNameSuffix = "s3-csi" + awsProfileConfigFilenameSuffix = "s3-csi-config" + awsProfileCredentialsFilenameSuffix = "s3-csi-credentials" +) + +// ErrInvalidCredentials is returned when given AWS Credentials contains invalid characters. +var ErrInvalidCredentials = errors.New("aws-profile: Invalid AWS Credentials") + +// Profile represents an AWS profile with it's credentials and config filenames. +type Profile struct { + // Name is the AWS profile name + Name string + // ConfigFilename is the name of the AWS config file + ConfigFilename string + // CredentialsFilename is the name of the AWS credentials file + CredentialsFilename string +} + +// Credentials represents long-term AWS credentials used to create an AWS Profile. +type Credentials struct { + AccessKeyID string + SecretAccessKey string + SessionToken string +} + +// isValid checks if all credential fields contain only printable characters +func (c *Credentials) isValid() bool { + return isValidCredential(c.AccessKeyID) && + isValidCredential(c.SecretAccessKey) && + isValidCredential(c.SessionToken) +} + +// Settings contains configuration for AWS profile creation and management. +type Settings struct { + // Basepath is the directory path where AWS profile files will be created + Basepath string + // Prefix is prepended to generated filenames for uniqueness + Prefix string + // FilePerm specifies the file permissions for created profile files + FilePerm fs.FileMode +} + +// prefixed prepends the Settings prefix to the given suffix +func (s *Settings) prefixed(suffix string) string { + return s.Prefix + suffix +} + +// path joins the basepath with the given filename +func (s *Settings) path(filename string) string { + return filepath.Join(s.Basepath, filename) +} + +// prefixedPath returns the full path for a prefixed filename +func (s *Settings) prefixedPath(filename string) string { + return s.path(s.prefixed(filename)) +} + +// Create creates an AWS Profile with credentials and config files from given credentials. +// Created credentials and config files can be clean up with [Cleanup]. +func Create(settings Settings, credentias Credentials) (Profile, error) { + if !credentias.isValid() { + return Profile{}, ErrInvalidCredentials + } + + name := settings.prefixed(awsProfileNameSuffix) + + configFilename := settings.prefixed(awsProfileConfigFilenameSuffix) + configPath := settings.path(configFilename) + err := writeAWSProfileFile(configPath, configFileContents(name), settings.FilePerm) + if err != nil { + return Profile{}, fmt.Errorf("aws-profile: Failed to create config file %s: %v", configPath, err) + } + + credentialsFilename := settings.prefixed(awsProfileCredentialsFilenameSuffix) + credentialsPath := settings.path(credentialsFilename) + err = writeAWSProfileFile(credentialsPath, credentialsFileContents(name, credentias), settings.FilePerm) + if err != nil { + return Profile{}, fmt.Errorf("aws-profile: Failed to create credentials file %s: %v", credentialsPath, err) + } + + return Profile{ + Name: name, + ConfigFilename: configFilename, + CredentialsFilename: credentialsFilename, + }, nil +} + +// Cleanup cleans up credentials and config files created via [Create]. +func Cleanup(settings Settings) error { + configPath := settings.prefixedPath(awsProfileConfigFilenameSuffix) + if err := os.Remove(configPath); err != nil { + if !errors.Is(err, fs.ErrNotExist) { + return fmt.Errorf("aws-profile: Failed to remove config file %s: %v", configPath, err) + } + } + + credentialsPath := settings.prefixedPath(awsProfileCredentialsFilenameSuffix) + if err := os.Remove(credentialsPath); err != nil { + if !errors.Is(err, fs.ErrNotExist) { + return fmt.Errorf("aws-profile: Failed to remove credentials file %s: %v", credentialsPath, err) + } + } + + return nil +} + +// writeAWSProfileFile safely writes AWS profile content to a file with given permissions +func writeAWSProfileFile(path string, content string, filePerm os.FileMode) error { + return renameio.WriteFile(path, []byte(content), filePerm) +} + +// credentialsFileContents generates the contents for an AWS credentials file +func credentialsFileContents(profile string, credentials Credentials) string { + var b strings.Builder + b.Grow(128) + b.WriteRune('[') + b.WriteString(profile) + b.WriteRune(']') + b.WriteRune('\n') + + b.WriteString("aws_access_key_id=") + b.WriteString(credentials.AccessKeyID) + b.WriteRune('\n') + + b.WriteString("aws_secret_access_key=") + b.WriteString(credentials.SecretAccessKey) + b.WriteRune('\n') + + if credentials.SessionToken != "" { + b.WriteString("aws_session_token=") + b.WriteString(credentials.SessionToken) + b.WriteRune('\n') + } + + return b.String() +} + +// configFileContents generates the contents for an AWS config file +func configFileContents(profile string) string { + return fmt.Sprintf("[profile %s]\n", profile) +} + +// isValidCredential checks whether given credential file contains any non-printable characters. +func isValidCredential(s string) bool { + return !strings.ContainsFunc(s, func(r rune) bool { return !unicode.IsPrint(r) }) +} diff --git a/pkg/driver/node/credentialprovider/awsprofile/aws_profile_test.go b/pkg/driver/node/credentialprovider/awsprofile/aws_profile_test.go new file mode 100644 index 00000000..adc6a4a8 --- /dev/null +++ b/pkg/driver/node/credentialprovider/awsprofile/aws_profile_test.go @@ -0,0 +1,142 @@ +package awsprofile_test + +import ( + "errors" + "io/fs" + "os" + "path/filepath" + "testing" + + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/credentialprovider/awsprofile" + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/credentialprovider/awsprofile/awsprofiletest" + "github.com/awslabs/aws-s3-csi-driver/pkg/util/testutil/assert" +) + +const testAccessKeyId = "test-access-key-id" +const testSecretAccessKey = "test-secret-access-key" +const testSessionToken = "test-session-token" +const testFilePerm = fs.FileMode(0600) + +func TestCreatingAWSProfile(t *testing.T) { + defaultSettings := awsprofile.Settings{ + Basepath: t.TempDir(), + Prefix: "test-", + FilePerm: testFilePerm, + } + + t.Run("create config and credentials files", func(t *testing.T) { + creds := awsprofile.Credentials{ + AccessKeyID: testAccessKeyId, + SecretAccessKey: testSecretAccessKey, + SessionToken: testSessionToken, + } + profile, err := awsprofile.Create(defaultSettings, creds) + assert.NoError(t, err) + assertCredentialsFromAWSProfile(t, defaultSettings.Basepath, profile, testAccessKeyId, testSecretAccessKey, testSessionToken) + }) + + t.Run("create config and credentials files with empty session token", func(t *testing.T) { + creds := awsprofile.Credentials{ + AccessKeyID: testAccessKeyId, + SecretAccessKey: testSecretAccessKey, + } + profile, err := awsprofile.Create(defaultSettings, creds) + assert.NoError(t, err) + assertCredentialsFromAWSProfile(t, defaultSettings.Basepath, profile, testAccessKeyId, testSecretAccessKey, "") + }) + + t.Run("ensure config and credentials files are created with correct permissions", func(t *testing.T) { + creds := awsprofile.Credentials{ + AccessKeyID: testAccessKeyId, + SecretAccessKey: testSecretAccessKey, + SessionToken: testSessionToken, + } + profile, err := awsprofile.Create(defaultSettings, creds) + assert.NoError(t, err) + assertCredentialsFromAWSProfile(t, defaultSettings.Basepath, profile, testAccessKeyId, testSecretAccessKey, testSessionToken) + + configStat, err := os.Stat(filepath.Join(defaultSettings.Basepath, profile.ConfigFilename)) + assert.NoError(t, err) + assert.Equals(t, testFilePerm, configStat.Mode()) + + credentialsStat, err := os.Stat(filepath.Join(defaultSettings.Basepath, profile.CredentialsFilename)) + assert.NoError(t, err) + assert.Equals(t, testFilePerm, credentialsStat.Mode()) + }) + + t.Run("fail if credentials contains non-ascii characters", func(t *testing.T) { + t.Run("access key ID", func(t *testing.T) { + creds := awsprofile.Credentials{ + AccessKeyID: testAccessKeyId + "\n\t\r credential_process=exit", + SecretAccessKey: testSecretAccessKey, + SessionToken: testSessionToken, + } + _, err := awsprofile.Create(defaultSettings, creds) + assert.Equals(t, true, errors.Is(err, awsprofile.ErrInvalidCredentials)) + }) + t.Run("secret access key", func(t *testing.T) { + creds := awsprofile.Credentials{ + AccessKeyID: testAccessKeyId, + SecretAccessKey: testSecretAccessKey + "\n", + SessionToken: testSessionToken, + } + _, err := awsprofile.Create(defaultSettings, creds) + assert.Equals(t, true, errors.Is(err, awsprofile.ErrInvalidCredentials)) + }) + t.Run("session token", func(t *testing.T) { + creds := awsprofile.Credentials{ + AccessKeyID: testAccessKeyId, + SecretAccessKey: testSecretAccessKey, + SessionToken: testSessionToken + "\n\r", + } + _, err := awsprofile.Create(defaultSettings, creds) + assert.Equals(t, true, errors.Is(err, awsprofile.ErrInvalidCredentials)) + }) + }) +} + +func TestCleaningUpAWSProfile(t *testing.T) { + settings := awsprofile.Settings{ + Basepath: t.TempDir(), + Prefix: "test-", + FilePerm: testFilePerm, + } + + t.Run("clean config and credentials files", func(t *testing.T) { + creds := awsprofile.Credentials{ + AccessKeyID: testAccessKeyId, + SecretAccessKey: testSecretAccessKey, + SessionToken: testSessionToken, + } + + profile, err := awsprofile.Create(settings, creds) + assert.NoError(t, err) + assertCredentialsFromAWSProfile(t, settings.Basepath, profile, testAccessKeyId, testSecretAccessKey, testSessionToken) + + err = awsprofile.Cleanup(settings) + assert.NoError(t, err) + + _, err = os.Stat(filepath.Join(settings.Basepath, profile.ConfigFilename)) + assert.Equals(t, true, errors.Is(err, fs.ErrNotExist)) + + _, err = os.Stat(filepath.Join(settings.Basepath, profile.CredentialsFilename)) + assert.Equals(t, true, errors.Is(err, fs.ErrNotExist)) + }) + + t.Run("cleaning non-existent config and credentials files should not be an error", func(t *testing.T) { + err := awsprofile.Cleanup(settings) + assert.NoError(t, err) + }) +} + +func assertCredentialsFromAWSProfile(t *testing.T, basepath string, profile awsprofile.Profile, accessKeyID string, secretAccessKey string, sessionToken string) { + awsprofiletest.AssertCredentialsFromAWSProfile( + t, + profile.Name, + filepath.Join(basepath, profile.ConfigFilename), + filepath.Join(basepath, profile.CredentialsFilename), + accessKeyID, + secretAccessKey, + sessionToken, + ) +} diff --git a/pkg/driver/node/credentialprovider/awsprofile/awsprofiletest/aws_profile.go b/pkg/driver/node/credentialprovider/awsprofile/awsprofiletest/aws_profile.go new file mode 100644 index 00000000..5670eb6b --- /dev/null +++ b/pkg/driver/node/credentialprovider/awsprofile/awsprofiletest/aws_profile.go @@ -0,0 +1,30 @@ +// Package awsprofiletest provides testing utilities for AWS Profiles. +package awsprofiletest + +import ( + "context" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + + "github.com/awslabs/aws-s3-csi-driver/pkg/util/testutil/assert" +) + +func AssertCredentialsFromAWSProfile(t *testing.T, profileName, configFile, credentialsFile, accessKeyID, secretAccessKey, sessionToken string) { + t.Helper() + + credentials := parseAWSProfile(t, profileName, configFile, credentialsFile) + assert.Equals(t, accessKeyID, credentials.AccessKeyID) + assert.Equals(t, secretAccessKey, credentials.SecretAccessKey) + assert.Equals(t, sessionToken, credentials.SessionToken) +} + +func parseAWSProfile(t *testing.T, profileName, configFile, credentialsFile string) aws.Credentials { + sharedConfig, err := config.LoadSharedConfigProfile(context.Background(), profileName, func(c *config.LoadSharedConfigOptions) { + c.ConfigFiles = []string{configFile} + c.CredentialsFiles = []string{credentialsFile} + }) + assert.NoError(t, err) + return sharedConfig.Credentials +} diff --git a/pkg/driver/node/credentialprovider/provider.go b/pkg/driver/node/credentialprovider/provider.go new file mode 100644 index 00000000..58fae514 --- /dev/null +++ b/pkg/driver/node/credentialprovider/provider.go @@ -0,0 +1,125 @@ +// Package credentialprovider provides utilities for obtaining AWS credentials to use. +// Depending on the configuration, it either uses Pod-level or Driver-level credentials. +package credentialprovider + +import ( + "context" + "errors" + "fmt" + "io/fs" + "strings" + + k8sv1 "k8s.io/client-go/kubernetes/typed/core/v1" + k8sstrings "k8s.io/utils/strings" + + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/envprovider" +) + +// CredentialFilePerm is the default permissions to be used for credential files. +// It's only readable and writeable by the owner. +const CredentialFilePerm = fs.FileMode(0600) + +// CredentialDirPerm is the default permissions to be used for credential directories. +// It's only readable, listable (execute bit), and writeable by the owner. +const CredentialDirPerm = fs.FileMode(0700) + +// An AuthenticationSource represents the source where the credentials was obtained. +type AuthenticationSource = string + +const ( + // This is when users don't provide a `authenticationSource` option in their volume attributes. + // We're defaulting to `driver` in this case. + AuthenticationSourceUnspecified AuthenticationSource = "" + AuthenticationSourceDriver AuthenticationSource = "driver" + AuthenticationSourcePod AuthenticationSource = "pod" +) + +// A Provider provides methods for accessing AWS credentials. +type Provider struct { + client k8sv1.CoreV1Interface + regionFromIMDS func() (string, error) +} + +// A ProvideContext contains parameters needed to provide credentials for a volume mount. +// +// Here, [WritePath] and [EnvPath] are used together to provide credential files to Mountpoint. +// The [Provider.Provide] method decides on filenames for credentials (e.g., `token` for driver-level service account token) +// and writes credentials with these filenames in [WritePath], and returns environment variables to pass Mountpoint +// with these filenames in [EnvPath]. +// This is due to fact that Mountpoint and the CSI Driver Node Pod - caller of this method - runs with different +// filesystems, and to communicate with each other, the CSI Driver Node Pod uses `hostPath` volume to gain +// access some path visible from both the CSI Driver Node Pod and Mountpoint, and setups files in that volume +// using [WritePath] and returns paths to these files in [EnvPath], so Mountpoint can correctly read these files. +type ProvideContext struct { + // WritePath is basepath to write credentials into. + WritePath string + // EnvPath is basepath to use while creating environment variables to pass Mountpoint. + EnvPath string + + PodID string + VolumeID string + + // The following values are provided from CSI volume context. + AuthenticationSource AuthenticationSource + PodNamespace string + ServiceAccountTokens string + ServiceAccountName string + // StsRegion is the `stsRegion` parameter passed via volume attribute. + StsRegion string + // BucketRegion is the `--region` parameter passed via mount options. + BucketRegion string +} + +// A CleanupContext contains parameters needed to clean up credentials after volume unmount. +type CleanupContext struct { + // WritePath is basepath where credentials previously written into. + WritePath string + PodID string + VolumeID string +} + +// New creates a new [Provider] with given client. +func New(client k8sv1.CoreV1Interface, regionFromIMDS func() (string, error)) *Provider { + // `regionFromIMDS` is a `sync.OnceValues` and it only makes request to IMDS once, + // this call is basically here to pre-warm the cache of IMDS call. + go func() { + _, _ = regionFromIMDS() + }() + + return &Provider{client, regionFromIMDS} +} + +// Provide provides credentials for given context. +// Depending on the configuration, it either returns driver-level or pod-level credentials. +func (c *Provider) Provide(ctx context.Context, provideCtx ProvideContext) (envprovider.Environment, AuthenticationSource, error) { + authenticationSource := provideCtx.AuthenticationSource + switch authenticationSource { + case AuthenticationSourcePod: + env, err := c.provideFromPod(ctx, provideCtx) + return env, AuthenticationSourcePod, err + case AuthenticationSourceUnspecified, AuthenticationSourceDriver: + env, err := c.provideFromDriver(provideCtx) + return env, AuthenticationSourceDriver, err + default: + return nil, AuthenticationSourceUnspecified, fmt.Errorf("unknown `authenticationSource`: %s, only `driver` (default option if not specified) and `pod` supported", authenticationSource) + } +} + +// Cleanup cleans any previously created credential files for given context. +func (c *Provider) Cleanup(cleanupCtx CleanupContext) error { + errPod := c.cleanupFromPod(cleanupCtx) + errDriver := c.cleanupFromDriver(cleanupCtx) + return errors.Join(errPod, errDriver) +} + +// escapedVolumeIdentifier returns "{podID}-{volumeID}" as a unique identifier for this volume. +// It also escapes slashes to make this identifier path-safe. +func escapedVolumeIdentifier(podID string, volumeID string) string { + var filename strings.Builder + // `podID` is a UUID, but escape it to ensure it doesn't contain `/` + filename.WriteString(k8sstrings.EscapeQualifiedName(podID)) + filename.WriteRune('-') + // `volumeID` might contain `/`, we need to escape it + filename.WriteString(k8sstrings.EscapeQualifiedName(volumeID)) + return filename.String() +} diff --git a/pkg/driver/node/credentialprovider/provider_driver.go b/pkg/driver/node/credentialprovider/provider_driver.go new file mode 100644 index 00000000..33ec43e5 --- /dev/null +++ b/pkg/driver/node/credentialprovider/provider_driver.go @@ -0,0 +1,122 @@ +package credentialprovider + +import ( + "fmt" + "os" + "path/filepath" + + "k8s.io/klog/v2" + + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/credentialprovider/awsprofile" + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/envprovider" + "github.com/awslabs/aws-s3-csi-driver/pkg/util" +) + +const ( + driverLevelServiceAccountTokenName = "token" +) + +// provideFromDriver provides driver-level AWS credentials. +func (c *Provider) provideFromDriver(provideCtx ProvideContext) (envprovider.Environment, error) { + klog.V(4).Infof("credentialprovider: Using driver identity") + + env := envprovider.Environment{} + + // Long-term AWS credentials + accessKeyID := os.Getenv(envprovider.EnvAccessKeyID) + secretAccessKey := os.Getenv(envprovider.EnvSecretAccessKey) + if accessKeyID != "" && secretAccessKey != "" { + sessionToken := os.Getenv(envprovider.EnvSessionToken) + longTermCredsEnv, err := provideLongTermCredentialsFromDriver(provideCtx, accessKeyID, secretAccessKey, sessionToken) + if err != nil { + klog.V(4).ErrorS(err, "credentialprovider: Failed to provide long-term AWS credentials") + return nil, err + } + + env.Merge(longTermCredsEnv) + } else { + // Profile provider + // TODO: This is not officially supported and won't work by default with containerization. + configFile := os.Getenv(envprovider.EnvConfigFile) + sharedCredentialsFile := os.Getenv(envprovider.EnvSharedCredentialsFile) + if configFile != "" && sharedCredentialsFile != "" { + env.Set(envprovider.EnvConfigFile, configFile) + env.Set(envprovider.EnvSharedCredentialsFile, sharedCredentialsFile) + } + } + + // STS Web Identity provider + webIdentityTokenFile := os.Getenv(envprovider.EnvWebIdentityTokenFile) + roleARN := os.Getenv(envprovider.EnvRoleARN) + if webIdentityTokenFile != "" && roleARN != "" { + stsWebIdentityCredsEnv, err := provideStsWebIdentityCredentialsFromDriver(provideCtx) + if err != nil { + klog.V(4).ErrorS(err, "credentialprovider: Failed to provide STS Web Identity credentials from driver") + return nil, err + } + + env.Merge(stsWebIdentityCredsEnv) + } + + return env, nil +} + +// cleanupFromDriver removes any credential files that were created for driver-level authentication via [Provider.provideFromDriver]. +func (c *Provider) cleanupFromDriver(cleanupCtx CleanupContext) error { + prefix := driverLevelLongTermCredentialsProfilePrefix(cleanupCtx.PodID, cleanupCtx.VolumeID) + return awsprofile.Cleanup(awsprofile.Settings{ + Basepath: cleanupCtx.WritePath, + Prefix: prefix, + }) +} + +// provideStsWebIdentityCredentialsFromDriver provides credentials for STS Web Identity from the driver's service account. +// It basically copies driver's injected service account token to [provideCtx.WritePath]. +func provideStsWebIdentityCredentialsFromDriver(provideCtx ProvideContext) (envprovider.Environment, error) { + driverServiceAccountTokenFile := os.Getenv(envprovider.EnvWebIdentityTokenFile) + tokenFile := filepath.Join(provideCtx.WritePath, driverLevelServiceAccountTokenName) + err := util.ReplaceFile(tokenFile, driverServiceAccountTokenFile, CredentialFilePerm) + if err != nil { + return nil, fmt.Errorf("fcredentialprovider: sts-web-identity: failed to copy driver's service account token: %w", err) + } + + return envprovider.Environment{ + envprovider.EnvRoleARN: os.Getenv(envprovider.EnvRoleARN), + envprovider.EnvWebIdentityTokenFile: filepath.Join(provideCtx.EnvPath, driverLevelServiceAccountTokenName), + }, nil +} + +// provideLongTermCredentialsFromDriver provides long-term AWS credentials from the driver's environment variables. +// These variables injected to driver's Pod from a configured Kubernetes secret if configured, here it basically +// created a AWS Profile from these credentials in [provideCtx.WritePath]. +func provideLongTermCredentialsFromDriver(provideCtx ProvideContext, accessKeyID, secretAccessKey, sessionToken string) (envprovider.Environment, error) { + prefix := driverLevelLongTermCredentialsProfilePrefix(provideCtx.PodID, provideCtx.VolumeID) + awsProfile, err := awsprofile.Create(awsprofile.Settings{ + Basepath: provideCtx.WritePath, + Prefix: prefix, + FilePerm: CredentialFilePerm, + }, awsprofile.Credentials{ + AccessKeyID: accessKeyID, + SecretAccessKey: secretAccessKey, + SessionToken: sessionToken, + }) + if err != nil { + return nil, fmt.Errorf("credentialprovider: long-term: failed to create aws profile: %w", err) + } + + profile := awsProfile.Name + configFile := filepath.Join(provideCtx.EnvPath, awsProfile.ConfigFilename) + credentialsFile := filepath.Join(provideCtx.EnvPath, awsProfile.CredentialsFilename) + + return envprovider.Environment{ + envprovider.EnvProfile: profile, + envprovider.EnvConfigFile: configFile, + envprovider.EnvSharedCredentialsFile: credentialsFile, + }, nil +} + +// driverLevelLongTermCredentialsProfilePrefix generates a prefix for AWS credential profile names +// when using driver-level authentication. The prefix includes both pod and volume IDs to ensure uniqueness. +func driverLevelLongTermCredentialsProfilePrefix(podID, volumeID string) string { + return escapedVolumeIdentifier(podID, volumeID) + "-" +} diff --git a/pkg/driver/node/credentialprovider/provider_pod.go b/pkg/driver/node/credentialprovider/provider_pod.go new file mode 100644 index 00000000..d011a942 --- /dev/null +++ b/pkg/driver/node/credentialprovider/provider_pod.go @@ -0,0 +1,142 @@ +package credentialprovider + +import ( + "context" + "encoding/json" + "errors" + "io/fs" + "os" + "path/filepath" + "time" + + "github.com/google/renameio" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/klog/v2" + + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/envprovider" +) + +const ( + serviceAccountTokenAudienceSTS = "sts.amazonaws.com" + serviceAccountRoleAnnotation = "eks.amazonaws.com/role-arn" +) + +const podLevelCredentialsDocsPage = "https://github.com/awslabs/mountpoint-s3-csi-driver/blob/main/docs/CONFIGURATION.md#pod-level-credentials" +const stsConfigDocsPage = "https://github.com/awslabs/mountpoint-s3-csi-driver/blob/main/docs/CONFIGURATION.md#configuring-the-sts-region" + +type serviceAccountToken struct { + Token string `json:"token"` + ExpirationTimestamp time.Time `json:"expirationTimestamp"` +} + +// provideFromPod provides pod-level AWS credentials. +func (c *Provider) provideFromPod(ctx context.Context, provideCtx ProvideContext) (envprovider.Environment, error) { + klog.V(4).Infof("credentialprovider: Using pod identity") + + tokensJson := provideCtx.ServiceAccountTokens + if tokensJson == "" { + klog.Error("credentialprovider: `authenticationSource` configured to `pod` but no service account tokens are received. Please make sure to enable `podInfoOnMountCompat`, see " + podLevelCredentialsDocsPage) + return nil, status.Error(codes.InvalidArgument, "Missing service account tokens. Please make sure to enable `podInfoOnMountCompat`, see "+podLevelCredentialsDocsPage) + } + + var tokens map[string]*serviceAccountToken + if err := json.Unmarshal([]byte(tokensJson), &tokens); err != nil { + return nil, status.Errorf(codes.InvalidArgument, "Failed to parse service account tokens: %v", err) + } + + stsToken := tokens[serviceAccountTokenAudienceSTS] + if stsToken == nil { + klog.Errorf("credentialprovider: `authenticationSource` configured to `pod` but no service account tokens for %s received. Please make sure to enable `podInfoOnMountCompat`, see "+podLevelCredentialsDocsPage, serviceAccountTokenAudienceSTS) + return nil, status.Errorf(codes.InvalidArgument, "Missing service account token for %s", serviceAccountTokenAudienceSTS) + } + + roleARN, err := c.findPodServiceAccountRole(ctx, provideCtx) + if err != nil { + return nil, err + } + + region, err := c.stsRegion(provideCtx) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "Failed to detect STS AWS Region, please explicitly set the AWS Region, see "+stsConfigDocsPage) + } + + defaultRegion := os.Getenv(envprovider.EnvDefaultRegion) + if defaultRegion == "" { + defaultRegion = region + } + + podID := provideCtx.PodID + volumeID := provideCtx.VolumeID + if podID == "" { + return nil, status.Error(codes.InvalidArgument, "Missing Pod info. Please make sure to enable `podInfoOnMountCompat`, see "+podLevelCredentialsDocsPage) + } + + tokenName := podLevelServiceAccountTokenName(podID, volumeID) + + err = renameio.WriteFile(filepath.Join(provideCtx.WritePath, tokenName), []byte(stsToken.Token), CredentialFilePerm) + if err != nil { + return nil, status.Errorf(codes.Internal, "Failed to write service account token: %v", err) + } + + podNamespace := provideCtx.PodNamespace + podServiceAccount := provideCtx.ServiceAccountName + cacheKey := podNamespace + "/" + podServiceAccount + + return envprovider.Environment{ + envprovider.EnvRoleARN: roleARN, + envprovider.EnvWebIdentityTokenFile: filepath.Join(provideCtx.EnvPath, tokenName), + + envprovider.EnvRegion: region, + envprovider.EnvDefaultRegion: defaultRegion, + + envprovider.EnvEC2MetadataDisabled: "true", + + // TODO: These were needed with `systemd` but probably won't be necessary with containerization. + envprovider.EnvMountpointCacheKey: cacheKey, + envprovider.EnvConfigFile: filepath.Join(provideCtx.EnvPath, "disable-config"), + envprovider.EnvSharedCredentialsFile: filepath.Join(provideCtx.EnvPath, "disable-credentials"), + }, nil +} + +// cleanupFromPod removes any credential files that were created for pod-level authentication authentication via [Provider.provideFromPod]. +func (c *Provider) cleanupFromPod(cleanupCtx CleanupContext) error { + tokenName := podLevelServiceAccountTokenName(cleanupCtx.PodID, cleanupCtx.VolumeID) + tokenPath := filepath.Join(cleanupCtx.WritePath, tokenName) + err := os.Remove(tokenPath) + if err != nil && errors.Is(err, fs.ErrNotExist) { + return nil + } + return err +} + +// findPodServiceAccountRole tries to provide associated AWS IAM role for service account specified in the volume context. +func (c *Provider) findPodServiceAccountRole(ctx context.Context, provideCtx ProvideContext) (string, error) { + podNamespace := provideCtx.PodNamespace + podServiceAccount := provideCtx.ServiceAccountName + if podNamespace == "" || podServiceAccount == "" { + klog.Error("credentialprovider: `authenticationSource` configured to `pod` but no pod info found. Please make sure to enable `podInfoOnMountCompat`, see " + podLevelCredentialsDocsPage) + return "", status.Error(codes.InvalidArgument, "Missing Pod info. Please make sure to enable `podInfoOnMountCompat`, see "+podLevelCredentialsDocsPage) + } + + response, err := c.client.ServiceAccounts(podNamespace).Get(ctx, podServiceAccount, metav1.GetOptions{}) + if err != nil { + return "", status.Errorf(codes.InvalidArgument, "Failed to get pod's service account %s/%s: %v", podNamespace, podServiceAccount, err) + } + + roleArn := response.Annotations[serviceAccountRoleAnnotation] + if roleArn == "" { + klog.Error("credentialprovider: `authenticationSource` configured to `pod` but pod's service account is not annotated with a role, see " + podLevelCredentialsDocsPage) + return "", status.Errorf(codes.InvalidArgument, "Missing role annotation on pod's service account %s/%s", podNamespace, podServiceAccount) + } + + return roleArn, nil +} + +// podLevelServiceAccountTokenName returns service account token name for Pod-level identity. +// It escapes from slashes to make this token name path-safe. +func podLevelServiceAccountTokenName(podID string, volumeID string) string { + id := escapedVolumeIdentifier(podID, volumeID) + return id + ".token" +} diff --git a/pkg/driver/node/credentialprovider/provider_test.go b/pkg/driver/node/credentialprovider/provider_test.go new file mode 100644 index 00000000..df28ce04 --- /dev/null +++ b/pkg/driver/node/credentialprovider/provider_test.go @@ -0,0 +1,864 @@ +package credentialprovider_test + +import ( + "context" + "encoding/json" + "errors" + "io/fs" + "os" + "path/filepath" + "testing" + "time" + + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/kubernetes/fake" + + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/credentialprovider" + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/credentialprovider/awsprofile/awsprofiletest" + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/envprovider" + "github.com/awslabs/aws-s3-csi-driver/pkg/util/testutil/assert" +) + +const testAccessKeyID = "test-access-key-id" +const testSecretAccessKey = "test-secret-access-key" +const testSessionToken = "test-session-token" + +const testRoleARN = "arn:aws:iam::111122223333:role/pod-a-role" +const testWebIdentityToken = "test-web-identity-token" + +const testPodID = "2a17db00-0bf3-4052-9b3f-6c89dcee5d79" +const testVolumeID = "test-vol" +const testProfilePrefix = testPodID + "-" + testVolumeID + "-" + +const testPodLevelServiceAccountToken = testPodID + "-" + testVolumeID + ".token" +const testDriverLevelServiceAccountToken = "token" + +const testPodServiceAccount = "test-sa" +const testPodNamespace = "test-ns" + +const testIMDSRegion = "us-east-1" + +func dummyRegionProvider() (string, error) { + return "us-east-1", nil +} + +const testEnvPath = "/test-env" + +func TestProvidingDriverLevelCredentials(t *testing.T) { + provider := credentialprovider.New(nil, dummyRegionProvider) + + authenticationSourceVariants := []string{ + credentialprovider.AuthenticationSourceDriver, + // It should fallback to Driver-level if authentication source is unspecified. + credentialprovider.AuthenticationSourceUnspecified, + } + + t.Run("only long-term credentials", func(t *testing.T) { + for _, authSource := range authenticationSourceVariants { + setEnvForLongTermCredentials(t) + + writePath := t.TempDir() + provideCtx := credentialprovider.ProvideContext{ + AuthenticationSource: authSource, + WritePath: writePath, + EnvPath: testEnvPath, + PodID: testPodID, + VolumeID: testVolumeID, + } + + env, source, err := provider.Provide(context.Background(), provideCtx) + assert.NoError(t, err) + assert.Equals(t, credentialprovider.AuthenticationSourceDriver, source) + assert.Equals(t, envprovider.Environment{ + "AWS_PROFILE": testProfilePrefix + "s3-csi", + "AWS_CONFIG_FILE": "/test-env/" + testProfilePrefix + "s3-csi-config", + "AWS_SHARED_CREDENTIALS_FILE": "/test-env/" + testProfilePrefix + "s3-csi-credentials", + }, env) + assertLongTermCredentials(t, writePath) + } + }) + + t.Run("only sts web identity credentials", func(t *testing.T) { + for _, authSource := range authenticationSourceVariants { + setEnvForStsWebIdentityCredentials(t) + + writePath := t.TempDir() + provideCtx := credentialprovider.ProvideContext{ + AuthenticationSource: authSource, + WritePath: writePath, + EnvPath: testEnvPath, + PodID: testPodID, + VolumeID: testVolumeID, + } + + env, source, err := provider.Provide(context.Background(), provideCtx) + assert.NoError(t, err) + assert.Equals(t, credentialprovider.AuthenticationSourceDriver, source) + assert.Equals(t, envprovider.Environment{ + "AWS_ROLE_ARN": testRoleARN, + "AWS_WEB_IDENTITY_TOKEN_FILE": filepath.Join(testEnvPath, testDriverLevelServiceAccountToken), + }, env) + assertWebIdentityTokenFile(t, filepath.Join(writePath, testDriverLevelServiceAccountToken)) + } + }) + + t.Run("only profile provider", func(t *testing.T) { + basepath := t.TempDir() + t.Setenv("AWS_CONFIG_FILE", filepath.Join(basepath, "config")) + t.Setenv("AWS_SHARED_CREDENTIALS_FILE", filepath.Join(basepath, "credentials")) + + provideCtx := credentialprovider.ProvideContext{ + AuthenticationSource: credentialprovider.AuthenticationSourceDriver, + WritePath: t.TempDir(), + EnvPath: testEnvPath, + PodID: testPodID, + VolumeID: testVolumeID, + } + + env, source, err := provider.Provide(context.Background(), provideCtx) + assert.NoError(t, err) + assert.Equals(t, credentialprovider.AuthenticationSourceDriver, source) + assert.Equals(t, envprovider.Environment{ + "AWS_CONFIG_FILE": filepath.Join(basepath, "config"), + "AWS_SHARED_CREDENTIALS_FILE": filepath.Join(basepath, "credentials"), + }, env) + }) + + t.Run("long-term and sts web identity credentials", func(t *testing.T) { + for _, authSource := range authenticationSourceVariants { + setEnvForLongTermCredentials(t) + setEnvForStsWebIdentityCredentials(t) + + writePath := t.TempDir() + provideCtx := credentialprovider.ProvideContext{ + AuthenticationSource: authSource, + WritePath: writePath, + EnvPath: testEnvPath, + PodID: testPodID, + VolumeID: testVolumeID, + } + + env, source, err := provider.Provide(context.Background(), provideCtx) + assert.NoError(t, err) + assert.Equals(t, credentialprovider.AuthenticationSourceDriver, source) + assert.Equals(t, envprovider.Environment{ + "AWS_PROFILE": testProfilePrefix + "s3-csi", + "AWS_CONFIG_FILE": "/test-env/" + testProfilePrefix + "s3-csi-config", + "AWS_SHARED_CREDENTIALS_FILE": "/test-env/" + testProfilePrefix + "s3-csi-credentials", + "AWS_ROLE_ARN": testRoleARN, + "AWS_WEB_IDENTITY_TOKEN_FILE": filepath.Join(testEnvPath, testDriverLevelServiceAccountToken), + }, env) + assertLongTermCredentials(t, writePath) + assertWebIdentityTokenFile(t, filepath.Join(writePath, testDriverLevelServiceAccountToken)) + } + }) + + t.Run("incomplete long-term credentials", func(t *testing.T) { + // Only set access key without secret + t.Setenv("AWS_ACCESS_KEY_ID", testAccessKeyID) + + provider := credentialprovider.New(nil, dummyRegionProvider) + writePath := t.TempDir() + provideCtx := credentialprovider.ProvideContext{ + AuthenticationSource: credentialprovider.AuthenticationSourceDriver, + WritePath: writePath, + EnvPath: testEnvPath, + PodID: testPodID, + VolumeID: testVolumeID, + } + + env, source, err := provider.Provide(context.Background(), provideCtx) + assert.NoError(t, err) + assert.Equals(t, credentialprovider.AuthenticationSourceDriver, source) + assert.Equals(t, envprovider.Environment{}, env) + + // Only set secret key without access key + t.Setenv("AWS_ACCESS_KEY_ID", "") + t.Setenv("AWS_SECRET_ACCESS_KEY", testSecretAccessKey) + + env, source, err = provider.Provide(context.Background(), provideCtx) + assert.NoError(t, err) + assert.Equals(t, credentialprovider.AuthenticationSourceDriver, source) + assert.Equals(t, envprovider.Environment{}, env) + }) + + t.Run("incomplete sts web identity credentials", func(t *testing.T) { + // Only set role ARN without token file + t.Setenv("AWS_ROLE_ARN", testRoleARN) + + provider := credentialprovider.New(nil, dummyRegionProvider) + writePath := t.TempDir() + provideCtx := credentialprovider.ProvideContext{ + AuthenticationSource: credentialprovider.AuthenticationSourceDriver, + WritePath: writePath, + EnvPath: testEnvPath, + PodID: testPodID, + VolumeID: testVolumeID, + } + + env, source, err := provider.Provide(context.Background(), provideCtx) + assert.NoError(t, err) + assert.Equals(t, credentialprovider.AuthenticationSourceDriver, source) + assert.Equals(t, envprovider.Environment{}, env) + + // Only set token file without role ARN + tokenPath := filepath.Join(t.TempDir(), "token") + assert.NoError(t, os.WriteFile(tokenPath, []byte(testWebIdentityToken), 0600)) + t.Setenv("AWS_ROLE_ARN", "") + t.Setenv("AWS_WEB_IDENTITY_TOKEN_FILE", tokenPath) + + env, source, err = provider.Provide(context.Background(), provideCtx) + assert.NoError(t, err) + assert.Equals(t, credentialprovider.AuthenticationSourceDriver, source) + assert.Equals(t, envprovider.Environment{}, env) + }) + + t.Run("no credentials", func(t *testing.T) { + for _, authSource := range authenticationSourceVariants { + writePath := t.TempDir() + provideCtx := credentialprovider.ProvideContext{ + AuthenticationSource: authSource, + WritePath: writePath, + EnvPath: testEnvPath, + PodID: testPodID, + VolumeID: testVolumeID, + } + + env, source, err := provider.Provide(context.Background(), provideCtx) + assert.NoError(t, err) + assert.Equals(t, credentialprovider.AuthenticationSourceDriver, source) + assert.Equals(t, envprovider.Environment{}, env) + } + }) +} + +func TestProvidingPodLevelCredentials(t *testing.T) { + t.Run("correct values", func(t *testing.T) { + clientset := fake.NewSimpleClientset(serviceAccount(testPodServiceAccount, testPodNamespace, map[string]string{ + "eks.amazonaws.com/role-arn": testRoleARN, + })) + provider := credentialprovider.New(clientset.CoreV1(), dummyRegionProvider) + + writePath := t.TempDir() + provideCtx := credentialprovider.ProvideContext{ + AuthenticationSource: credentialprovider.AuthenticationSourcePod, + WritePath: writePath, + EnvPath: testEnvPath, + PodID: testPodID, + VolumeID: testVolumeID, + PodNamespace: testPodNamespace, + ServiceAccountName: testPodServiceAccount, + ServiceAccountTokens: serviceAccountTokens(t, tokens{ + "sts.amazonaws.com": { + Token: testWebIdentityToken, + }, + }), + } + + env, source, err := provider.Provide(context.Background(), provideCtx) + assert.NoError(t, err) + assert.Equals(t, credentialprovider.AuthenticationSourcePod, source) + assert.Equals(t, envprovider.Environment{ + "AWS_ROLE_ARN": testRoleARN, + "AWS_WEB_IDENTITY_TOKEN_FILE": filepath.Join(testEnvPath, testPodLevelServiceAccountToken), + + // Having a unique cache key for namespace/serviceaccount pair + "UNSTABLE_MOUNTPOINT_CACHE_KEY": testPodNamespace + "/" + testPodServiceAccount, + + // Disable long-term credentials + "AWS_CONFIG_FILE": "/test-env/disable-config", + "AWS_SHARED_CREDENTIALS_FILE": "/test-env/disable-credentials", + + // Disable EC2 credentials + "AWS_EC2_METADATA_DISABLED": "true", + + "AWS_REGION": testIMDSRegion, + "AWS_DEFAULT_REGION": testIMDSRegion, + }, env) + assertWebIdentityTokenFile(t, filepath.Join(writePath, testPodLevelServiceAccountToken)) + }) + + t.Run("missing information", func(t *testing.T) { + clientset := fake.NewSimpleClientset( + serviceAccount(testPodServiceAccount, testPodNamespace, map[string]string{ + "eks.amazonaws.com/role-arn": testRoleARN, + }), + serviceAccount("test-sa-missing-role", testPodNamespace, map[string]string{}), + ) + + for name, provideCtx := range map[string]credentialprovider.ProvideContext{ + "unknown service account": { + AuthenticationSource: credentialprovider.AuthenticationSourcePod, + WritePath: t.TempDir(), + EnvPath: testEnvPath, + PodID: testPodID, + VolumeID: testVolumeID, + PodNamespace: testPodNamespace, + ServiceAccountName: "test-unknown-sa", + ServiceAccountTokens: serviceAccountTokens(t, tokens{ + "sts.amazonaws.com": { + Token: testWebIdentityToken, + }, + }), + }, + "missing role arn in service account": { + AuthenticationSource: credentialprovider.AuthenticationSourcePod, + WritePath: t.TempDir(), + EnvPath: testEnvPath, + PodID: testPodID, + VolumeID: testVolumeID, + PodNamespace: testPodNamespace, + ServiceAccountName: "test-sa-missing-role", + ServiceAccountTokens: serviceAccountTokens(t, tokens{ + "sts.amazonaws.com": { + Token: testWebIdentityToken, + }, + }), + }, + "missing service account token": { + AuthenticationSource: credentialprovider.AuthenticationSourcePod, + WritePath: t.TempDir(), + EnvPath: testEnvPath, + PodID: testPodID, + VolumeID: testVolumeID, + PodNamespace: testPodNamespace, + ServiceAccountName: testPodServiceAccount, + }, + "missing sts audience in service account tokens": { + AuthenticationSource: credentialprovider.AuthenticationSourcePod, + WritePath: t.TempDir(), + EnvPath: testEnvPath, + PodID: testPodID, + VolumeID: testVolumeID, + PodNamespace: testPodNamespace, + ServiceAccountName: testPodServiceAccount, + ServiceAccountTokens: serviceAccountTokens(t, tokens{ + "unknown": { + Token: testWebIdentityToken, + }, + }), + }, + "missing service account name": { + AuthenticationSource: credentialprovider.AuthenticationSourcePod, + WritePath: t.TempDir(), + EnvPath: testEnvPath, + PodID: testPodID, + VolumeID: testVolumeID, + PodNamespace: testPodNamespace, + ServiceAccountTokens: serviceAccountTokens(t, tokens{ + "sts.amazonaws.com": { + Token: testWebIdentityToken, + }, + }), + }, + "missing pod namespace": { + AuthenticationSource: credentialprovider.AuthenticationSourcePod, + WritePath: t.TempDir(), + EnvPath: testEnvPath, + PodID: testPodID, + VolumeID: testVolumeID, + ServiceAccountName: testPodServiceAccount, + ServiceAccountTokens: serviceAccountTokens(t, tokens{ + "sts.amazonaws.com": { + Token: testWebIdentityToken, + }, + }), + }, + } { + t.Run(name, func(t *testing.T) { + provider := credentialprovider.New(clientset.CoreV1(), dummyRegionProvider) + _, _, err := provider.Provide(context.Background(), provideCtx) + if err == nil { + t.Error("it should fail with missing information") + } + }) + } + }) +} + +func TestDetectingRegionToUseForPodLevelCredentials(t *testing.T) { + clientset := fake.NewSimpleClientset(serviceAccount(testPodServiceAccount, testPodNamespace, map[string]string{ + "eks.amazonaws.com/role-arn": testRoleARN, + })) + + baseProvideCtx := credentialprovider.ProvideContext{ + AuthenticationSource: credentialprovider.AuthenticationSourcePod, + WritePath: t.TempDir(), + EnvPath: testEnvPath, + PodID: testPodID, + VolumeID: testVolumeID, + PodNamespace: testPodNamespace, + ServiceAccountName: testPodServiceAccount, + ServiceAccountTokens: serviceAccountTokens(t, tokens{ + "sts.amazonaws.com": { + Token: testWebIdentityToken, + }, + }), + } + + t.Run("no region", func(t *testing.T) { + provider := credentialprovider.New(clientset.CoreV1(), func() (string, error) { + return "", errors.New("unknown region") + }) + + _, _, err := provider.Provide(context.Background(), baseProvideCtx) + if err == nil { + t.Error("it should fail if there is not any region information") + } + }) + + t.Run("region from imds", func(t *testing.T) { + provider := credentialprovider.New(clientset.CoreV1(), func() (string, error) { + return "us-east-2", nil + }) + + env, _, err := provider.Provide(context.Background(), baseProvideCtx) + assert.NoError(t, err) + assert.Equals(t, envprovider.Environment{ + "AWS_ROLE_ARN": testRoleARN, + "AWS_WEB_IDENTITY_TOKEN_FILE": filepath.Join(testEnvPath, testPodLevelServiceAccountToken), + "UNSTABLE_MOUNTPOINT_CACHE_KEY": testPodNamespace + "/" + testPodServiceAccount, + "AWS_CONFIG_FILE": "/test-env/disable-config", + "AWS_SHARED_CREDENTIALS_FILE": "/test-env/disable-credentials", + "AWS_EC2_METADATA_DISABLED": "true", + "AWS_REGION": "us-east-2", + "AWS_DEFAULT_REGION": "us-east-2", + }, env) + }) + + t.Run("region from env", func(t *testing.T) { + t.Setenv("AWS_REGION", "eu-west-1") + provider := credentialprovider.New(clientset.CoreV1(), dummyRegionProvider) + + env, _, err := provider.Provide(context.Background(), baseProvideCtx) + assert.NoError(t, err) + assert.Equals(t, envprovider.Environment{ + "AWS_ROLE_ARN": testRoleARN, + "AWS_WEB_IDENTITY_TOKEN_FILE": filepath.Join(testEnvPath, testPodLevelServiceAccountToken), + "UNSTABLE_MOUNTPOINT_CACHE_KEY": testPodNamespace + "/" + testPodServiceAccount, + "AWS_CONFIG_FILE": "/test-env/disable-config", + "AWS_SHARED_CREDENTIALS_FILE": "/test-env/disable-credentials", + "AWS_EC2_METADATA_DISABLED": "true", + "AWS_REGION": "eu-west-1", + "AWS_DEFAULT_REGION": "eu-west-1", + }, env) + }) + + t.Run("default region from env", func(t *testing.T) { + t.Setenv("AWS_DEFAULT_REGION", "eu-north-1") + provider := credentialprovider.New(clientset.CoreV1(), dummyRegionProvider) + + env, _, err := provider.Provide(context.Background(), baseProvideCtx) + assert.NoError(t, err) + assert.Equals(t, envprovider.Environment{ + "AWS_ROLE_ARN": testRoleARN, + "AWS_WEB_IDENTITY_TOKEN_FILE": filepath.Join(testEnvPath, testPodLevelServiceAccountToken), + "UNSTABLE_MOUNTPOINT_CACHE_KEY": testPodNamespace + "/" + testPodServiceAccount, + "AWS_CONFIG_FILE": "/test-env/disable-config", + "AWS_SHARED_CREDENTIALS_FILE": "/test-env/disable-credentials", + "AWS_EC2_METADATA_DISABLED": "true", + "AWS_REGION": "eu-north-1", + "AWS_DEFAULT_REGION": "eu-north-1", + }, env) + }) + + t.Run("default and regular region from env", func(t *testing.T) { + t.Setenv("AWS_REGION", "eu-west-1") + t.Setenv("AWS_DEFAULT_REGION", "eu-north-1") + provider := credentialprovider.New(clientset.CoreV1(), dummyRegionProvider) + + env, _, err := provider.Provide(context.Background(), baseProvideCtx) + assert.NoError(t, err) + assert.Equals(t, envprovider.Environment{ + "AWS_ROLE_ARN": testRoleARN, + "AWS_WEB_IDENTITY_TOKEN_FILE": filepath.Join(testEnvPath, testPodLevelServiceAccountToken), + "UNSTABLE_MOUNTPOINT_CACHE_KEY": testPodNamespace + "/" + testPodServiceAccount, + "AWS_CONFIG_FILE": "/test-env/disable-config", + "AWS_SHARED_CREDENTIALS_FILE": "/test-env/disable-credentials", + "AWS_EC2_METADATA_DISABLED": "true", + "AWS_REGION": "eu-west-1", + "AWS_DEFAULT_REGION": "eu-north-1", + }, env) + }) + + t.Run("region from options", func(t *testing.T) { + provider := credentialprovider.New(clientset.CoreV1(), dummyRegionProvider) + + provideCtx := baseProvideCtx + provideCtx.BucketRegion = "us-west-1" + env, _, err := provider.Provide(context.Background(), provideCtx) + assert.NoError(t, err) + assert.Equals(t, envprovider.Environment{ + "AWS_ROLE_ARN": testRoleARN, + "AWS_WEB_IDENTITY_TOKEN_FILE": filepath.Join(testEnvPath, testPodLevelServiceAccountToken), + "UNSTABLE_MOUNTPOINT_CACHE_KEY": testPodNamespace + "/" + testPodServiceAccount, + "AWS_CONFIG_FILE": "/test-env/disable-config", + "AWS_SHARED_CREDENTIALS_FILE": "/test-env/disable-credentials", + "AWS_EC2_METADATA_DISABLED": "true", + "AWS_REGION": "us-west-1", + "AWS_DEFAULT_REGION": "us-west-1", + }, env) + }) + + t.Run("region from options with default region from env", func(t *testing.T) { + t.Setenv("AWS_DEFAULT_REGION", "eu-north-1") + provider := credentialprovider.New(clientset.CoreV1(), dummyRegionProvider) + + provideCtx := baseProvideCtx + provideCtx.BucketRegion = "us-west-1" + env, _, err := provider.Provide(context.Background(), provideCtx) + assert.NoError(t, err) + assert.Equals(t, envprovider.Environment{ + "AWS_ROLE_ARN": testRoleARN, + "AWS_WEB_IDENTITY_TOKEN_FILE": filepath.Join(testEnvPath, testPodLevelServiceAccountToken), + "UNSTABLE_MOUNTPOINT_CACHE_KEY": testPodNamespace + "/" + testPodServiceAccount, + "AWS_CONFIG_FILE": "/test-env/disable-config", + "AWS_SHARED_CREDENTIALS_FILE": "/test-env/disable-credentials", + "AWS_EC2_METADATA_DISABLED": "true", + "AWS_REGION": "us-west-1", + "AWS_DEFAULT_REGION": "eu-north-1", + }, env) + }) + + t.Run("region from volume context", func(t *testing.T) { + provider := credentialprovider.New(clientset.CoreV1(), dummyRegionProvider) + + provideCtx := baseProvideCtx + provideCtx.StsRegion = "ap-south-1" + env, _, err := provider.Provide(context.Background(), provideCtx) + assert.NoError(t, err) + assert.Equals(t, envprovider.Environment{ + "AWS_ROLE_ARN": testRoleARN, + "AWS_WEB_IDENTITY_TOKEN_FILE": filepath.Join(testEnvPath, testPodLevelServiceAccountToken), + "UNSTABLE_MOUNTPOINT_CACHE_KEY": testPodNamespace + "/" + testPodServiceAccount, + "AWS_CONFIG_FILE": "/test-env/disable-config", + "AWS_SHARED_CREDENTIALS_FILE": "/test-env/disable-credentials", + "AWS_EC2_METADATA_DISABLED": "true", + "AWS_REGION": "ap-south-1", + "AWS_DEFAULT_REGION": "ap-south-1", + }, env) + }) + + t.Run("region from volume context with default region from env", func(t *testing.T) { + t.Setenv("AWS_DEFAULT_REGION", "eu-north-1") + provider := credentialprovider.New(clientset.CoreV1(), dummyRegionProvider) + + provideCtx := baseProvideCtx + provideCtx.StsRegion = "ap-south-1" + env, _, err := provider.Provide(context.Background(), provideCtx) + assert.NoError(t, err) + assert.Equals(t, envprovider.Environment{ + "AWS_ROLE_ARN": testRoleARN, + "AWS_WEB_IDENTITY_TOKEN_FILE": filepath.Join(testEnvPath, testPodLevelServiceAccountToken), + "UNSTABLE_MOUNTPOINT_CACHE_KEY": testPodNamespace + "/" + testPodServiceAccount, + "AWS_CONFIG_FILE": "/test-env/disable-config", + "AWS_SHARED_CREDENTIALS_FILE": "/test-env/disable-credentials", + "AWS_EC2_METADATA_DISABLED": "true", + "AWS_REGION": "ap-south-1", + "AWS_DEFAULT_REGION": "eu-north-1", + }, env) + }) +} + +func TestProvidingPodLevelCredentialsForDifferentPods(t *testing.T) { + clientset := fake.NewSimpleClientset( + serviceAccount("test-sa-1", testPodNamespace, map[string]string{ + "eks.amazonaws.com/role-arn": "arn:aws:iam::123456789012:role/Test1", + }), + serviceAccount("test-sa-2", testPodNamespace, map[string]string{ + "eks.amazonaws.com/role-arn": "arn:aws:iam::123456789012:role/Test2", + }), + ) + provider := credentialprovider.New(clientset.CoreV1(), dummyRegionProvider) + + baseProvideCtx := credentialprovider.ProvideContext{ + AuthenticationSource: credentialprovider.AuthenticationSourcePod, + WritePath: t.TempDir(), + EnvPath: testEnvPath, + PodNamespace: testPodNamespace, + VolumeID: testVolumeID, + } + + provideCtxOne := baseProvideCtx + provideCtxOne.PodID = "pod1" + provideCtxOne.ServiceAccountName = "test-sa-1" + provideCtxOne.ServiceAccountTokens = serviceAccountTokens(t, tokens{ + "sts.amazonaws.com": {Token: "token1"}, + }) + + provideCtxTwo := baseProvideCtx + provideCtxTwo.PodID = "pod2" + provideCtxTwo.ServiceAccountName = "test-sa-2" + provideCtxTwo.ServiceAccountTokens = serviceAccountTokens(t, tokens{ + "sts.amazonaws.com": {Token: "token2"}, + }) + + envOne, sourceOne, err := provider.Provide(context.Background(), provideCtxOne) + assert.NoError(t, err) + assert.Equals(t, credentialprovider.AuthenticationSourcePod, sourceOne) + assert.Equals(t, envprovider.Environment{ + "AWS_ROLE_ARN": "arn:aws:iam::123456789012:role/Test1", + "AWS_WEB_IDENTITY_TOKEN_FILE": filepath.Join(testEnvPath, "pod1-"+testVolumeID+".token"), + "UNSTABLE_MOUNTPOINT_CACHE_KEY": testPodNamespace + "/test-sa-1", + "AWS_CONFIG_FILE": "/test-env/disable-config", + "AWS_SHARED_CREDENTIALS_FILE": "/test-env/disable-credentials", + "AWS_EC2_METADATA_DISABLED": "true", + "AWS_REGION": testIMDSRegion, + "AWS_DEFAULT_REGION": testIMDSRegion, + }, envOne) + + tokenOneContent, err := os.ReadFile(filepath.Join(provideCtxOne.WritePath, "pod1-"+testVolumeID+".token")) + assert.NoError(t, err) + assert.Equals(t, []byte("token1"), tokenOneContent) + + envTwo, sourceTwo, err := provider.Provide(context.Background(), provideCtxTwo) + assert.NoError(t, err) + assert.Equals(t, credentialprovider.AuthenticationSourcePod, sourceTwo) + assert.Equals(t, envprovider.Environment{ + "AWS_ROLE_ARN": "arn:aws:iam::123456789012:role/Test2", + "AWS_WEB_IDENTITY_TOKEN_FILE": filepath.Join(testEnvPath, "pod2-"+testVolumeID+".token"), + "UNSTABLE_MOUNTPOINT_CACHE_KEY": testPodNamespace + "/test-sa-2", + "AWS_CONFIG_FILE": "/test-env/disable-config", + "AWS_SHARED_CREDENTIALS_FILE": "/test-env/disable-credentials", + "AWS_EC2_METADATA_DISABLED": "true", + "AWS_REGION": testIMDSRegion, + "AWS_DEFAULT_REGION": testIMDSRegion, + }, envTwo) + + tokenContent2, err := os.ReadFile(filepath.Join(provideCtxTwo.WritePath, "pod2-"+testVolumeID+".token")) + assert.NoError(t, err) + assert.Equals(t, []byte("token2"), tokenContent2) +} + +func TestProvidingPodLevelCredentialsWithSlashInIDs(t *testing.T) { + clientset := fake.NewSimpleClientset(serviceAccount(testPodServiceAccount, testPodNamespace, map[string]string{ + "eks.amazonaws.com/role-arn": testRoleARN, + })) + provider := credentialprovider.New(clientset.CoreV1(), dummyRegionProvider) + + baseProvideCtx := credentialprovider.ProvideContext{ + AuthenticationSource: credentialprovider.AuthenticationSourcePod, + WritePath: t.TempDir(), + EnvPath: testEnvPath, + PodID: testPodID, + VolumeID: testVolumeID, + PodNamespace: testPodNamespace, + ServiceAccountName: testPodServiceAccount, + ServiceAccountTokens: serviceAccountTokens(t, tokens{ + "sts.amazonaws.com": {Token: testWebIdentityToken}, + }), + } + + t.Run("slash in volume id", func(t *testing.T) { + provideCtx := baseProvideCtx + provideCtx.VolumeID = "vol/1" + + env, source, err := provider.Provide(context.Background(), provideCtx) + assert.NoError(t, err) + assert.Equals(t, credentialprovider.AuthenticationSourcePod, source) + assert.Equals(t, envprovider.Environment{ + "AWS_ROLE_ARN": testRoleARN, + "AWS_WEB_IDENTITY_TOKEN_FILE": filepath.Join(testEnvPath, testPodID+"-vol~1.token"), + "UNSTABLE_MOUNTPOINT_CACHE_KEY": testPodNamespace + "/" + testPodServiceAccount, + "AWS_CONFIG_FILE": "/test-env/disable-config", + "AWS_SHARED_CREDENTIALS_FILE": "/test-env/disable-credentials", + "AWS_EC2_METADATA_DISABLED": "true", + "AWS_REGION": testIMDSRegion, + "AWS_DEFAULT_REGION": testIMDSRegion, + }, env) + + tokenContent, err := os.ReadFile(filepath.Join(provideCtx.WritePath, testPodID+"-vol~1.token")) + assert.NoError(t, err) + assert.Equals(t, []byte(testWebIdentityToken), tokenContent) + }) + + t.Run("slash in pod id", func(t *testing.T) { + provideCtx := baseProvideCtx + provideCtx.PodID = "pod/123" + + env, source, err := provider.Provide(context.Background(), provideCtx) + assert.NoError(t, err) + assert.Equals(t, credentialprovider.AuthenticationSourcePod, source) + assert.Equals(t, envprovider.Environment{ + "AWS_ROLE_ARN": testRoleARN, + "AWS_WEB_IDENTITY_TOKEN_FILE": filepath.Join(testEnvPath, "pod~123-"+testVolumeID+".token"), + "UNSTABLE_MOUNTPOINT_CACHE_KEY": testPodNamespace + "/" + testPodServiceAccount, + "AWS_CONFIG_FILE": "/test-env/disable-config", + "AWS_SHARED_CREDENTIALS_FILE": "/test-env/disable-credentials", + "AWS_EC2_METADATA_DISABLED": "true", + "AWS_REGION": testIMDSRegion, + "AWS_DEFAULT_REGION": testIMDSRegion, + }, env) + + tokenContent, err := os.ReadFile(filepath.Join(provideCtx.WritePath, "pod~123-"+testVolumeID+".token")) + assert.NoError(t, err) + assert.Equals(t, []byte(testWebIdentityToken), tokenContent) + }) +} + +func TestCleanup(t *testing.T) { + t.Run("cleanup driver level", func(t *testing.T) { + // Provide/create long-term credentials first + setEnvForLongTermCredentials(t) + + provider := credentialprovider.New(nil, dummyRegionProvider) + writePath := t.TempDir() + provideCtx := credentialprovider.ProvideContext{ + AuthenticationSource: credentialprovider.AuthenticationSourceDriver, + WritePath: writePath, + EnvPath: testEnvPath, + PodID: testPodID, + VolumeID: testVolumeID, + } + + env, source, err := provider.Provide(context.Background(), provideCtx) + assert.NoError(t, err) + assert.Equals(t, credentialprovider.AuthenticationSourceDriver, source) + assert.Equals(t, envprovider.Environment{ + "AWS_PROFILE": testProfilePrefix + "s3-csi", + "AWS_CONFIG_FILE": "/test-env/" + testProfilePrefix + "s3-csi-config", + "AWS_SHARED_CREDENTIALS_FILE": "/test-env/" + testProfilePrefix + "s3-csi-credentials", + }, env) + assertLongTermCredentials(t, writePath) + + // Perform cleanup + err = provider.Cleanup(credentialprovider.CleanupContext{ + WritePath: writePath, + PodID: testPodID, + VolumeID: testVolumeID, + }) + assert.NoError(t, err) + + // Verify files were removed + _, err = os.Stat(filepath.Join(writePath, testProfilePrefix+"s3-csi-config")) + if err == nil { + t.Fatalf("AWS Config file should be cleaned up") + } + assert.Equals(t, fs.ErrNotExist, err) + + _, err = os.Stat(filepath.Join(writePath, testProfilePrefix+"s3-csi-credentials")) + if err == nil { + t.Fatalf("AWS Credentials file should be cleaned up") + } + assert.Equals(t, fs.ErrNotExist, err) + }) + + t.Run("cleanup pod level", func(t *testing.T) { + // Provide/create STS Web Identity credentials first + clientset := fake.NewSimpleClientset(serviceAccount(testPodServiceAccount, testPodNamespace, map[string]string{ + "eks.amazonaws.com/role-arn": testRoleARN, + })) + provider := credentialprovider.New(clientset.CoreV1(), dummyRegionProvider) + + writePath := t.TempDir() + provideCtx := credentialprovider.ProvideContext{ + AuthenticationSource: credentialprovider.AuthenticationSourcePod, + WritePath: writePath, + EnvPath: testEnvPath, + PodID: testPodID, + VolumeID: testVolumeID, + PodNamespace: testPodNamespace, + ServiceAccountName: testPodServiceAccount, + ServiceAccountTokens: serviceAccountTokens(t, tokens{ + "sts.amazonaws.com": { + Token: testWebIdentityToken, + }, + }), + } + + env, source, err := provider.Provide(context.Background(), provideCtx) + assert.NoError(t, err) + assert.Equals(t, credentialprovider.AuthenticationSourcePod, source) + assert.Equals(t, testRoleARN, env["AWS_ROLE_ARN"]) + assert.Equals(t, filepath.Join(testEnvPath, testPodLevelServiceAccountToken), env["AWS_WEB_IDENTITY_TOKEN_FILE"]) + assertWebIdentityTokenFile(t, filepath.Join(writePath, testPodLevelServiceAccountToken)) + + // Perform cleanup + err = provider.Cleanup(credentialprovider.CleanupContext{ + WritePath: writePath, + PodID: testPodID, + VolumeID: testVolumeID, + }) + assert.NoError(t, err) + + // Verify file was removed + _, err = os.Stat(filepath.Join(writePath, testPodLevelServiceAccountToken)) + if err == nil { + t.Fatalf("Service Account Token should be cleaned up") + } + assert.Equals(t, fs.ErrNotExist, err) + }) + + t.Run("cleanup with non-existent files", func(t *testing.T) { + writePath := t.TempDir() + provider := credentialprovider.New(nil, dummyRegionProvider) + + // Cleanup should not fail if files don't exist + err := provider.Cleanup(credentialprovider.CleanupContext{ + WritePath: writePath, + PodID: testPodID, + VolumeID: testVolumeID, + }) + assert.NoError(t, err) + }) +} + +//-- Utilities for tests + +func setEnvForLongTermCredentials(t *testing.T) { + t.Setenv("AWS_ACCESS_KEY_ID", testAccessKeyID) + t.Setenv("AWS_SECRET_ACCESS_KEY", testSecretAccessKey) + t.Setenv("AWS_SESSION_TOKEN", testSessionToken) +} + +func assertLongTermCredentials(t *testing.T, basepath string) { + t.Helper() + + awsprofiletest.AssertCredentialsFromAWSProfile( + t, + testProfilePrefix+"s3-csi", + filepath.Join(basepath, testProfilePrefix+"s3-csi-config"), + filepath.Join(basepath, testProfilePrefix+"s3-csi-credentials"), + testAccessKeyID, + testSecretAccessKey, + testSessionToken, + ) +} + +func setEnvForStsWebIdentityCredentials(t *testing.T) { + t.Helper() + + tokenPath := filepath.Join(t.TempDir(), "token") + assert.NoError(t, os.WriteFile(tokenPath, []byte(testWebIdentityToken), 0600)) + + t.Setenv("AWS_ROLE_ARN", testRoleARN) + t.Setenv("AWS_WEB_IDENTITY_TOKEN_FILE", tokenPath) +} + +func assertWebIdentityTokenFile(t *testing.T, path string) { + t.Helper() + + got, err := os.ReadFile(path) + assert.NoError(t, err) + assert.Equals(t, []byte(testWebIdentityToken), got) +} + +type tokens = map[string]struct { + Token string `json:"token"` + ExpirationTimestamp time.Time +} + +func serviceAccountTokens(t *testing.T, tokens tokens) string { + buf, err := json.Marshal(&tokens) + assert.NoError(t, err) + return string(buf) +} + +func serviceAccount(name, namespace string, annotations map[string]string) *v1.ServiceAccount { + return &v1.ServiceAccount{ObjectMeta: metav1.ObjectMeta{ + Name: name, + Namespace: namespace, + Annotations: annotations, + }} +} diff --git a/pkg/driver/node/credentialprovider/sts_region.go b/pkg/driver/node/credentialprovider/sts_region.go new file mode 100644 index 00000000..fa6fb3a2 --- /dev/null +++ b/pkg/driver/node/credentialprovider/sts_region.go @@ -0,0 +1,84 @@ +package credentialprovider + +import ( + "context" + "errors" + "fmt" + "os" + "sync" + "time" + + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/feature/ec2/imds" + "k8s.io/klog/v2" + + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/envprovider" +) + +var errUnknownRegion = errors.New("credentialprovider: pod-level: unknown region") + +// stsRegion tries to detect AWS region to use for STS. +// +// It looks for the following (in-order): +// 1. `stsRegion` passed via volume context +// 2. Region set for S3 bucket via mount options +// 3. `AWS_REGION` or `AWS_DEFAULT_REGION` env variables +// 4. Calling IMDS to detect region +// +// It returns an error if all of them fails. +func (p *Provider) stsRegion(provideCtx ProvideContext) (string, error) { + region := provideCtx.StsRegion + if region != "" { + klog.V(5).Infof("credentialprovider: pod-level: Detected STS region %s from volume context", region) + return region, nil + } + + region = provideCtx.BucketRegion + if region != "" { + klog.V(5).Infof("credentialprovider: pod-level: Detected STS region %s from S3 bucket region", region) + return region, nil + } + + region = os.Getenv(envprovider.EnvRegion) + if region != "" { + klog.V(5).Infof("credentialprovider: pod-level: Detected STS region %s from `AWS_REGION` env variable", region) + return region, nil + } + + region = os.Getenv(envprovider.EnvDefaultRegion) + if region != "" { + klog.V(5).Infof("credentialprovider: pod-level: Detected STS region %s from `AWS_DEFAULT_REGION` env variable", region) + return region, nil + } + + // We're ignoring the error here, makes a call to IMDS only once and logs the error in case of error + region, _ = p.regionFromIMDS() + if region != "" { + klog.V(5).Infof("credentialprovider: pod-level: Detected STS region %s from IMDS", region) + return region, nil + } + + return "", errUnknownRegion +} + +// RegionFromIMDSOnce tries to detect AWS region by making a request to IMDS. +// It only makes request to IMDS once and caches the value. +var RegionFromIMDSOnce = sync.OnceValues(func() (string, error) { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + cfg, err := config.LoadDefaultConfig(ctx) + if err != nil { + klog.V(5).Infof("credentialprovider: pod-level: Failed to create config for IMDS client: %v", err) + return "", fmt.Errorf("could not create config for imds client: %w", err) + } + + client := imds.NewFromConfig(cfg) + output, err := client.GetRegion(ctx, &imds.GetRegionInput{}) + if err != nil { + klog.V(5).Infof("credentialprovider: pod-level: Failed to get region from IMDS: %v", err) + return "", fmt.Errorf("failed to get region from imds: %w", err) + } + + return output.Region, nil +}) diff --git a/pkg/driver/node/mounter/credential_provider.go b/pkg/driver/node/mounter/credential_provider.go deleted file mode 100644 index 31a19fe1..00000000 --- a/pkg/driver/node/mounter/credential_provider.go +++ /dev/null @@ -1,312 +0,0 @@ -package mounter - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "os" - "path" - "strings" - "sync" - "time" - - "github.com/aws/aws-sdk-go-v2/config" - "github.com/aws/aws-sdk-go-v2/feature/ec2/imds" - "github.com/google/renameio" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - k8sv1 "k8s.io/client-go/kubernetes/typed/core/v1" - "k8s.io/klog/v2" - k8sstrings "k8s.io/utils/strings" - - "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/envprovider" - "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/volumecontext" - "github.com/awslabs/aws-s3-csi-driver/pkg/mountpoint" -) - -const hostPluginDirEnv = "HOST_PLUGIN_DIR" - -type AuthenticationSource = string - -const ( - // This is when users don't provide a `authenticationSource` option in their volume attributes. - // We're defaulting to `driver` in this case. - AuthenticationSourceUnspecified AuthenticationSource = "" - AuthenticationSourceDriver AuthenticationSource = "driver" - AuthenticationSourcePod AuthenticationSource = "pod" -) - -const ( - // This is to ensure only owner/group can read the file and no one else. - serviceAccountTokenPerm = 0440 -) - -const defaultHostPluginDir = "/var/lib/kubelet/plugins/s3.csi.aws.com/" - -const serviceAccountTokenAudienceSTS = "sts.amazonaws.com" - -const serviceAccountRoleAnnotation = "eks.amazonaws.com/role-arn" - -const podLevelCredentialsDocsPage = "https://github.com/awslabs/mountpoint-s3-csi-driver/blob/main/docs/CONFIGURATION.md#pod-level-credentials" -const stsConfigDocsPage = "https://github.com/awslabs/mountpoint-s3-csi-driver/blob/main/docs/CONFIGURATION.md#configuring-the-sts-region" - -var errUnknownRegion = errors.New("NodePublishVolume: Pod-level: unknown region") - -type Token struct { - Token string `json:"token"` - ExpirationTimestamp time.Time `json:"expirationTimestamp"` -} - -type CredentialProvider struct { - client k8sv1.CoreV1Interface - containerPluginDir string - regionFromIMDS func() (string, error) -} - -func NewCredentialProvider(client k8sv1.CoreV1Interface, containerPluginDir string, regionFromIMDS func() (string, error)) *CredentialProvider { - // `regionFromIMDS` is a `sync.OnceValues` and it only makes request to IMDS once, - // this call is basically here to pre-warm the cache of IMDS call. - go func() { - _, _ = regionFromIMDS() - }() - - return &CredentialProvider{client, containerPluginDir, regionFromIMDS} -} - -// CleanupToken cleans any created service token files for given volume and pod. -func (c *CredentialProvider) CleanupToken(volumeID string, podID string) error { - err := os.Remove(c.tokenPathContainer(podID, volumeID)) - if err != nil && os.IsNotExist(err) { - return nil - } - return err -} - -// Provide provides mount credentials for given volume and volume context. -// Depending on the configuration, it either returns driver-level or pod-level credentials. -func (c *CredentialProvider) Provide(ctx context.Context, volumeID string, volumeCtx map[string]string, args mountpoint.Args) (*MountCredentials, error) { - if volumeCtx == nil { - return nil, status.Error(codes.InvalidArgument, "Missing volume context") - } - - authenticationSource := volumeCtx[volumecontext.AuthenticationSource] - switch authenticationSource { - case AuthenticationSourcePod: - return c.provideFromPod(ctx, volumeID, volumeCtx, args) - case AuthenticationSourceUnspecified, AuthenticationSourceDriver: - return c.provideFromDriver() - default: - return nil, fmt.Errorf("unknown `authenticationSource`: %s, only `driver` (default option if not specified) and `pod` supported", authenticationSource) - } -} - -func (c *CredentialProvider) provideFromDriver() (*MountCredentials, error) { - klog.V(4).Infof("NodePublishVolume: Using driver identity") - - hostPluginDir := hostPluginDirWithDefault() - hostTokenPath := path.Join(hostPluginDir, "token") - - return &MountCredentials{ - AuthenticationSource: AuthenticationSourceDriver, - AccessKeyID: os.Getenv(envprovider.EnvAccessKeyID), - SecretAccessKey: os.Getenv(envprovider.EnvSecretAccessKey), - SessionToken: os.Getenv(envprovider.EnvSessionToken), - Region: os.Getenv(envprovider.EnvRegion), - DefaultRegion: os.Getenv(envprovider.EnvDefaultRegion), - WebTokenPath: hostTokenPath, - StsEndpoints: os.Getenv(envprovider.EnvSTSRegionalEndpoints), - AwsRoleArn: os.Getenv(envprovider.EnvRoleARN), - }, nil -} - -func (c *CredentialProvider) provideFromPod(ctx context.Context, volumeID string, volumeCtx map[string]string, args mountpoint.Args) (*MountCredentials, error) { - klog.V(4).Infof("NodePublishVolume: Using pod identity") - - tokensJson := volumeCtx[volumecontext.CSIServiceAccountTokens] - if tokensJson == "" { - klog.Error("`authenticationSource` configured to `pod` but no service account tokens are received. Please make sure to enable `podInfoOnMountCompat`, see " + podLevelCredentialsDocsPage) - return nil, status.Error(codes.InvalidArgument, "Missing service account tokens") - } - - var tokens map[string]*Token - if err := json.Unmarshal([]byte(tokensJson), &tokens); err != nil { - return nil, status.Errorf(codes.InvalidArgument, "Failed to parse service account tokens: %v", err) - } - - stsToken := tokens[serviceAccountTokenAudienceSTS] - if stsToken == nil { - klog.Errorf("`authenticationSource` configured to `pod` but no service account tokens for %s received. Please make sure to enable `podInfoOnMountCompat`, see "+podLevelCredentialsDocsPage, serviceAccountTokenAudienceSTS) - return nil, status.Errorf(codes.InvalidArgument, "Missing service account token for %s", serviceAccountTokenAudienceSTS) - } - - awsRoleARN, err := c.findPodServiceAccountRole(ctx, volumeCtx) - if err != nil { - return nil, err - } - - region, err := c.stsRegion(volumeCtx, args) - if err != nil { - return nil, status.Errorf(codes.InvalidArgument, "Failed to detect STS AWS Region, please explicitly set the AWS Region, see "+stsConfigDocsPage) - } - - defaultRegion := os.Getenv(envprovider.EnvDefaultRegion) - if defaultRegion == "" { - defaultRegion = region - } - - podID := volumeCtx[volumecontext.CSIPodUID] - if podID == "" { - return nil, status.Error(codes.InvalidArgument, "Missing Pod info. Please make sure to enable `podInfoOnMountCompat`, see "+podLevelCredentialsDocsPage) - } - - err = c.writeToken(podID, volumeID, stsToken) - if err != nil { - return nil, status.Errorf(codes.Internal, "Failed to write service account token: %v", err) - } - - hostPluginDir := hostPluginDirWithDefault() - hostTokenPath := path.Join(hostPluginDir, c.tokenFilename(podID, volumeID)) - - podNamespace := volumeCtx[volumecontext.CSIPodNamespace] - podServiceAccount := volumeCtx[volumecontext.CSIServiceAccountName] - cacheKey := podNamespace + "/" + podServiceAccount - - return &MountCredentials{ - AuthenticationSource: AuthenticationSourcePod, - - Region: region, - DefaultRegion: defaultRegion, - StsEndpoints: os.Getenv(envprovider.EnvSTSRegionalEndpoints), - WebTokenPath: hostTokenPath, - AwsRoleArn: awsRoleARN, - - // Ensure to disable env credential provider - AccessKeyID: "", - SecretAccessKey: "", - - // Ensure to disable profile provider - ConfigFilePath: path.Join(hostPluginDir, "disable-config"), - SharedCredentialsFilePath: path.Join(hostPluginDir, "disable-credentials"), - - // Ensure to disable IMDS provider - DisableIMDSProvider: true, - - MountpointCacheKey: cacheKey, - }, nil -} - -func (c *CredentialProvider) writeToken(podID string, volumeID string, token *Token) error { - return renameio.WriteFile(c.tokenPathContainer(podID, volumeID), []byte(token.Token), serviceAccountTokenPerm) -} - -func (c *CredentialProvider) tokenPathContainer(podID string, volumeID string) string { - return path.Join(c.containerPluginDir, c.tokenFilename(podID, volumeID)) -} - -func (c *CredentialProvider) tokenFilename(podID string, volumeID string) string { - var filename strings.Builder - // `podID` is a UUID, but escape it to ensure it doesn't contain `/` - filename.WriteString(k8sstrings.EscapeQualifiedName(podID)) - filename.WriteRune('-') - // `volumeID` might contain `/`, we need to escape it - filename.WriteString(k8sstrings.EscapeQualifiedName(volumeID)) - filename.WriteString(".token") - return filename.String() -} - -func (c *CredentialProvider) findPodServiceAccountRole(ctx context.Context, volumeCtx map[string]string) (string, error) { - podNamespace := volumeCtx[volumecontext.CSIPodNamespace] - podServiceAccount := volumeCtx[volumecontext.CSIServiceAccountName] - if podNamespace == "" || podServiceAccount == "" { - klog.Error("`authenticationSource` configured to `pod` but no pod info found. Please make sure to enable `podInfoOnMountCompat`, see " + podLevelCredentialsDocsPage) - return "", status.Error(codes.InvalidArgument, "Missing Pod info. Please make sure to enable `podInfoOnMountCompat`, see "+podLevelCredentialsDocsPage) - } - - response, err := c.client.ServiceAccounts(podNamespace).Get(ctx, podServiceAccount, metav1.GetOptions{}) - if err != nil { - return "", status.Errorf(codes.InvalidArgument, "Failed to get pod's service account %s/%s: %v", podNamespace, podServiceAccount, err) - } - - roleArn := response.Annotations[serviceAccountRoleAnnotation] - if roleArn == "" { - klog.Error("`authenticationSource` configured to `pod` but pod's service account is not annotated with a role, see " + podLevelCredentialsDocsPage) - return "", status.Errorf(codes.InvalidArgument, "Missing role annotation on pod's service account %s/%s", podNamespace, podServiceAccount) - } - - return roleArn, nil -} - -// stsRegion tries to detect AWS region to use for STS. -// -// It looks for the following (in-order): -// 1. `stsRegion` passed via volume context -// 2. Region set for S3 bucket via mount options -// 3. `AWS_REGION` or `AWS_DEFAULT_REGION` env variables -// 4. Calling IMDS to detect region -// -// It returns an error if all of them fails. -func (c *CredentialProvider) stsRegion(volumeCtx map[string]string, args mountpoint.Args) (string, error) { - region := volumeCtx[volumecontext.STSRegion] - if region != "" { - klog.V(5).Infof("NodePublishVolume: Pod-level: Detected STS region %s from volume context", region) - return region, nil - } - - if region, ok := args.Value(mountpoint.ArgRegion); ok { - klog.V(5).Infof("NodePublishVolume: Pod-level: Detected STS region %s from S3 bucket region", region) - return region, nil - } - - region = os.Getenv(envprovider.EnvRegion) - if region != "" { - klog.V(5).Infof("NodePublishVolume: Pod-level: Detected STS region %s from `AWS_REGION` env variable", region) - return region, nil - } - - region = os.Getenv(envprovider.EnvDefaultRegion) - if region != "" { - klog.V(5).Infof("NodePublishVolume: Pod-level: Detected STS region %s from `AWS_DEFAULT_REGION` env variable", region) - return region, nil - } - - // We're ignoring the error here, makes a call to IMDS only once and logs the error in case of error - region, _ = c.regionFromIMDS() - if region != "" { - klog.V(5).Infof("NodePublishVolume: Pod-level: Detected STS region %s from IMDS", region) - return region, nil - } - - return "", errUnknownRegion -} - -func hostPluginDirWithDefault() string { - hostPluginDir := os.Getenv(hostPluginDirEnv) - if hostPluginDir == "" { - hostPluginDir = defaultHostPluginDir - } - return hostPluginDir -} - -// RegionFromIMDSOnce tries to detect AWS region by making a request to IMDS. -// It only makes request to IMDS once and caches the value. -var RegionFromIMDSOnce = sync.OnceValues(func() (string, error) { - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - - cfg, err := config.LoadDefaultConfig(ctx) - if err != nil { - klog.V(5).Infof("NodePublishVolume: Pod-level: Failed to create config for IMDS client: %v", err) - return "", fmt.Errorf("could not create config for imds client: %w", err) - } - - client := imds.NewFromConfig(cfg) - output, err := client.GetRegion(ctx, &imds.GetRegionInput{}) - if err != nil { - klog.V(5).Infof("NodePublishVolume: Pod-level: Failed to get region from IMDS: %v", err) - return "", fmt.Errorf("failed to get region from imds: %w", err) - } - - return output.Region, nil -}) diff --git a/pkg/driver/node/mounter/credential_provider_test.go b/pkg/driver/node/mounter/credential_provider_test.go deleted file mode 100644 index ac8c23e2..00000000 --- a/pkg/driver/node/mounter/credential_provider_test.go +++ /dev/null @@ -1,599 +0,0 @@ -package mounter_test - -import ( - "context" - "encoding/json" - "errors" - "os" - "path" - "testing" - "time" - - "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/mounter" - "github.com/awslabs/aws-s3-csi-driver/pkg/mountpoint" - - v1 "k8s.io/api/core/v1" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/client-go/kubernetes/fake" -) - -func TestProvidingDriverLevelCredentials(t *testing.T) { - t.Setenv("AWS_ACCESS_KEY_ID", "test-access-key") - t.Setenv("AWS_SECRET_ACCESS_KEY", "test-secret-key") - t.Setenv("AWS_SESSION_TOKEN", "test-session-token") - t.Setenv("AWS_REGION", "eu-west-1") - t.Setenv("AWS_DEFAULT_REGION", "eu-north-1") - t.Setenv("HOST_PLUGIN_DIR", "/test/csi/plugin/dir") - t.Setenv("AWS_STS_REGIONAL_ENDPOINTS", "regional") - t.Setenv("AWS_ROLE_ARN", "arn:aws:iam::123456789012:role/Test") - - for _, test := range []struct { - volumeID string - volumeContext map[string]string - }{ - { - volumeID: "test-vol-id", - volumeContext: map[string]string{"authenticationSource": "driver"}, - }, - { - volumeID: "test-vol-id", - // It should default to `driver` if `authenticationSource` is not explicitly set - volumeContext: map[string]string{}, - }, - } { - - provider := mounter.NewCredentialProvider(nil, "", mounter.RegionFromIMDSOnce) - credentials, err := provider.Provide(context.Background(), test.volumeID, test.volumeContext, mountpoint.ParseArgs(nil)) - assertEquals(t, nil, err) - - assertEquals(t, credentials.AccessKeyID, "test-access-key") - assertEquals(t, credentials.SecretAccessKey, "test-secret-key") - assertEquals(t, credentials.SessionToken, "test-session-token") - assertEquals(t, credentials.Region, "eu-west-1") - assertEquals(t, credentials.DefaultRegion, "eu-north-1") - assertEquals(t, credentials.WebTokenPath, "/test/csi/plugin/dir/token") - assertEquals(t, credentials.StsEndpoints, "regional") - assertEquals(t, credentials.AwsRoleArn, "arn:aws:iam::123456789012:role/Test") - } -} - -func TestProvidingDriverLevelCredentialsWithEmptyEnv(t *testing.T) { - provider := mounter.NewCredentialProvider(nil, "", mounter.RegionFromIMDSOnce) - credentials, err := provider.Provide(context.Background(), "test-vol-id", map[string]string{"authenticationSource": "driver"}, mountpoint.ParseArgs(nil)) - assertEquals(t, nil, err) - - assertEquals(t, credentials.AccessKeyID, "") - assertEquals(t, credentials.SecretAccessKey, "") - assertEquals(t, credentials.SessionToken, "") - assertEquals(t, credentials.Region, "") - assertEquals(t, credentials.DefaultRegion, "") - assertEquals(t, credentials.WebTokenPath, "/var/lib/kubelet/plugins/s3.csi.aws.com/token") - assertEquals(t, credentials.StsEndpoints, "") - assertEquals(t, credentials.AwsRoleArn, "") -} - -func TestProvidingPodLevelCredentials(t *testing.T) { - pluginDir := t.TempDir() - clientset := fake.NewSimpleClientset(serviceAccount("test-sa", "test-ns", map[string]string{ - "eks.amazonaws.com/role-arn": "arn:aws:iam::123456789012:role/Test", - })) - t.Setenv("AWS_REGION", "eu-west-1") - t.Setenv("AWS_DEFAULT_REGION", "eu-north-1") - t.Setenv("HOST_PLUGIN_DIR", "/test/csi/plugin/dir") - t.Setenv("AWS_STS_REGIONAL_ENDPOINTS", "regional") - - provider := mounter.NewCredentialProvider(clientset.CoreV1(), pluginDir, mounter.RegionFromIMDSOnce) - - credentials, err := provider.Provide(context.Background(), "test-vol-id", map[string]string{ - "authenticationSource": "pod", - "csi.storage.k8s.io/pod.uid": "test-pod", - "csi.storage.k8s.io/pod.namespace": "test-ns", - "csi.storage.k8s.io/serviceAccount.name": "test-sa", - "csi.storage.k8s.io/serviceAccount.tokens": serviceAccountTokens(t, tokens{ - "sts.amazonaws.com": { - Token: "test-service-account-token", - }, - }), - }, mountpoint.ParseArgs(nil)) - assertEquals(t, nil, err) - - // Should disable env variable provider - assertEquals(t, credentials.AccessKeyID, "") - assertEquals(t, credentials.SecretAccessKey, "") - assertEquals(t, credentials.SessionToken, "") - - // Should disable profile provider - assertEquals(t, credentials.ConfigFilePath, "/test/csi/plugin/dir/disable-config") - assertEquals(t, credentials.SharedCredentialsFilePath, "/test/csi/plugin/dir/disable-credentials") - - // Should disable IMDS provider - assertEquals(t, credentials.DisableIMDSProvider, true) - - // Should populate env variables for STS Web Identity provider - assertEquals(t, credentials.WebTokenPath, "/test/csi/plugin/dir/test-pod-test-vol-id.token") - assertEquals(t, credentials.AwsRoleArn, "arn:aws:iam::123456789012:role/Test") - - assertEquals(t, credentials.Region, "eu-west-1") - assertEquals(t, credentials.DefaultRegion, "eu-north-1") - assertEquals(t, credentials.StsEndpoints, "regional") - - assertEquals(t, credentials.MountpointCacheKey, "test-ns/test-sa") - - token, err := os.ReadFile(tokenFilePath(credentials, pluginDir)) - assertEquals(t, nil, err) - assertEquals(t, "test-service-account-token", string(token)) -} - -func TestProvidingPodLevelCredentialsWithMissingInformation(t *testing.T) { - pluginDir := t.TempDir() - clientset := fake.NewSimpleClientset( - serviceAccount("test-sa", "test-ns", map[string]string{ - "eks.amazonaws.com/role-arn": "arn:aws:iam::123456789012:role/Test", - }), - serviceAccount("test-sa-missing-role", "test-ns", map[string]string{}), - ) - - provider := mounter.NewCredentialProvider(clientset.CoreV1(), pluginDir, mounter.RegionFromIMDSOnce) - - for name, test := range map[string]struct { - volumeID string - volumeContext map[string]string - }{ - "unknown service account": { - volumeID: "test-vol-id", - volumeContext: map[string]string{ - "authenticationSource": "pod", - "csi.storage.k8s.io/pod.uid": "test-pod", - "csi.storage.k8s.io/pod.namespace": "test-ns", - "csi.storage.k8s.io/serviceAccount.name": "test-unknown-sa", - "csi.storage.k8s.io/serviceAccount.tokens": serviceAccountTokens(t, tokens{ - "sts.amazonaws.com": { - Token: "test-service-account-token", - }, - }), - }, - }, - "missing service account token": { - volumeID: "test-vol-id", - volumeContext: map[string]string{ - "authenticationSource": "pod", - "csi.storage.k8s.io/pod.uid": "test-pod", - "csi.storage.k8s.io/pod.namespace": "test-ns", - "csi.storage.k8s.io/serviceAccount.name": "test-sa", - }, - }, - "missing sts audience in service account tokens": { - volumeID: "test-vol-id", - volumeContext: map[string]string{ - "authenticationSource": "pod", - "csi.storage.k8s.io/pod.uid": "test-pod", - "csi.storage.k8s.io/pod.namespace": "test-ns", - "csi.storage.k8s.io/serviceAccount.name": "test-sa", - "csi.storage.k8s.io/serviceAccount.tokens": serviceAccountTokens(t, tokens{ - "unknown": { - Token: "test-service-account-token", - }, - }), - }, - }, - "missing service account name": { - volumeID: "test-vol-id", - volumeContext: map[string]string{ - "authenticationSource": "pod", - "csi.storage.k8s.io/pod.uid": "test-pod", - "csi.storage.k8s.io/pod.namespace": "test-ns", - "csi.storage.k8s.io/serviceAccount.tokens": serviceAccountTokens(t, tokens{ - "sts.amazonaws.com": { - Token: "test-service-account-token", - }, - }), - }, - }, - "missing pod namespace": { - volumeID: "test-vol-id", - volumeContext: map[string]string{ - "authenticationSource": "pod", - "csi.storage.k8s.io/pod.uid": "test-pod", - "csi.storage.k8s.io/serviceAccount.name": "test-sa", - "csi.storage.k8s.io/serviceAccount.tokens": serviceAccountTokens(t, tokens{ - "sts.amazonaws.com": { - Token: "test-service-account-token", - }, - }), - }, - }, - "missing pod id": { - volumeID: "test-vol-id", - volumeContext: map[string]string{ - "authenticationSource": "pod", - "csi.storage.k8s.io/pod.namespace": "test-ns", - "csi.storage.k8s.io/serviceAccount.name": "test-sa", - "csi.storage.k8s.io/serviceAccount.tokens": serviceAccountTokens(t, tokens{ - "sts.amazonaws.com": { - Token: "test-service-account-token", - }, - }), - }, - }, - } { - t.Run(name, func(t *testing.T) { - credentials, err := provider.Provide(context.Background(), test.volumeID, test.volumeContext, mountpoint.ParseArgs(nil)) - assertEquals(t, nil, credentials) - if err == nil { - t.Error("it should fail with missing information") - } - - _, err = os.ReadFile(path.Join(pluginDir, "test-pod-test-vol-id.token")) - assertEquals(t, true, os.IsNotExist(err)) - }) - } -} - -func TestProvidingPodLevelCredentialsRegionPopulation(t *testing.T) { - clientset := fake.NewSimpleClientset(serviceAccount("test-sa", "test-ns", map[string]string{ - "eks.amazonaws.com/role-arn": "arn:aws:iam::123456789012:role/Test", - })) - - volumeID := "test-vol-id" - volumeContext := map[string]string{ - "authenticationSource": "pod", - "csi.storage.k8s.io/pod.uid": "test-pod", - "csi.storage.k8s.io/pod.namespace": "test-ns", - "csi.storage.k8s.io/serviceAccount.name": "test-sa", - "csi.storage.k8s.io/serviceAccount.tokens": serviceAccountTokens(t, tokens{ - "sts.amazonaws.com": { - Token: "test-service-account-token", - }, - }), - } - - t.Run("no region", func(t *testing.T) { - pluginDir := t.TempDir() - provider := mounter.NewCredentialProvider(clientset.CoreV1(), pluginDir, func() (string, error) { - return "", errors.New("unknown region") - }) - - credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, mountpoint.ParseArgs(nil)) - assertEquals(t, nil, credentials) - if err == nil { - t.Error("it should fail if there is not any region information") - } - - _, err = os.ReadFile(path.Join(pluginDir, "test-pod-test-vol-id.token")) - assertEquals(t, true, os.IsNotExist(err)) - }) - - t.Run("region from imds", func(t *testing.T) { - pluginDir := t.TempDir() - provider := mounter.NewCredentialProvider(clientset.CoreV1(), pluginDir, func() (string, error) { - return "us-east-1", nil - }) - - credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, mountpoint.ParseArgs(nil)) - assertEquals(t, nil, err) - assertEquals(t, credentials.Region, "us-east-1") - assertEquals(t, credentials.DefaultRegion, "us-east-1") - - token, err := os.ReadFile(tokenFilePath(credentials, pluginDir)) - assertEquals(t, nil, err) - assertEquals(t, "test-service-account-token", string(token)) - }) - - t.Run("region from env", func(t *testing.T) { - pluginDir := t.TempDir() - provider := mounter.NewCredentialProvider(clientset.CoreV1(), pluginDir, func() (string, error) { - return "us-east-1", nil - }) - - t.Setenv("AWS_REGION", "eu-west-1") - - credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, mountpoint.ParseArgs(nil)) - assertEquals(t, nil, err) - assertEquals(t, credentials.Region, "eu-west-1") - assertEquals(t, credentials.DefaultRegion, "eu-west-1") - - token, err := os.ReadFile(tokenFilePath(credentials, pluginDir)) - assertEquals(t, nil, err) - assertEquals(t, "test-service-account-token", string(token)) - }) - - t.Run("default region from env", func(t *testing.T) { - pluginDir := t.TempDir() - provider := mounter.NewCredentialProvider(clientset.CoreV1(), pluginDir, func() (string, error) { - return "us-east-1", nil - }) - - t.Setenv("AWS_DEFAULT_REGION", "eu-west-1") - - credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, mountpoint.ParseArgs(nil)) - assertEquals(t, nil, err) - assertEquals(t, credentials.Region, "eu-west-1") - assertEquals(t, credentials.DefaultRegion, "eu-west-1") - - token, err := os.ReadFile(tokenFilePath(credentials, pluginDir)) - assertEquals(t, nil, err) - assertEquals(t, "test-service-account-token", string(token)) - }) - - t.Run("default and regular region from env", func(t *testing.T) { - pluginDir := t.TempDir() - provider := mounter.NewCredentialProvider(clientset.CoreV1(), pluginDir, func() (string, error) { - return "us-east-1", nil - }) - - t.Setenv("AWS_REGION", "eu-west-1") - t.Setenv("AWS_DEFAULT_REGION", "eu-north-1") - - credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, mountpoint.ParseArgs(nil)) - assertEquals(t, nil, err) - assertEquals(t, credentials.Region, "eu-west-1") - assertEquals(t, credentials.DefaultRegion, "eu-north-1") - - token, err := os.ReadFile(tokenFilePath(credentials, pluginDir)) - assertEquals(t, nil, err) - assertEquals(t, "test-service-account-token", string(token)) - }) - - t.Run("region from mountpoint options", func(t *testing.T) { - pluginDir := t.TempDir() - provider := mounter.NewCredentialProvider(clientset.CoreV1(), pluginDir, func() (string, error) { - return "us-east-1", nil - }) - - t.Setenv("AWS_REGION", "eu-west-1") - - credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, mountpoint.ParseArgs([]string{"--region=us-west-1"})) - assertEquals(t, nil, err) - assertEquals(t, credentials.Region, "us-west-1") - assertEquals(t, credentials.DefaultRegion, "us-west-1") - - token, err := os.ReadFile(tokenFilePath(credentials, pluginDir)) - assertEquals(t, nil, err) - assertEquals(t, "test-service-account-token", string(token)) - }) - - t.Run("missing region from mountpoint options", func(t *testing.T) { - pluginDir := t.TempDir() - provider := mounter.NewCredentialProvider(clientset.CoreV1(), pluginDir, func() (string, error) { - return "us-east-1", nil - }) - - t.Setenv("AWS_REGION", "eu-west-1") - - credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, mountpoint.ParseArgs([]string{"--read-only"})) - assertEquals(t, nil, err) - assertEquals(t, credentials.Region, "eu-west-1") - assertEquals(t, credentials.DefaultRegion, "eu-west-1") - - token, err := os.ReadFile(tokenFilePath(credentials, pluginDir)) - assertEquals(t, nil, err) - assertEquals(t, "test-service-account-token", string(token)) - }) - - t.Run("region from mountpoint options with default region from env", func(t *testing.T) { - pluginDir := t.TempDir() - provider := mounter.NewCredentialProvider(clientset.CoreV1(), pluginDir, func() (string, error) { - return "us-east-1", nil - }) - - t.Setenv("AWS_REGION", "eu-west-1") - t.Setenv("AWS_DEFAULT_REGION", "eu-north-1") - - credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, mountpoint.ParseArgs([]string{"--region=us-west-1"})) - assertEquals(t, nil, err) - assertEquals(t, credentials.Region, "us-west-1") - assertEquals(t, credentials.DefaultRegion, "eu-north-1") - - token, err := os.ReadFile(tokenFilePath(credentials, pluginDir)) - assertEquals(t, nil, err) - assertEquals(t, "test-service-account-token", string(token)) - }) - - t.Run("region from volume context", func(t *testing.T) { - pluginDir := t.TempDir() - provider := mounter.NewCredentialProvider(clientset.CoreV1(), pluginDir, func() (string, error) { - return "us-east-1", nil - }) - - t.Setenv("AWS_REGION", "eu-west-1") - - volumeContext["stsRegion"] = "ap-south-1" - - credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, mountpoint.ParseArgs([]string{"--region=us-west-1"})) - assertEquals(t, nil, err) - assertEquals(t, credentials.Region, "ap-south-1") - assertEquals(t, credentials.DefaultRegion, "ap-south-1") - - token, err := os.ReadFile(tokenFilePath(credentials, pluginDir)) - assertEquals(t, nil, err) - assertEquals(t, "test-service-account-token", string(token)) - }) - - t.Run("region from volume context with default region from env", func(t *testing.T) { - pluginDir := t.TempDir() - provider := mounter.NewCredentialProvider(clientset.CoreV1(), pluginDir, func() (string, error) { - return "us-east-1", nil - }) - - t.Setenv("AWS_REGION", "eu-west-1") - t.Setenv("AWS_DEFAULT_REGION", "eu-north-1") - - volumeContext["stsRegion"] = "ap-south-1" - - credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, mountpoint.ParseArgs([]string{"--region=us-west-1"})) - assertEquals(t, nil, err) - assertEquals(t, credentials.Region, "ap-south-1") - assertEquals(t, credentials.DefaultRegion, "eu-north-1") - - token, err := os.ReadFile(tokenFilePath(credentials, pluginDir)) - assertEquals(t, nil, err) - assertEquals(t, "test-service-account-token", string(token)) - }) -} - -func TestProvidingPodLevelCredentialsForDifferentPodsWithDifferentRoles(t *testing.T) { - pluginDir := t.TempDir() - clientset := fake.NewSimpleClientset( - serviceAccount("test-sa-1", "test-ns", map[string]string{ - "eks.amazonaws.com/role-arn": "arn:aws:iam::123456789012:role/Test1", - }), - serviceAccount("test-sa-2", "test-ns", map[string]string{ - "eks.amazonaws.com/role-arn": "arn:aws:iam::123456789012:role/Test2", - }), - ) - t.Setenv("AWS_REGION", "eu-west-1") - t.Setenv("AWS_DEFAULT_REGION", "eu-north-1") - t.Setenv("HOST_PLUGIN_DIR", "/test/csi/plugin/dir") - t.Setenv("AWS_STS_REGIONAL_ENDPOINTS", "regional") - - provider := mounter.NewCredentialProvider(clientset.CoreV1(), pluginDir, mounter.RegionFromIMDSOnce) - - credentialsPodOne, err := provider.Provide(context.Background(), "test-vol-id", map[string]string{ - "authenticationSource": "pod", - "csi.storage.k8s.io/pod.uid": "test-pod-1", - "csi.storage.k8s.io/pod.namespace": "test-ns", - "csi.storage.k8s.io/serviceAccount.name": "test-sa-1", - "csi.storage.k8s.io/serviceAccount.tokens": serviceAccountTokens(t, tokens{ - "sts.amazonaws.com": { - Token: "test-service-account-token-1", - }, - }), - }, mountpoint.ParseArgs(nil)) - assertEquals(t, nil, err) - - credentialsPodTwo, err := provider.Provide(context.Background(), "test-vol-id", map[string]string{ - "authenticationSource": "pod", - "csi.storage.k8s.io/pod.uid": "test-pod-2", - "csi.storage.k8s.io/pod.namespace": "test-ns", - "csi.storage.k8s.io/serviceAccount.name": "test-sa-2", - "csi.storage.k8s.io/serviceAccount.tokens": serviceAccountTokens(t, tokens{ - "sts.amazonaws.com": { - Token: "test-service-account-token-2", - }, - }), - }, mountpoint.ParseArgs(nil)) - assertEquals(t, nil, err) - - // PodOne - assertEquals(t, credentialsPodOne.AccessKeyID, "") - assertEquals(t, credentialsPodOne.SecretAccessKey, "") - assertEquals(t, credentialsPodOne.SessionToken, "") - assertEquals(t, credentialsPodOne.Region, "eu-west-1") - assertEquals(t, credentialsPodOne.DefaultRegion, "eu-north-1") - assertEquals(t, credentialsPodOne.WebTokenPath, "/test/csi/plugin/dir/test-pod-1-test-vol-id.token") - assertEquals(t, credentialsPodOne.StsEndpoints, "regional") - assertEquals(t, credentialsPodOne.AwsRoleArn, "arn:aws:iam::123456789012:role/Test1") - assertEquals(t, credentialsPodOne.MountpointCacheKey, "test-ns/test-sa-1") - - token, err := os.ReadFile(tokenFilePath(credentialsPodOne, pluginDir)) - assertEquals(t, nil, err) - assertEquals(t, "test-service-account-token-1", string(token)) - - // PodTwo - assertEquals(t, credentialsPodTwo.AccessKeyID, "") - assertEquals(t, credentialsPodTwo.SecretAccessKey, "") - assertEquals(t, credentialsPodTwo.SessionToken, "") - assertEquals(t, credentialsPodTwo.Region, "eu-west-1") - assertEquals(t, credentialsPodTwo.DefaultRegion, "eu-north-1") - assertEquals(t, credentialsPodTwo.WebTokenPath, "/test/csi/plugin/dir/test-pod-2-test-vol-id.token") - assertEquals(t, credentialsPodTwo.StsEndpoints, "regional") - assertEquals(t, credentialsPodTwo.AwsRoleArn, "arn:aws:iam::123456789012:role/Test2") - assertEquals(t, credentialsPodTwo.MountpointCacheKey, "test-ns/test-sa-2") - - token, err = os.ReadFile(tokenFilePath(credentialsPodTwo, pluginDir)) - assertEquals(t, nil, err) - assertEquals(t, "test-service-account-token-2", string(token)) -} - -func TestProvidingPodLevelCredentialsWithSlashInVolumeID(t *testing.T) { - pluginDir := t.TempDir() - clientset := fake.NewSimpleClientset(serviceAccount("test-sa", "test-ns", map[string]string{ - "eks.amazonaws.com/role-arn": "arn:aws:iam::123456789012:role/Test", - })) - t.Setenv("AWS_REGION", "eu-west-1") - t.Setenv("AWS_DEFAULT_REGION", "eu-north-1") - t.Setenv("HOST_PLUGIN_DIR", "/test/csi/plugin/dir") - t.Setenv("AWS_STS_REGIONAL_ENDPOINTS", "regional") - - provider := mounter.NewCredentialProvider(clientset.CoreV1(), pluginDir, mounter.RegionFromIMDSOnce) - - credentials, err := provider.Provide(context.Background(), "test-vol-id/1", map[string]string{ - "authenticationSource": "pod", - "csi.storage.k8s.io/pod.uid": "test-pod", - "csi.storage.k8s.io/pod.namespace": "test-ns", - "csi.storage.k8s.io/serviceAccount.name": "test-sa", - "csi.storage.k8s.io/serviceAccount.tokens": serviceAccountTokens(t, tokens{ - "sts.amazonaws.com": { - Token: "test-service-account-token", - }, - }), - }, mountpoint.ParseArgs(nil)) - assertEquals(t, nil, err) - - assertEquals(t, credentials.AccessKeyID, "") - assertEquals(t, credentials.SecretAccessKey, "") - assertEquals(t, credentials.SessionToken, "") - assertEquals(t, credentials.Region, "eu-west-1") - assertEquals(t, credentials.DefaultRegion, "eu-north-1") - assertEquals(t, credentials.WebTokenPath, "/test/csi/plugin/dir/test-pod-test-vol-id~1.token") - assertEquals(t, credentials.StsEndpoints, "regional") - assertEquals(t, credentials.AwsRoleArn, "arn:aws:iam::123456789012:role/Test") - - token, err := os.ReadFile(tokenFilePath(credentials, pluginDir)) - assertEquals(t, nil, err) - assertEquals(t, "test-service-account-token", string(token)) -} - -func TestCleaningUpTokenFileForAVolume(t *testing.T) { - t.Run("existing token", func(t *testing.T) { - pluginDir := t.TempDir() - volumeID := "test-vol-id" - podID := "test-pod-id" - tokenPath := path.Join(pluginDir, podID+"-"+volumeID+".token") - err := os.WriteFile(tokenPath, []byte("test-service-account-token"), 0400) - assertEquals(t, nil, err) - - provider := mounter.NewCredentialProvider(nil, pluginDir, mounter.RegionFromIMDSOnce) - err = provider.CleanupToken(volumeID, podID) - assertEquals(t, nil, err) - - _, err = os.ReadFile(tokenPath) - assertEquals(t, true, os.IsNotExist(err)) - }) - - t.Run("non-existing token", func(t *testing.T) { - provider := mounter.NewCredentialProvider(nil, t.TempDir(), mounter.RegionFromIMDSOnce) - - err := provider.CleanupToken("non-existing-vol-id", "non-existing-pod-id") - assertEquals(t, nil, err) - }) -} - -type tokens = map[string]struct { - Token string `json:"token"` - ExpirationTimestamp time.Time -} - -func serviceAccountTokens(t *testing.T, tokens tokens) string { - buf, err := json.Marshal(&tokens) - assertEquals(t, nil, err) - return string(buf) -} - -func serviceAccount(name, namespace string, annotations map[string]string) *v1.ServiceAccount { - return &v1.ServiceAccount{ObjectMeta: metav1.ObjectMeta{ - Name: name, - Namespace: namespace, - Annotations: annotations, - }} -} - -func tokenFilePath(credentials *mounter.MountCredentials, pluginDir string) string { - return path.Join(pluginDir, path.Base(credentials.WebTokenPath)) -} - -func assertEquals[T comparable](t *testing.T, expected T, got T) { - if expected != got { - t.Errorf("Expected %#v, Got %#v", expected, got) - } -} diff --git a/pkg/driver/node/mounter/fake_mounter.go b/pkg/driver/node/mounter/fake_mounter.go index c97c40a0..ab0119a0 100644 --- a/pkg/driver/node/mounter/fake_mounter.go +++ b/pkg/driver/node/mounter/fake_mounter.go @@ -1,15 +1,20 @@ package mounter -import "github.com/awslabs/aws-s3-csi-driver/pkg/mountpoint" +import ( + "context" + + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/credentialprovider" + "github.com/awslabs/aws-s3-csi-driver/pkg/mountpoint" +) type FakeMounter struct{} -func (m *FakeMounter) Mount(bucketName string, target string, - credentials *MountCredentials, args mountpoint.Args) error { +func (m *FakeMounter) Mount(ctx context.Context, bucketName string, target string, + credentialCtx credentialprovider.ProvideContext, args mountpoint.Args) error { return nil } -func (m *FakeMounter) Unmount(target string) error { +func (m *FakeMounter) Unmount(target string, credentialCtx credentialprovider.CleanupContext) error { return nil } diff --git a/pkg/driver/node/mounter/mocks/mock_mount.go b/pkg/driver/node/mounter/mocks/mock_mount.go index fdfe8880..1e4c93fb 100644 --- a/pkg/driver/node/mounter/mocks/mock_mount.go +++ b/pkg/driver/node/mounter/mocks/mock_mount.go @@ -8,7 +8,7 @@ import ( context "context" reflect "reflect" - mounter "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/mounter" + credentialprovider "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/credentialprovider" mountpoint "github.com/awslabs/aws-s3-csi-driver/pkg/mountpoint" system "github.com/awslabs/aws-s3-csi-driver/pkg/system" gomock "github.com/golang/mock/gomock" @@ -106,29 +106,29 @@ func (mr *MockMounterMockRecorder) IsMountPoint(target interface{}) *gomock.Call } // Mount mocks base method. -func (m *MockMounter) Mount(bucketName, target string, credentials *mounter.MountCredentials, args mountpoint.Args) error { +func (m *MockMounter) Mount(ctx context.Context, bucketName, target string, credentialCtx credentialprovider.ProvideContext, args mountpoint.Args) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Mount", bucketName, target, credentials, args) + ret := m.ctrl.Call(m, "Mount", ctx, bucketName, target, credentialCtx, args) ret0, _ := ret[0].(error) return ret0 } // Mount indicates an expected call of Mount. -func (mr *MockMounterMockRecorder) Mount(bucketName, target, credentials, args interface{}) *gomock.Call { +func (mr *MockMounterMockRecorder) Mount(ctx, bucketName, target, credentialCtx, args interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Mount", reflect.TypeOf((*MockMounter)(nil).Mount), bucketName, target, credentials, args) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Mount", reflect.TypeOf((*MockMounter)(nil).Mount), ctx, bucketName, target, credentialCtx, args) } // Unmount mocks base method. -func (m *MockMounter) Unmount(target string) error { +func (m *MockMounter) Unmount(target string, credentialCtx credentialprovider.CleanupContext) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Unmount", target) + ret := m.ctrl.Call(m, "Unmount", target, credentialCtx) ret0, _ := ret[0].(error) return ret0 } // Unmount indicates an expected call of Unmount. -func (mr *MockMounterMockRecorder) Unmount(target interface{}) *gomock.Call { +func (mr *MockMounterMockRecorder) Unmount(target, credentialCtx interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Unmount", reflect.TypeOf((*MockMounter)(nil).Unmount), target) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Unmount", reflect.TypeOf((*MockMounter)(nil).Unmount), target, credentialCtx) } diff --git a/pkg/driver/node/mounter/mount_credentials.go b/pkg/driver/node/mounter/mount_credentials.go deleted file mode 100644 index 450d9bd5..00000000 --- a/pkg/driver/node/mounter/mount_credentials.go +++ /dev/null @@ -1,83 +0,0 @@ -package mounter - -import ( - "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/awsprofile" - "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/envprovider" -) - -type MountCredentials struct { - // Identifies how these credentials are obtained. - AuthenticationSource AuthenticationSource - - // -- Env variable provider - AccessKeyID string - SecretAccessKey string - SessionToken string - - // -- Profile provider - ConfigFilePath string - SharedCredentialsFilePath string - - // -- STS provider - WebTokenPath string - AwsRoleArn string - - // -- IMDS provider - DisableIMDSProvider bool - - // -- Generic - Region string - DefaultRegion string - StsEndpoints string - - // -- TODO - Move somewhere better - MountpointCacheKey string -} - -// Get environment variables to pass to mount-s3 for authentication. -func (mc *MountCredentials) Env(awsProfile awsprofile.AWSProfile) envprovider.Environment { - env := envprovider.Environment{} - - // For profile provider from long-term credentials - if awsProfile.Name != "" { - env.Set(envprovider.EnvProfile, awsProfile.Name) - env.Set(envprovider.EnvConfigFile, awsProfile.ConfigPath) - env.Set(envprovider.EnvSharedCredentialsFile, awsProfile.CredentialsPath) - } else { - // For profile provider - if mc.ConfigFilePath != "" { - env.Set(envprovider.EnvConfigFile, mc.ConfigFilePath) - } - if mc.SharedCredentialsFilePath != "" { - env.Set(envprovider.EnvSharedCredentialsFile, mc.SharedCredentialsFilePath) - } - } - - // For STS Web Identity provider - if mc.WebTokenPath != "" { - env.Set(envprovider.EnvWebIdentityTokenFile, mc.WebTokenPath) - env.Set(envprovider.EnvRoleARN, mc.AwsRoleArn) - } - - // For disabling IMDS provider - if mc.DisableIMDSProvider { - env.Set(envprovider.EnvEC2MetadataDisabled, "true") - } - - // Generic variables - if mc.Region != "" { - env.Set(envprovider.EnvRegion, mc.Region) - } - if mc.DefaultRegion != "" { - env.Set(envprovider.EnvDefaultRegion, mc.DefaultRegion) - } - if mc.StsEndpoints != "" { - env.Set(envprovider.EnvSTSRegionalEndpoints, mc.StsEndpoints) - } - - if mc.MountpointCacheKey != "" { - env.Set(envprovider.EnvMountpointCacheKey, mc.MountpointCacheKey) - } - - return env -} diff --git a/pkg/driver/node/mounter/mounter.go b/pkg/driver/node/mounter/mounter.go index d63f6a77..8942b957 100644 --- a/pkg/driver/node/mounter/mounter.go +++ b/pkg/driver/node/mounter/mounter.go @@ -5,6 +5,7 @@ import ( "context" "os" + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/credentialprovider" "github.com/awslabs/aws-s3-csi-driver/pkg/mountpoint" "github.com/awslabs/aws-s3-csi-driver/pkg/system" ) @@ -16,8 +17,8 @@ type ServiceRunner interface { // Mounter is an interface for mount operations type Mounter interface { - Mount(bucketName string, target string, credentials *MountCredentials, args mountpoint.Args) error - Unmount(target string) error + Mount(ctx context.Context, bucketName string, target string, credentialCtx credentialprovider.ProvideContext, args mountpoint.Args) error + Unmount(target string, credentialCtx credentialprovider.CleanupContext) error IsMountPoint(target string) (bool, error) } diff --git a/pkg/driver/node/mounter/systemd_mounter.go b/pkg/driver/node/mounter/systemd_mounter.go index 6c60c863..e884bece 100644 --- a/pkg/driver/node/mounter/systemd_mounter.go +++ b/pkg/driver/node/mounter/systemd_mounter.go @@ -4,16 +4,16 @@ import ( "context" "fmt" "os" - "path/filepath" "time" - "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/awsprofile" - "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/envprovider" - "github.com/awslabs/aws-s3-csi-driver/pkg/mountpoint" - "github.com/awslabs/aws-s3-csi-driver/pkg/system" "github.com/google/uuid" "k8s.io/klog/v2" "k8s.io/mount-utils" + + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/credentialprovider" + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/envprovider" + "github.com/awslabs/aws-s3-csi-driver/pkg/mountpoint" + "github.com/awslabs/aws-s3-csi-driver/pkg/system" ) // https://github.com/awslabs/mountpoint-s3/blob/9ed8b6243f4511e2013b2f4303a9197c3ddd4071/mountpoint-s3/src/cli.rs#L421 @@ -26,9 +26,10 @@ type SystemdMounter struct { MpVersion string MountS3Path string kubernetesVersion string + credProvider *credentialprovider.Provider } -func NewSystemdMounter(mpVersion string, kubernetesVersion string) (*SystemdMounter, error) { +func NewSystemdMounter(credProvider *credentialprovider.Provider, mpVersion string, kubernetesVersion string) (*SystemdMounter, error) { ctx := context.Background() runner, err := system.StartOsSystemdSupervisor() if err != nil { @@ -41,6 +42,7 @@ func NewSystemdMounter(mpVersion string, kubernetesVersion string) (*SystemdMoun MpVersion: mpVersion, MountS3Path: MountS3Path(), kubernetesVersion: kubernetesVersion, + credProvider: credProvider, }, nil } @@ -80,7 +82,7 @@ func (m *SystemdMounter) IsMountPoint(target string) (bool, error) { // // This method will create the target path if it does not exist and if there is an existing corrupt // mount, it will attempt an unmount before attempting the mount. -func (m *SystemdMounter) Mount(bucketName string, target string, credentials *MountCredentials, args mountpoint.Args) error { +func (m *SystemdMounter) Mount(ctx context.Context, bucketName string, target string, credentialCtx credentialprovider.ProvideContext, args mountpoint.Args) error { if bucketName == "" { return fmt.Errorf("bucket name is empty") } @@ -90,6 +92,8 @@ func (m *SystemdMounter) Mount(bucketName string, target string, credentials *Mo timeoutCtx, cancel := context.WithTimeout(m.Ctx, 30*time.Second) defer cancel() + credentialCtx.WritePath, credentialCtx.EnvPath = m.credentialWriteAndEnvPath() + cleanupDir := false // check if the target path exists @@ -112,7 +116,11 @@ func (m *SystemdMounter) Mount(bucketName string, target string, credentials *Mo // Corrupted mount, try unmounting if mount.IsCorruptedMnt(statErr) { klog.V(4).Infof("Mount: Target path %q is a corrupted mount. Trying to unmount.", target) - if mntErr := m.Unmount(target); mntErr != nil { + if mntErr := m.Unmount(target, credentialprovider.CleanupContext{ + WritePath: credentialCtx.WritePath, + PodID: credentialCtx.PodID, + VolumeID: credentialCtx.VolumeID, + }); mntErr != nil { return fmt.Errorf("Unable to unmount the target %q : %v, %v", target, statErr, mntErr) } } @@ -123,30 +131,19 @@ func (m *SystemdMounter) Mount(bucketName string, target string, credentials *Mo return fmt.Errorf("Could not check if %q is a mount point: %v, %v", target, statErr, err) } - if isMountPoint { - klog.V(4).Infof("NodePublishVolume: Target path %q is already mounted", target) - return nil - } - env := envprovider.Default() - var authenticationSource AuthenticationSource - if credentials != nil { - var awsProfile awsprofile.AWSProfile - if credentials.AccessKeyID != "" && credentials.SecretAccessKey != "" { - // Kubernetes creates target path in the form of "/var/lib/kubelet/pods//volumes/kubernetes.io~csi//mount". - // So the directory of the target path is unique for this mount, and we can use it to write credentials and config files. - // These files will be cleaned up in `Unmount`. - basepath := filepath.Dir(target) - awsProfile, err = awsprofile.CreateAWSProfile(basepath, credentials.AccessKeyID, credentials.SecretAccessKey, credentials.SessionToken) - if err != nil { - klog.V(4).Infof("Mount: Failed to create AWS Profile in %s: %v", basepath, err) - return fmt.Errorf("Mount: Failed to create AWS Profile in %s: %v", basepath, err) - } - } - authenticationSource = credentials.AuthenticationSource + credEnv, authenticationSource, err := m.credProvider.Provide(ctx, credentialCtx) + if err != nil { + klog.V(4).Infof("NodePublishVolume: Failed to provide credentials for %s: %v", target, err) + return err + } + + env.Merge(credEnv) - env = credentials.Env(awsProfile) + if isMountPoint { + klog.V(4).Infof("NodePublishVolume: Target path %q is already mounted", target) + return nil } // Move `--aws-max-attempts` to env if provided @@ -174,15 +171,11 @@ func (m *SystemdMounter) Mount(bucketName string, target string, credentials *Mo return nil } -func (m *SystemdMounter) Unmount(target string) error { +func (m *SystemdMounter) Unmount(target string, credentialCtx credentialprovider.CleanupContext) error { timeoutCtx, cancel := context.WithTimeout(m.Ctx, 30*time.Second) defer cancel() - basepath := filepath.Dir(target) - err := awsprofile.CleanupAWSProfile(basepath) - if err != nil { - klog.V(4).Infof("Unmount: Failed to clean up AWS Profile in %s: %v", basepath, err) - } + credentialCtx.WritePath, _ = m.credentialWriteAndEnvPath() output, err := m.Runner.RunOneshot(timeoutCtx, &system.ExecConfig{ Name: "mount-s3-umount-" + uuid.New().String() + ".service", @@ -196,5 +189,27 @@ func (m *SystemdMounter) Unmount(target string) error { if output != "" { klog.V(5).Infof("umount output: %s", output) } + + err = m.credProvider.Cleanup(credentialCtx) + if err != nil { + klog.V(4).Infof("Unmount: Failed to clean up credentials for %s: %v", target, err) + } + return nil } + +func (m *SystemdMounter) credentialWriteAndEnvPath() (writePath string, envPath string) { + // This is the plugin directory for CSI driver mounted in the container. + writePath = "/csi" + // This is the plugin directory for CSI driver in the host. + envPath = hostPluginDirWithDefault() + return writePath, envPath +} + +func hostPluginDirWithDefault() string { + hostPluginDir := os.Getenv("HOST_PLUGIN_DIR") + if hostPluginDir == "" { + hostPluginDir = "/var/lib/kubelet/plugins/s3.csi.aws.com/" + } + return hostPluginDir +} diff --git a/pkg/driver/node/mounter/systemd_mounter_test.go b/pkg/driver/node/mounter/systemd_mounter_test.go index 9e0c3194..7a9818bb 100644 --- a/pkg/driver/node/mounter/systemd_mounter_test.go +++ b/pkg/driver/node/mounter/systemd_mounter_test.go @@ -5,16 +5,15 @@ import ( "errors" "os" "path/filepath" - "reflect" - "slices" "strings" "testing" - "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/awsprofile" + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/credentialprovider" "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/mounter" mock_driver "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/mounter/mocks" "github.com/awslabs/aws-s3-csi-driver/pkg/mountpoint" "github.com/awslabs/aws-s3-csi-driver/pkg/system" + "github.com/awslabs/aws-s3-csi-driver/pkg/util/testutil/assert" "github.com/golang/mock/gomock" "k8s.io/mount-utils" ) @@ -50,51 +49,47 @@ func initMounterTestEnv(t *testing.T) *mounterTestEnv { func TestS3MounterMount(t *testing.T) { testBucketName := "test-bucket" testTargetPath := filepath.Join(t.TempDir(), "mount") - testCredentials := &mounter.MountCredentials{ - AccessKeyID: "test-access-key", - SecretAccessKey: "test-secret-key", - Region: "test-region", - DefaultRegion: "test-region", - WebTokenPath: "test-web-token-path", - StsEndpoints: "test-sts-endpoint", - AwsRoleArn: "test-aws-role", + testProvideCtx := credentialprovider.ProvideContext{ + PodID: "test-pod", + VolumeID: "test-volume", + WritePath: t.TempDir(), } testCases := []struct { name string bucketName string targetPath string - credentials *mounter.MountCredentials + provideCtx credentialprovider.ProvideContext options []string expectedErr bool before func(*testing.T, *mounterTestEnv) }{ { - name: "success: mounts with empty options", - bucketName: testBucketName, - targetPath: testTargetPath, - credentials: testCredentials, - options: []string{}, + name: "success: mounts with empty options", + bucketName: testBucketName, + targetPath: testTargetPath, + provideCtx: testProvideCtx, + options: []string{}, before: func(t *testing.T, env *mounterTestEnv) { env.mockRunner.EXPECT().StartService(gomock.Any(), gomock.Any()).Return("success", nil) }, }, { - name: "success: mounts with nil credentials", - bucketName: testBucketName, - targetPath: testTargetPath, - credentials: nil, - options: []string{}, + name: "success: mounts with nil credentials", + bucketName: testBucketName, + targetPath: testTargetPath, + provideCtx: credentialprovider.ProvideContext{}, + options: []string{}, before: func(t *testing.T, env *mounterTestEnv) { env.mockRunner.EXPECT().StartService(gomock.Any(), gomock.Any()).Return("success", nil) }, }, { - name: "success: replaces user agent prefix", - bucketName: testBucketName, - targetPath: testTargetPath, - credentials: nil, - options: []string{"--user-agent-prefix=mycustomuseragent"}, + name: "success: replaces user agent prefix", + bucketName: testBucketName, + targetPath: testTargetPath, + provideCtx: credentialprovider.ProvideContext{}, + options: []string{"--user-agent-prefix=mycustomuseragent"}, before: func(t *testing.T, env *mounterTestEnv) { env.mockRunner.EXPECT().StartService(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, config *system.ExecConfig) (string, error) { for _, a := range config.Args { @@ -107,11 +102,11 @@ func TestS3MounterMount(t *testing.T) { }, }, { - name: "success: aws max attempts", - bucketName: testBucketName, - targetPath: testTargetPath, - credentials: nil, - options: []string{"--aws-max-attempts=10"}, + name: "success: aws max attempts", + bucketName: testBucketName, + targetPath: testTargetPath, + provideCtx: credentialprovider.ProvideContext{}, + options: []string{"--aws-max-attempts=10"}, before: func(t *testing.T, env *mounterTestEnv) { env.mockRunner.EXPECT().StartService(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, config *system.ExecConfig) (string, error) { for _, e := range config.Env { @@ -128,7 +123,7 @@ func TestS3MounterMount(t *testing.T) { name: "failure: fails on mount failure", bucketName: testBucketName, targetPath: testTargetPath, - credentials: nil, + provideCtx: credentialprovider.ProvideContext{}, options: []string{}, expectedErr: true, before: func(t *testing.T, env *mounterTestEnv) { @@ -138,14 +133,14 @@ func TestS3MounterMount(t *testing.T) { { name: "failure: won't mount empty bucket name", targetPath: testTargetPath, - credentials: testCredentials, + provideCtx: testProvideCtx, options: []string{}, expectedErr: true, }, { name: "failure: won't mount empty target", bucketName: testBucketName, - credentials: testCredentials, + provideCtx: testProvideCtx, options: []string{}, expectedErr: true, }, @@ -156,8 +151,8 @@ func TestS3MounterMount(t *testing.T) { if testCase.before != nil { testCase.before(t, env) } - err := env.mounter.Mount(testCase.bucketName, testCase.targetPath, - testCase.credentials, mountpoint.ParseArgs(testCase.options)) + err := env.mounter.Mount(env.ctx, testCase.bucketName, testCase.targetPath, + testCase.provideCtx, mountpoint.ParseArgs(testCase.options)) env.mockCtl.Finish() if err != nil && !testCase.expectedErr { t.Fatal(err) @@ -166,119 +161,6 @@ func TestS3MounterMount(t *testing.T) { } } -func TestProvidingEnvVariablesForMountpointProcess(t *testing.T) { - tests := map[string]struct { - profile awsprofile.AWSProfile - credentials *mounter.MountCredentials - expected []string - }{ - "Profile Provider for long-term credentials": { - profile: awsprofile.AWSProfile{ - Name: "profile", - ConfigPath: "~/.aws/s3-csi-config", - CredentialsPath: "~/.aws/s3-csi-credentials", - }, - credentials: &mounter.MountCredentials{}, - expected: []string{ - "AWS_PROFILE=profile", - "AWS_CONFIG_FILE=~/.aws/s3-csi-config", - "AWS_SHARED_CREDENTIALS_FILE=~/.aws/s3-csi-credentials", - }, - }, - "Profile Provider": { - credentials: &mounter.MountCredentials{ - ConfigFilePath: "~/.aws/config", - SharedCredentialsFilePath: "~/.aws/credentials", - }, - expected: []string{ - "AWS_CONFIG_FILE=~/.aws/config", - "AWS_SHARED_CREDENTIALS_FILE=~/.aws/credentials", - }, - }, - "Disabling IMDS Provider": { - credentials: &mounter.MountCredentials{ - DisableIMDSProvider: true, - }, - expected: []string{ - "AWS_EC2_METADATA_DISABLED=true", - }, - }, - "STS Web Identity Provider": { - credentials: &mounter.MountCredentials{ - WebTokenPath: "/path/to/web/token", - AwsRoleArn: "arn:aws:iam::123456789012:role/Role", - }, - expected: []string{ - "AWS_WEB_IDENTITY_TOKEN_FILE=/path/to/web/token", - "AWS_ROLE_ARN=arn:aws:iam::123456789012:role/Role", - }, - }, - "Region and Default Region": { - credentials: &mounter.MountCredentials{ - Region: "us-west-2", - DefaultRegion: "us-east-1", - }, - expected: []string{ - "AWS_REGION=us-west-2", - "AWS_DEFAULT_REGION=us-east-1", - }, - }, - "STS Endpoints": { - credentials: &mounter.MountCredentials{ - StsEndpoints: "regional", - }, - expected: []string{ - "AWS_STS_REGIONAL_ENDPOINTS=regional", - }, - }, - "Mountpoint Cache Key": { - credentials: &mounter.MountCredentials{ - MountpointCacheKey: "test_cache_key", - }, - expected: []string{ - "UNSTABLE_MOUNTPOINT_CACHE_KEY=test_cache_key", - }, - }, - "All Combined": { - credentials: &mounter.MountCredentials{ - WebTokenPath: "/path/to/web/token", - AwsRoleArn: "arn:aws:iam::123456789012:role/Role", - Region: "us-west-2", - DefaultRegion: "us-east-1", - StsEndpoints: "legacy", - ConfigFilePath: "~/.aws/config", - SharedCredentialsFilePath: "~/.aws/credentials", - DisableIMDSProvider: true, - MountpointCacheKey: "test/cache/key", - }, - expected: []string{ - "AWS_WEB_IDENTITY_TOKEN_FILE=/path/to/web/token", - "AWS_ROLE_ARN=arn:aws:iam::123456789012:role/Role", - "AWS_REGION=us-west-2", - "AWS_DEFAULT_REGION=us-east-1", - "AWS_STS_REGIONAL_ENDPOINTS=legacy", - "AWS_EC2_METADATA_DISABLED=true", - "AWS_CONFIG_FILE=~/.aws/config", - "AWS_SHARED_CREDENTIALS_FILE=~/.aws/credentials", - "UNSTABLE_MOUNTPOINT_CACHE_KEY=test/cache/key", - }, - }, - } - - for name, test := range tests { - t.Run(name, func(t *testing.T) { - actual := test.credentials.Env(test.profile).List() - - slices.Sort(test.expected) - slices.Sort(actual) - - if !reflect.DeepEqual(actual, test.expected) { - t.Errorf("Expected %v, but got %v", test.expected, actual) - } - }) - } -} - func TestIsMountPoint(t *testing.T) { testDir := t.TempDir() mountpointS3MountPath := filepath.Join(testDir, "/var/lib/kubelet/pods/46efe8aa-75d9-4b12-8fdd-0ce0c2cabd99/volumes/kubernetes.io~csi/s3-mp-csi-pv/mount") @@ -357,8 +239,8 @@ func TestIsMountPoint(t *testing.T) { t.Run(name, func(t *testing.T) { mounter := &mounter.SystemdMounter{Mounter: mount.NewFakeMounter(test.procMountsContent)} isMountPoint, err := mounter.IsMountPoint(test.target) - assertEquals(t, test.isMountPoint, isMountPoint) - assertEquals(t, test.expectErr, err != nil) + assert.Equals(t, test.isMountPoint, isMountPoint) + assert.Equals(t, test.expectErr, err != nil) }) } } diff --git a/pkg/driver/node/mounter/user_agent_test.go b/pkg/driver/node/mounter/user_agent_test.go index 5308c92f..cd308ae2 100644 --- a/pkg/driver/node/mounter/user_agent_test.go +++ b/pkg/driver/node/mounter/user_agent_test.go @@ -18,6 +18,8 @@ package mounter import ( "testing" + + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/credentialprovider" ) func TestUserAgent(t *testing.T) { @@ -39,12 +41,12 @@ func TestUserAgent(t *testing.T) { }, "driver authentication source": { k8sVersion: "v1.30.2-eks-db838b0", - authenticationSource: AuthenticationSourceDriver, + authenticationSource: credentialprovider.AuthenticationSourceDriver, result: "s3-csi-driver/ credential-source#driver k8s/v1.30.2-eks-db838b0", }, "pod authentication source": { k8sVersion: "v1.30.2-eks-db838b0", - authenticationSource: AuthenticationSourcePod, + authenticationSource: credentialprovider.AuthenticationSourcePod, result: "s3-csi-driver/ credential-source#pod k8s/v1.30.2-eks-db838b0", }, } diff --git a/pkg/driver/node/node.go b/pkg/driver/node/node.go index 9ee0a01d..2582e6a4 100644 --- a/pkg/driver/node/node.go +++ b/pkg/driver/node/node.go @@ -25,10 +25,10 @@ import ( "github.com/container-storage-interface/spec/lib/go/csi" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" - "k8s.io/apimachinery/pkg/util/sets" "k8s.io/klog/v2" "k8s.io/mount-utils" + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/credentialprovider" "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/mounter" "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/targetpath" "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/volumecontext" @@ -58,28 +58,15 @@ var ( // S3NodeServer is the implementation of the csi.NodeServer interface type S3NodeServer struct { - NodeID string - Mounter mounter.Mounter - credentialProvider *mounter.CredentialProvider + NodeID string + Mounter mounter.Mounter } -func NewS3NodeServer(nodeID string, mounter mounter.Mounter, credentialProvider *mounter.CredentialProvider) *S3NodeServer { - return &S3NodeServer{NodeID: nodeID, Mounter: mounter, credentialProvider: credentialProvider} +func NewS3NodeServer(nodeID string, mounter mounter.Mounter) *S3NodeServer { + return &S3NodeServer{NodeID: nodeID, Mounter: mounter} } func (ns *S3NodeServer) NodeStageVolume(ctx context.Context, req *csi.NodeStageVolumeRequest) (*csi.NodeStageVolumeResponse, error) { - volumeCtx := req.GetVolumeContext() - if volumeCtx[volumecontext.AuthenticationSource] == mounter.AuthenticationSourcePod { - podID := volumeCtx[volumecontext.CSIPodUID] - volumeID := req.GetVolumeId() - if podID != "" && volumeID != "" { - err := ns.credentialProvider.CleanupToken(volumeID, podID) - if err != nil { - klog.V(4).Infof("NodeStageVolume: Failed to cleanup token for pod/volume %s/%s: %v", podID, volumeID, err) - } - } - } - return nil, status.Error(codes.Unimplemented, "") } @@ -131,16 +118,11 @@ func (ns *S3NodeServer) NodePublishVolume(ctx context.Context, req *csi.NodePubl } args := mountpoint.ParseArgs(mountpointArgs) - - credentials, err := ns.credentialProvider.Provide(ctx, req.VolumeId, req.VolumeContext, args) - if err != nil { - klog.Errorf("NodePublishVolume: failed to provide credentials: %v", err) - return nil, err - } - klog.V(4).Infof("NodePublishVolume: mounting %s at %s with options %v", bucket, target, args.SortedList()) - if err := ns.Mounter.Mount(bucket, target, credentials, args); err != nil { + credentialCtx := credentialProvideContextFromPublishRequest(req, args) + + if err := ns.Mounter.Mount(ctx, bucket, target, credentialCtx, args); err != nil { os.Remove(target) return nil, status.Errorf(codes.Internal, "Could not mount %q at %q: %v", bucket, target, err) } @@ -149,30 +131,6 @@ func (ns *S3NodeServer) NodePublishVolume(ctx context.Context, req *csi.NodePubl return &csi.NodePublishVolumeResponse{}, nil } -/** - * Compile mounting options into a singular set - */ -func compileMountOptions(currentOptions []string, newOptions []string) []string { - allMountOptions := sets.NewString() - - for _, currentMountOptions := range currentOptions { - if len(currentMountOptions) > 0 { - allMountOptions.Insert(currentMountOptions) - } - } - - for _, mountOption := range newOptions { - // disallow options that don't make sense in CSI - switch mountOption { - case "--foreground", "-f", "--help", "-h", "--version", "-v": - continue - } - allMountOptions.Insert(mountOption) - } - - return allMountOptions.List() -} - func getKubeletPath() string { kubeletPath := os.Getenv("KUBELET_PATH") if kubeletPath == "" { @@ -209,26 +167,14 @@ func (ns *S3NodeServer) NodeUnpublishVolume(ctx context.Context, req *csi.NodeUn return &csi.NodeUnpublishVolumeResponse{}, nil } + credentialCtx := credentialCleanupContextFromUnpublishRequest(req) + klog.V(4).Infof("NodeUnpublishVolume: unmounting %s", target) - err = ns.Mounter.Unmount(target) + err = ns.Mounter.Unmount(target, credentialCtx) if err != nil { return nil, status.Errorf(codes.Internal, "Could not unmount %q: %v", target, err) } - targetPath, err := targetpath.Parse(target) - if err == nil { - if targetPath.VolumeID != volumeID { - klog.V(4).Infof("NodeUnpublishVolume: Volume ID from parsed target path differs from Volume ID passed: %s (parsed) != %s (passed)", targetPath.VolumeID, volumeID) - } else { - err := ns.credentialProvider.CleanupToken(targetPath.VolumeID, targetPath.PodID) - if err != nil { - klog.V(4).Infof("NodeUnpublishVolume: Failed to cleanup token for pod/volume %s/%s: %v", targetPath.PodID, volumeID, err) - } - } - } else { - klog.V(4).Infof("NodeUnpublishVolume: Failed to parse target path %s: %v", target, err) - } - return &csi.NodeUnpublishVolumeResponse{}, nil } @@ -283,6 +229,45 @@ func (ns *S3NodeServer) isValidVolumeCapabilities(volCaps []*csi.VolumeCapabilit return foundAll } +func credentialProvideContextFromPublishRequest(req *csi.NodePublishVolumeRequest, args mountpoint.Args) credentialprovider.ProvideContext { + volumeCtx := req.GetVolumeContext() + + podID := volumeCtx[volumecontext.CSIPodUID] + if podID == "" { + podID, _ = podIDFromTargetPath(req.GetTargetPath()) + } + + bucketRegion, _ := args.Value(mountpoint.ArgRegion) + + return credentialprovider.ProvideContext{ + PodID: podID, + VolumeID: req.GetVolumeId(), + AuthenticationSource: volumeCtx[volumecontext.AuthenticationSource], + PodNamespace: volumeCtx[volumecontext.CSIPodNamespace], + ServiceAccountTokens: volumeCtx[volumecontext.CSIServiceAccountTokens], + ServiceAccountName: volumeCtx[volumecontext.CSIServiceAccountName], + StsRegion: volumeCtx[volumecontext.STSRegion], + BucketRegion: bucketRegion, + } +} + +func credentialCleanupContextFromUnpublishRequest(req *csi.NodeUnpublishVolumeRequest) credentialprovider.CleanupContext { + podID, _ := podIDFromTargetPath(req.GetTargetPath()) + return credentialprovider.CleanupContext{ + VolumeID: req.GetVolumeId(), + PodID: podID, + } +} + +func podIDFromTargetPath(target string) (string, bool) { + targetPath, err := targetpath.Parse(target) + if err != nil { + klog.V(4).Infof("Failed to parse target path %s: %v", target, err) + return "", false + } + return targetPath.PodID, true +} + // logSafeNodePublishVolumeRequest returns a copy of given `csi.NodePublishVolumeRequest` // with sensitive fields removed. func logSafeNodePublishVolumeRequest(req *csi.NodePublishVolumeRequest) *csi.NodePublishVolumeRequest { diff --git a/pkg/driver/node/node_test.go b/pkg/driver/node/node_test.go index 040f1635..58cf909c 100644 --- a/pkg/driver/node/node_test.go +++ b/pkg/driver/node/node_test.go @@ -2,22 +2,18 @@ package node_test import ( "errors" - "fmt" "io/fs" - "os" - "path/filepath" "testing" + csi "github.com/container-storage-interface/spec/lib/go/csi" + "github.com/golang/mock/gomock" + "golang.org/x/net/context" + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node" + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/credentialprovider" "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/mounter" mock_driver "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/mounter/mocks" "github.com/awslabs/aws-s3-csi-driver/pkg/mountpoint" - "github.com/awslabs/aws-s3-csi-driver/pkg/util/testutil/assert" - csi "github.com/container-storage-interface/spec/lib/go/csi" - "github.com/golang/mock/gomock" - "github.com/google/go-cmp/cmp/cmpopts" - "github.com/google/uuid" - "golang.org/x/net/context" ) type nodeServerTestEnv struct { @@ -28,14 +24,8 @@ type nodeServerTestEnv struct { func initNodeServerTestEnv(t *testing.T) *nodeServerTestEnv { mockCtl := gomock.NewController(t) - defer mockCtl.Finish() mockMounter := mock_driver.NewMockMounter(mockCtl) - credentialProvider := mounter.NewCredentialProvider(nil, t.TempDir(), mounter.RegionFromIMDSOnce) - server := node.NewS3NodeServer( - "test-nodeID", - mockMounter, - credentialProvider, - ) + server := node.NewS3NodeServer("test-nodeID", mockMounter) return &nodeServerTestEnv{ mockCtl: mockCtl, mockMounter: mockMounter, @@ -73,7 +63,14 @@ func TestNodePublishVolume(t *testing.T) { VolumeContext: map[string]string{"bucketName": bucketName}, } - nodeTestEnv.mockMounter.EXPECT().Mount(gomock.Eq(bucketName), gomock.Eq(targetPath), gomock.Any(), gomock.Any()) + nodeTestEnv.mockMounter.EXPECT().Mount( + gomock.Eq(context.Background()), + gomock.Eq(bucketName), + gomock.Eq(targetPath), + gomock.Eq(credentialprovider.ProvideContext{ + VolumeID: volumeId, + }), + gomock.Any()) _, err := nodeTestEnv.server.NodePublishVolume(ctx, req) if err != nil { t.Fatalf("NodePublishVolume is failed: %v", err) @@ -101,7 +98,14 @@ func TestNodePublishVolume(t *testing.T) { VolumeContext: map[string]string{"bucketName": bucketName}, } - nodeTestEnv.mockMounter.EXPECT().Mount(gomock.Eq(bucketName), gomock.Eq(targetPath), gomock.Any(), gomock.Eq(mountpoint.ParseArgs([]string{"--read-only"}))) + nodeTestEnv.mockMounter.EXPECT().Mount( + gomock.Eq(context.Background()), + gomock.Eq(bucketName), + gomock.Eq(targetPath), + gomock.Eq(credentialprovider.ProvideContext{ + VolumeID: volumeId, + }), + gomock.Eq(mountpoint.ParseArgs([]string{"--read-only"}))) _, err := nodeTestEnv.server.NodePublishVolume(ctx, req) if err != nil { t.Fatalf("NodePublishVolume is failed: %v", err) @@ -132,7 +136,14 @@ func TestNodePublishVolume(t *testing.T) { Readonly: true, } - nodeTestEnv.mockMounter.EXPECT().Mount(gomock.Eq(bucketName), gomock.Eq(targetPath), gomock.Any(), gomock.Eq(mountpoint.ParseArgs([]string{"--bar", "--foo", "--read-only", "--test=123"}))) + nodeTestEnv.mockMounter.EXPECT().Mount( + gomock.Eq(context.Background()), + gomock.Eq(bucketName), + gomock.Eq(targetPath), + gomock.Eq(credentialprovider.ProvideContext{ + VolumeID: volumeId, + }), + gomock.Eq(mountpoint.ParseArgs([]string{"--bar", "--foo", "--read-only", "--test=123"}))) _, err := nodeTestEnv.server.NodePublishVolume(ctx, req) if err != nil { t.Fatalf("NodePublishVolume is failed: %v", err) @@ -164,7 +175,12 @@ func TestNodePublishVolume(t *testing.T) { } nodeTestEnv.mockMounter.EXPECT().Mount( - gomock.Eq(bucketName), gomock.Eq(targetPath), gomock.Any(), + gomock.Eq(context.Background()), + gomock.Eq(bucketName), + gomock.Eq(targetPath), + gomock.Eq(credentialprovider.ProvideContext{ + VolumeID: volumeId, + }), gomock.Eq(mountpoint.ParseArgs([]string{"--read-only", "--test=123"}))).Return(nil) _, err := nodeTestEnv.server.NodePublishVolume(ctx, req) if err != nil { @@ -219,7 +235,7 @@ func TestNodeUnpublishVolume(t *testing.T) { } nodeTestEnv.mockMounter.EXPECT().IsMountPoint(gomock.Eq(targetPath)).Return(true, nil) - nodeTestEnv.mockMounter.EXPECT().Unmount(gomock.Eq(targetPath)).Return(nil) + nodeTestEnv.mockMounter.EXPECT().Unmount(gomock.Eq(targetPath), gomock.Any()) _, err := nodeTestEnv.server.NodeUnpublishVolume(ctx, req) if err != nil { t.Fatalf("NodePublishVolume failed: %v", err) @@ -258,7 +274,12 @@ func TestNodeUnpublishVolume(t *testing.T) { } nodeTestEnv.mockMounter.EXPECT().IsMountPoint(gomock.Eq(targetPath)).Return(true, nil) - nodeTestEnv.mockMounter.EXPECT().Unmount(gomock.Eq(targetPath)).Return(errors.New("")) + nodeTestEnv.mockMounter.EXPECT().Unmount( + gomock.Eq(targetPath), + gomock.Eq(credentialprovider.CleanupContext{ + VolumeID: volumeId, + }), + ).Return(errors.New("")) _, err := nodeTestEnv.server.NodeUnpublishVolume(ctx, req) if err == nil { t.Fatalf("NodePublishVolume must fail") @@ -292,31 +313,6 @@ func TestNodeUnpublishVolume(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, tc.testFunc) } - - t.Run("Cleaning Service Account Token", func(t *testing.T) { - containerPluginDir := t.TempDir() - credentialProvider := mounter.NewCredentialProvider(nil, containerPluginDir, mounter.RegionFromIMDSOnce) - nodeServer := node.NewS3NodeServer("test-node-id", &dummyMounter{}, credentialProvider) - - podID := uuid.New().String() - volID := "test-vol-id" - - serviceAccountTokenPath := filepath.Join(containerPluginDir, fmt.Sprintf("%s-%s.token", podID, volID)) - _, err := os.Create(serviceAccountTokenPath) - assert.Equals(t, nil, err) - - targetPath := fmt.Sprintf("/var/lib/kubelet/pods/%s/volumes/kubernetes.io~csi/%s/mount", podID, volID) - - _, err = nodeServer.NodeUnpublishVolume(context.Background(), &csi.NodeUnpublishVolumeRequest{ - VolumeId: volID, - TargetPath: targetPath, - }) - assert.Equals(t, nil, err) - - _, err = os.Stat(serviceAccountTokenPath) - assert.Equals(t, cmpopts.AnyError, err) - assert.Equals(t, true, errors.Is(err, fs.ErrNotExist)) - }) } func TestNodeGetCapabilities(t *testing.T) { @@ -339,15 +335,16 @@ func TestNodeGetCapabilities(t *testing.T) { var _ mounter.Mounter = &dummyMounter{} -type dummyMounter struct { -} +type dummyMounter struct{} -func (d *dummyMounter) Mount(bucketName string, target string, credentials *mounter.MountCredentials, args mountpoint.Args) error { +func (d *dummyMounter) Mount(ctx context.Context, bucketName string, target string, provideCtx credentialprovider.ProvideContext, args mountpoint.Args) error { return nil } -func (d *dummyMounter) Unmount(target string) error { + +func (d *dummyMounter) Unmount(target string, ctx credentialprovider.CleanupContext) error { return nil } + func (d *dummyMounter) IsMountPoint(target string) (bool, error) { return true, nil } diff --git a/tests/sanity/sanity_test.go b/tests/sanity/sanity_test.go index 2ec9e8eb..39539cbc 100644 --- a/tests/sanity/sanity_test.go +++ b/tests/sanity/sanity_test.go @@ -73,7 +73,6 @@ var _ = BeforeSuite(func() { NodeServer: node.NewS3NodeServer( "fake_id", &mounter.FakeMounter{}, - mounter.NewCredentialProvider(nil, GinkgoT().TempDir(), mounter.RegionFromIMDSOnce), ), } go func() {