@@ -33,6 +33,7 @@ import (
3333 "github.com/aws/aws-sdk-go-v2/service/ec2"
3434 "github.com/aws/aws-sdk-go-v2/service/ec2/types"
3535 "github.com/aws/aws-sdk-go-v2/service/sagemaker"
36+ "github.com/aws/aws-sdk-go-v2/service/sts"
3637 "github.com/aws/smithy-go"
3738 "github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/batcher"
3839 dm "github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/cloud/devicemanager"
8889)
8990
9091const (
91- cacheForgetDelay = 1 * time .Hour
92- volInitCacheForgetDelay = 6 * time .Hour
92+ cacheForgetDelay = 1 * time .Hour
93+ volInitCacheForgetDelay = 6 * time .Hour
94+ getCallerIdentityRetryDelay = 30 * time .Second
9395)
9496
9597// VolumeStatusInitializingState is const reported by EC2 DescribeVolumeStatus which AWS SDK does not have type for.
@@ -320,6 +322,7 @@ type batcherManager struct {
320322}
321323
322324type cloud struct {
325+ awsConfig aws.Config
323326 region string
324327 ec2 EC2API
325328 sm SageMakerAPI
@@ -331,18 +334,14 @@ type cloud struct {
331334 latestClientTokens expiringcache.ExpiringCache [string , int ]
332335 volumeInitializations expiringcache.ExpiringCache [string , volumeInitialization ]
333336 accountID string
337+ accountIDOnce sync.Once
334338}
335339
336340var _ Cloud = & cloud {}
337341
338342// NewCloud returns a new instance of AWS cloud
339343// It panics if session is invalid.
340- func NewCloud (region string , accountID string , awsSdkDebugLog bool , userAgentExtra string , batching bool , deprecatedMetrics bool ) (Cloud , error ) {
341- c := newEC2Cloud (region , accountID , awsSdkDebugLog , userAgentExtra , batching , deprecatedMetrics )
342- return c , nil
343- }
344-
345- func newEC2Cloud (region string , accountID string , awsSdkDebugLog bool , userAgentExtra string , batchingEnabled bool , deprecatedMetrics bool ) Cloud {
344+ func NewCloud (region string , awsSdkDebugLog bool , userAgentExtra string , batchingEnabled bool , deprecatedMetrics bool ) Cloud {
346345 cfg , err := config .LoadDefaultConfig (context .Background (), config .WithRegion (region ))
347346 if err != nil {
348347 panic (err )
@@ -386,11 +385,12 @@ func newEC2Cloud(region string, accountID string, awsSdkDebugLog bool, userAgent
386385
387386 var bm * batcherManager
388387 if batchingEnabled {
389- klog .V (4 ).InfoS ("newEC2Cloud : batching enabled" )
388+ klog .V (4 ).InfoS ("NewCloud : batching enabled" )
390389 bm = newBatcherManager (svc )
391390 }
392391
393392 return & cloud {
393+ awsConfig : cfg ,
394394 region : region ,
395395 dm : dm .NewDeviceManager (),
396396 ec2 : svc ,
@@ -400,7 +400,6 @@ func newEC2Cloud(region string, accountID string, awsSdkDebugLog bool, userAgent
400400 vwp : vwp ,
401401 likelyBadDeviceNames : expiringcache.New [string , sync.Map ](cacheForgetDelay ),
402402 latestClientTokens : expiringcache.New [string , int ](cacheForgetDelay ),
403- accountID : accountID ,
404403 volumeInitializations : expiringcache.New [string , volumeInitialization ](volInitCacheForgetDelay ),
405404 }
406405}
@@ -997,7 +996,11 @@ func (c *cloud) attachDiskHyperPod(ctx context.Context, volumeID, nodeID string)
997996 klog .V (2 ).InfoS ("AttachDisk: HyperPod node detected" , "volumeID" , volumeID , "nodeID" , nodeID )
998997
999998 instanceID := getInstanceIDFromHyperPodNode (nodeID )
1000- clusterArn := c .buildHyperPodClusterArn (nodeID )
999+ accountID , err := c .getAccountID (ctx )
1000+ if err != nil {
1001+ return "" , fmt .Errorf ("failed to get account ID: %w" , err )
1002+ }
1003+ clusterArn := buildHyperPodClusterArn (nodeID , c .region , accountID )
10011004
10021005 klog .V (5 ).InfoS ("HyperPod attachment details" ,
10031006 "volumeID" , volumeID ,
@@ -1025,7 +1028,7 @@ func (c *cloud) attachDiskHyperPod(ctx context.Context, volumeID, nodeID string)
10251028
10261029 // Wait for attachment completion
10271030 deviceName := aws .ToString (resp .DeviceName )
1028- _ , err : = c .WaitForAttachmentState (
1031+ _ , err = c .WaitForAttachmentState (
10291032 ctx ,
10301033 types .VolumeAttachmentStateAttached ,
10311034 volumeID ,
@@ -1099,7 +1102,11 @@ func (c *cloud) detachDiskHyperPod(ctx context.Context, volumeID, nodeID string)
10991102 klog .V (2 ).InfoS ("DetachDisk: HyperPod node detected" , "volumeID" , volumeID , "nodeID" , nodeID )
11001103
11011104 instanceID := getInstanceIDFromHyperPodNode (nodeID )
1102- clusterArn := c .buildHyperPodClusterArn (nodeID )
1105+ accountID , err := c .getAccountID (ctx )
1106+ if err != nil {
1107+ return fmt .Errorf ("failed to get account ID: %w" , err )
1108+ }
1109+ clusterArn := buildHyperPodClusterArn (nodeID , c .region , accountID )
11031110
11041111 klog .V (4 ).InfoS ("HyperPod detachment details" ,
11051112 "volumeID" , volumeID ,
@@ -1114,7 +1121,7 @@ func (c *cloud) detachDiskHyperPod(ctx context.Context, volumeID, nodeID string)
11141121 }
11151122 klog .V (4 ).InfoS ("Calling DetachClusterNodeVolumeInput" , "input" , input )
11161123
1117- _ , err : = c .sm .DetachClusterNodeVolume (ctx , input )
1124+ _ , err = c .sm .DetachClusterNodeVolume (ctx , input )
11181125 if err != nil {
11191126 if isAWSHyperPodErrorIncorrectState (err ) ||
11201127 isAWSHyperPodErrorInvalidAttachmentNotFound (err ) ||
@@ -1450,9 +1457,9 @@ func getInstanceIDFromHyperPodNode(nodeID string) string {
14501457}
14511458
14521459// Only for hyperpod node, buildHyperPodClusterArn: arn:aws:sagemaker:region:account:cluster/clusterID.
1453- func ( c * cloud ) buildHyperPodClusterArn (nodeID string ) string {
1460+ func buildHyperPodClusterArn (nodeID string , region string , accountID string ) string {
14541461 parts := strings .Split (nodeID , "-" )
1455- return fmt .Sprintf ("arn:aws:sagemaker:%s:%s:cluster/%s" , c . region , c . accountID , parts [1 ])
1462+ return fmt .Sprintf ("arn:aws:sagemaker:%s:%s:cluster/%s" , region , accountID , parts [1 ])
14561463}
14571464
14581465// For hyperpod node, AssociatedResource is in arn:aws:sagemaker:region:account:cluster/clusterID-instanceId format.
@@ -1916,6 +1923,54 @@ func (c *cloud) waitForVolume(ctx context.Context, volumeID string) error {
19161923 return err
19171924}
19181925
1926+ // getAccountID returns the account ID of the AWS Account for the IAM credentials in use.
1927+ //
1928+ // In the first call (or any calls made before the first call succeeds), getAccountID
1929+ // will attempt to determine the Account ID via sts:GetCallerIdentity.
1930+ // This attempt will retry indefinitely, however getAccountID will return when ctx is cancelled,
1931+ // leaving the account ID thread to run in the background.
1932+ //
1933+ // In subsequent calls (after the first success), getAccountID will use a cached value.
1934+ func (c * cloud ) getAccountID (ctx context.Context ) (string , error ) {
1935+ accountIDRetrieved := make (chan struct {}, 1 )
1936+
1937+ // Start background thread if it isn't already.
1938+ // Intentionally runs in the background until account ID is retrieved, so we don't pass the context.
1939+ //nolint:contextcheck
1940+ go func () {
1941+ c .accountIDOnce .Do (func () {
1942+ for c .accountID == "" {
1943+ cfg , err := config .LoadDefaultConfig (context .Background (), config .WithRegion (c .region ))
1944+ if err != nil {
1945+ klog .ErrorS (err , "Failed to create AWS config for account ID retrieval" )
1946+ }
1947+
1948+ stsClient := sts .NewFromConfig (cfg )
1949+ resp , err := stsClient .GetCallerIdentity (context .Background (), & sts.GetCallerIdentityInput {})
1950+ if err != nil {
1951+ klog .ErrorS (err , "Failed to get AWS account ID, required for HyperPod operations, will retry" )
1952+ time .Sleep (getCallerIdentityRetryDelay )
1953+ } else {
1954+ c .accountID = * resp .Account
1955+ klog .V (5 ).InfoS ("Retrieved AWS account ID for HyperPod operations" , "accountID" , c .accountID )
1956+ }
1957+ }
1958+ })
1959+
1960+ // Once.Do blocks until the function exits, even if we aren't the first caller.
1961+ // So the account ID must be available now.
1962+ accountIDRetrieved <- struct {}{}
1963+ }()
1964+
1965+ select {
1966+ case <- ctx .Done ():
1967+ return "" , ctx .Err ()
1968+
1969+ case <- accountIDRetrieved :
1970+ return c .accountID , nil
1971+ }
1972+ }
1973+
19191974// isAWSError returns a boolean indicating whether the error is AWS-related
19201975// and has the given code. More information on AWS error codes at:
19211976// https://docs.aws.amazon.com/AWSEC2/latest/APIReference/errors-overview.html
0 commit comments