diff --git a/packages/credential-provider-node/src/credential-provider-node.integ.spec.ts b/packages/credential-provider-node/src/credential-provider-node.integ.spec.ts index 4d65fcb757cb..6a7886b175c4 100644 --- a/packages/credential-provider-node/src/credential-provider-node.integ.spec.ts +++ b/packages/credential-provider-node/src/credential-provider-node.integ.spec.ts @@ -8,6 +8,7 @@ import { AdaptiveRetryStrategy, StandardRetryStrategy } from "@smithy/util-retry import { PassThrough } from "stream"; import { defaultProvider } from "./defaultProvider"; +import { clearDefaultProviderCache } from "./memoizeGlobal"; jest.mock("fs", () => { const actual = jest.requireActual("fs"); @@ -1273,4 +1274,91 @@ describe("credential-provider-node integration test", () => { expect(async () => sts.getCallerIdentity({})).rejects.toThrow("Could not load credentials from any providers"); }); }); + + describe("Global Cache Behavior", () => { + beforeEach(() => { + clearDefaultProviderCache(); + jest.clearAllMocks(); + for (const variable in RESERVED_ENVIRONMENT_VARIABLES) { + delete process.env[variable]; + } + }); + + afterEach(() => { + clearDefaultProviderCache(); + }); + + it("should cache credentials across provider instances", async () => { + // Set up environment credentials to avoid profile warning + process.env.AWS_ACCESS_KEY_ID = "AKID"; + process.env.AWS_SECRET_ACCESS_KEY = "SECRET"; + + const provider1 = defaultProvider(); + const provider2 = defaultProvider(); + + const creds1 = await provider1(); + const creds2 = await provider2(); + + expect(creds1).toEqual(creds2); + expect(creds1).toEqual({ + accessKeyId: "AKID", + secretAccessKey: "SECRET", + $source: { + CREDENTIALS_ENV_VARS: "g", + }, + }); + }); + + it("should maintain separate caches for different profiles", async () => { + // Clear env variables to allow profile credentials + delete process.env.AWS_ACCESS_KEY_ID; + delete process.env.AWS_SECRET_ACCESS_KEY; + + Object.assign(iniProfileData, { + profile1: { + aws_access_key_id: "AKID1", + aws_secret_access_key: "SECRET1", + }, + profile2: { + aws_access_key_id: "AKID2", + aws_secret_access_key: "SECRET2", + }, + }); + + const provider1 = defaultProvider({ profile: "profile1" }); + const provider2 = defaultProvider({ profile: "profile2" }); + + const [creds1, creds2] = await Promise.all([provider1(), provider2()]); + + expect(creds1.accessKeyId).toBe("AKID1"); + expect(creds2.accessKeyId).toBe("AKID2"); + expect(creds1).not.toEqual(creds2); + }); + + it("should handle expired credentials", async () => { + process.env.AWS_ACCESS_KEY_ID = "AKID"; + process.env.AWS_SECRET_ACCESS_KEY = "SECRET"; + + const provider = defaultProvider(); + const creds = await provider(); + + // Simulate expiration + Object.defineProperty(creds, "expiration", { + value: new Date(Date.now() - 300001), // Just over 5 minutes ago + }); + + // Should force a refresh on next call + const newCreds = await provider(); + expect(newCreds).toBeDefined(); + expect(newCreds.accessKeyId).toBe("AKID"); + }); + + it("should handle provider errors", async () => { + delete process.env.AWS_ACCESS_KEY_ID; + delete process.env.AWS_SECRET_ACCESS_KEY; + + const provider = defaultProvider(); + await expect(provider()).rejects.toThrow("Could not load credentials from any providers"); + }); + }); }); diff --git a/packages/credential-provider-node/src/defaultProvider.ts b/packages/credential-provider-node/src/defaultProvider.ts index b16112434ec6..4c02a17afa99 100644 --- a/packages/credential-provider-node/src/defaultProvider.ts +++ b/packages/credential-provider-node/src/defaultProvider.ts @@ -9,6 +9,7 @@ import { chain, CredentialsProviderError, memoize } from "@smithy/property-provi import { ENV_PROFILE } from "@smithy/shared-ini-file-loader"; import { AwsCredentialIdentity, MemoizedProvider } from "@smithy/types"; +import { memoizeGlobal } from "./memoizeGlobal"; import { remoteProvider } from "./remoteProvider"; /** @@ -60,19 +61,18 @@ let multipleCredentialSourceWarningEmitted = false; * @see {@link fromContainerMetadata} The function used to source credentials from the * ECS Container Metadata Service. */ -export const defaultProvider = (init: DefaultProviderInit = {}): MemoizedProvider => - memoize( - chain( - async () => { - const profile = init.profile ?? process.env[ENV_PROFILE]; - if (profile) { - const envStaticCredentialsAreSet = process.env[ENV_KEY] && process.env[ENV_SECRET]; - if (envStaticCredentialsAreSet) { - if (!multipleCredentialSourceWarningEmitted) { - const warnFn = - init.logger?.warn && init.logger?.constructor?.name !== "NoOpLogger" ? init.logger.warn : console.warn; - warnFn( - `@aws-sdk/credential-provider-node - defaultProvider::fromEnv WARNING: +export const defaultProvider = (init: DefaultProviderInit = {}): MemoizedProvider => { + const provider = chain( + async () => { + const profile = init.profile ?? process.env[ENV_PROFILE]; + if (profile) { + const envStaticCredentialsAreSet = process.env[ENV_KEY] && process.env[ENV_SECRET]; + if (envStaticCredentialsAreSet) { + if (!multipleCredentialSourceWarningEmitted) { + const warnFn = + init.logger?.warn && init.logger?.constructor?.name !== "NoOpLogger" ? init.logger.warn : console.warn; + warnFn( + `@aws-sdk/credential-provider-node - defaultProvider::fromEnv WARNING: Multiple credential sources detected: Both AWS_PROFILE and the pair AWS_ACCESS_KEY_ID/AWS_SECRET_ACCESS_KEY static credentials are set. This SDK will proceed with the AWS_PROFILE value. @@ -81,59 +81,72 @@ export const defaultProvider = (init: DefaultProviderInit = {}): MemoizedProvide Please ensure that your environment only sets either the AWS_PROFILE or the AWS_ACCESS_KEY_ID/AWS_SECRET_ACCESS_KEY pair. ` - ); - multipleCredentialSourceWarningEmitted = true; - } + ); + multipleCredentialSourceWarningEmitted = true; } - throw new CredentialsProviderError("AWS_PROFILE is set, skipping fromEnv provider.", { - logger: init.logger, - tryNextLink: true, - }); } - init.logger?.debug("@aws-sdk/credential-provider-node - defaultProvider::fromEnv"); - return fromEnv(init)(); - }, - async () => { - init.logger?.debug("@aws-sdk/credential-provider-node - defaultProvider::fromSSO"); - const { ssoStartUrl, ssoAccountId, ssoRegion, ssoRoleName, ssoSession } = init; - if (!ssoStartUrl && !ssoAccountId && !ssoRegion && !ssoRoleName && !ssoSession) { - throw new CredentialsProviderError( - "Skipping SSO provider in default chain (inputs do not include SSO fields).", - { logger: init.logger } - ); - } - const { fromSSO } = await import("@aws-sdk/credential-provider-sso"); - return fromSSO(init)(); - }, - async () => { - init.logger?.debug("@aws-sdk/credential-provider-node - defaultProvider::fromIni"); - const { fromIni } = await import("@aws-sdk/credential-provider-ini"); - return fromIni(init)(); - }, - async () => { - init.logger?.debug("@aws-sdk/credential-provider-node - defaultProvider::fromProcess"); - const { fromProcess } = await import("@aws-sdk/credential-provider-process"); - return fromProcess(init)(); - }, - async () => { - init.logger?.debug("@aws-sdk/credential-provider-node - defaultProvider::fromTokenFile"); - const { fromTokenFile } = await import("@aws-sdk/credential-provider-web-identity"); - return fromTokenFile(init)(); - }, - async () => { - init.logger?.debug("@aws-sdk/credential-provider-node - defaultProvider::remoteProvider"); - return (await remoteProvider(init))(); - }, - async () => { - throw new CredentialsProviderError("Could not load credentials from any providers", { - tryNextLink: false, + throw new CredentialsProviderError("AWS_PROFILE is set, skipping fromEnv provider.", { logger: init.logger, + tryNextLink: true, }); } - ), + init.logger?.debug("@aws-sdk/credential-provider-node - defaultProvider::fromEnv"); + return fromEnv(init)(); + }, + async () => { + init.logger?.debug("@aws-sdk/credential-provider-node - defaultProvider::fromSSO"); + const { ssoStartUrl, ssoAccountId, ssoRegion, ssoRoleName, ssoSession } = init; + if (!ssoStartUrl && !ssoAccountId && !ssoRegion && !ssoRoleName && !ssoSession) { + throw new CredentialsProviderError( + "Skipping SSO provider in default chain (inputs do not include SSO fields).", + { logger: init.logger } + ); + } + const { fromSSO } = await import("@aws-sdk/credential-provider-sso"); + return fromSSO(init)(); + }, + async () => { + init.logger?.debug("@aws-sdk/credential-provider-node - defaultProvider::fromIni"); + const { fromIni } = await import("@aws-sdk/credential-provider-ini"); + return fromIni(init)(); + }, + async () => { + init.logger?.debug("@aws-sdk/credential-provider-node - defaultProvider::fromProcess"); + const { fromProcess } = await import("@aws-sdk/credential-provider-process"); + return fromProcess(init)(); + }, + async () => { + init.logger?.debug("@aws-sdk/credential-provider-node - defaultProvider::fromTokenFile"); + const { fromTokenFile } = await import("@aws-sdk/credential-provider-web-identity"); + return fromTokenFile(init)(); + }, + async () => { + init.logger?.debug("@aws-sdk/credential-provider-node - defaultProvider::remoteProvider"); + return (await remoteProvider(init))(); + }, + async () => { + throw new CredentialsProviderError("Could not load credentials from any providers", { + tryNextLink: false, + logger: init.logger, + }); + } + ); + + return memoizeGlobal( + async () => { + try { + return await provider(); + } catch (error) { + if (error instanceof CredentialsProviderError) { + throw error; + } + throw new CredentialsProviderError(error.message, { tryNextLink: true }); + } + }, credentialsTreatedAsExpired, credentialsWillNeedRefresh ); +}; /** * @internal diff --git a/packages/credential-provider-node/src/memoizeGlobal.ts b/packages/credential-provider-node/src/memoizeGlobal.ts new file mode 100644 index 000000000000..028e6d6b9a9a --- /dev/null +++ b/packages/credential-provider-node/src/memoizeGlobal.ts @@ -0,0 +1,44 @@ +import { memoize } from "@smithy/property-provider"; +import { AwsCredentialIdentity } from "@smithy/types"; + +const globalProviderCache: Map Promise> = new Map(); + +function hashProvider(provider: () => Promise, config?: string): string { + return config || provider.name || Math.random().toString(36).substring(7); +} + +export function memoizeGlobal( + provider: () => Promise, + isExpired: (resolved: T) => boolean, + requiresRefresh?: (resolved: T) => boolean, + cacheKey?: string +): () => Promise { + const key = hashProvider(provider, cacheKey); + const cached = globalProviderCache.get(key); + if (cached) { + return cached as () => Promise; + } + + const memoized = memoize(provider, isExpired, requiresRefresh); + const wrappedProvider = async () => { + try { + const creds = await memoized(); + if (isExpired(creds)) { + globalProviderCache.delete(key); + // Force memoize to refresh by calling provider directly + return await provider(); + } + return creds; + } catch (error) { + globalProviderCache.delete(key); + throw error; + } + }; + + globalProviderCache.set(key, wrappedProvider); + return wrappedProvider; +} + +export function clearDefaultProviderCache(): void { + globalProviderCache.clear(); +}