From 8497f11a1733235f75ed292f8b6606ddc1606214 Mon Sep 17 00:00:00 2001 From: Simon Unge Date: Thu, 11 Sep 2025 22:53:17 +0000 Subject: [PATCH 1/5] Use ETS table for credentials storage instead of gen_server --- deps/rabbitmq_aws/Makefile | 2 +- deps/rabbitmq_aws/include/rabbitmq_aws.hrl | 10 + deps/rabbitmq_aws/src/rabbitmq_aws.erl | 671 ++++++------------ deps/rabbitmq_aws/src/rabbitmq_aws_app.erl | 16 +- deps/rabbitmq_aws/src/rabbitmq_aws_config.erl | 62 +- deps/rabbitmq_aws/src/rabbitmq_aws_sup.erl | 7 +- .../test/rabbitmq_aws_config_tests.erl | 22 +- .../test/rabbitmq_aws_sup_tests.erl | 31 - deps/rabbitmq_aws/test/rabbitmq_aws_tests.erl | 518 ++++---------- 9 files changed, 475 insertions(+), 864 deletions(-) delete mode 100644 deps/rabbitmq_aws/test/rabbitmq_aws_sup_tests.erl diff --git a/deps/rabbitmq_aws/Makefile b/deps/rabbitmq_aws/Makefile index 7ba1f949b3dd..4e85ebc41134 100644 --- a/deps/rabbitmq_aws/Makefile +++ b/deps/rabbitmq_aws/Makefile @@ -6,7 +6,7 @@ define PROJECT_ENV [] endef -LOCAL_DEPS = crypto inets ssl xmerl public_key +LOCAL_DEPS = crypto inets ssl xmerl public_key gun BUILD_DEPS = rabbit_common # We do not depend on rabbit therefore can't run the broker. DEP_PLUGINS = rabbit_common/mk/rabbitmq-build.mk diff --git a/deps/rabbitmq_aws/include/rabbitmq_aws.hrl b/deps/rabbitmq_aws/include/rabbitmq_aws.hrl index 6a0cacd81131..84c1527f578c 100644 --- a/deps/rabbitmq_aws/include/rabbitmq_aws.hrl +++ b/deps/rabbitmq_aws/include/rabbitmq_aws.hrl @@ -46,6 +46,9 @@ -define(LINEAR_BACK_OFF_MILLIS, 500). -define(MAX_RETRIES, 5). +-define(AWS_CREDENTIALS_TABLE, aws_credentials). +-define(AWS_CONFIG_TABLE, aws_config). + -type access_key() :: nonempty_string(). -type secret_access_key() :: nonempty_string(). -type expiration() :: calendar:datetime() | undefined. @@ -62,6 +65,13 @@ expiration :: non_neg_integer() | undefined }). +-record(aws_credentials, { + access_key :: access_key(), + secret_key :: secret_access_key(), + security_token :: security_token(), + expiration :: expiration() +}). + -type imdsv2token() :: #imdsv2token{}. -record(state, { diff --git a/deps/rabbitmq_aws/src/rabbitmq_aws.erl b/deps/rabbitmq_aws/src/rabbitmq_aws.erl index 0b4f7c7e9e48..1c4d599f0900 100644 --- a/deps/rabbitmq_aws/src/rabbitmq_aws.erl +++ b/deps/rabbitmq_aws/src/rabbitmq_aws.erl @@ -6,8 +6,6 @@ %% ==================================================================== -module(rabbitmq_aws). --behavior(gen_server). - %% API exports -export([ get/2, get/3, get/4, @@ -19,22 +17,16 @@ has_credentials/0, parse_uri/1, set_region/1, + get_region/0, + ensure_credentials_valid/0, ensure_imdsv2_token_valid/0, api_get_request/2, status_text/1, open_connection/1, open_connection/2, - close_connection/1 -]). - -%% gen-server exports --export([ - start_link/0, - init/1, - terminate/2, - code_change/3, - handle_call/3, - handle_cast/2, - handle_info/2 + close_connection/1, + direct_request/6, + endpoint/4, + sign_headers/9 ]). %% Export all for unit tests @@ -46,14 +38,117 @@ -include_lib("kernel/include/logger.hrl"). %% Types for new concurrent API --type connection_handle() :: {gun:conn_ref(), credential_context()}. --type credential_context() :: #{ - access_key => access_key(), - secret_access_key => secret_access_key(), - security_token => security_token(), - region => region(), - service => string() -}. +-type connection_handle() :: {gun:conn_ref(), string()}. + +%%==================================================================== +%% ETS-based state management +%%==================================================================== + +-spec get_credentials() -> + {ok, access_key(), secret_access_key(), security_token(), region()} | {error, term()}. +get_credentials() -> + get_credentials(10). + +-spec get_credentials(Retries :: non_neg_integer()) -> + {ok, access_key(), secret_access_key(), security_token(), region()} | {error, term()}. +get_credentials(Retries) -> + case ets:lookup(?AWS_CREDENTIALS_TABLE, current) of + [{current, Creds}] -> + case expired_credentials(Creds#aws_credentials.expiration) of + false -> + Region = get_region(), + {ok, Creds#aws_credentials.access_key, Creds#aws_credentials.secret_key, + Creds#aws_credentials.security_token, Region}; + true -> + refresh_credentials_with_lock(Retries) + end; + [] -> + refresh_credentials_with_lock(Retries) + end. + +-spec refresh_credentials_with_lock(Retries :: non_neg_integer()) -> + {ok, access_key(), secret_access_key(), security_token(), region()} | {error, term()}. +refresh_credentials_with_lock(0) -> + {error, lock_timeout}; +refresh_credentials_with_lock(Retries) -> + LockId = {aws_credentials_refresh, node()}, + case global:set_lock(LockId, [node()], 0) of + true -> + try + % Double-check if someone else already refreshed + case ets:lookup(?AWS_CREDENTIALS_TABLE, current) of + [{current, Creds}] -> + case expired_credentials(Creds#aws_credentials.expiration) of + false -> + Region = get_region(), + {ok, Creds#aws_credentials.access_key, + Creds#aws_credentials.secret_key, + Creds#aws_credentials.security_token, Region}; + true -> + do_refresh_credentials() + end; + [] -> + do_refresh_credentials() + end + after + global:del_lock(LockId, [node()]) + end; + false -> + % Someone else is refreshing, wait and retry + timer:sleep(100), + get_credentials(Retries - 1) + end. + +-spec do_refresh_credentials() -> + {ok, access_key(), secret_access_key(), security_token(), region()} | {error, term()}. +do_refresh_credentials() -> + Region = get_region(), + case rabbitmq_aws_config:credentials() of + {ok, AccessKey, SecretAccessKey, Expiration, SecurityToken} -> + Creds = #aws_credentials{ + access_key = AccessKey, + secret_key = SecretAccessKey, + security_token = SecurityToken, + expiration = Expiration + }, + ets:insert(?AWS_CREDENTIALS_TABLE, {current, Creds}), + {ok, AccessKey, SecretAccessKey, SecurityToken, Region}; + {error, Reason} -> + {error, Reason} + end. + +-spec get_region() -> region(). +get_region() -> + case ets:lookup(?AWS_CONFIG_TABLE, region) of + [{region, Region}] -> + Region; + [] -> + % Use proper region detection + case rabbitmq_aws_config:region() of + {ok, DetectedRegion} -> + % Cache the detected region + ets:insert(?AWS_CONFIG_TABLE, {region, DetectedRegion}), + DetectedRegion; + _ -> + % Final fallback + ets:insert(?AWS_CONFIG_TABLE, {region, "us-east-1"}), + "us-east-1" + end + end. + +-spec set_region(Region :: region()) -> ok. +set_region(Region) -> + ets:insert(?AWS_CONFIG_TABLE, {region, Region}), + ok. + +-spec has_credentials() -> boolean(). +has_credentials() -> + case ets:lookup(?AWS_CREDENTIALS_TABLE, current) of + [{current, Creds}] when Creds#aws_credentials.access_key =/= undefined -> + not expired_credentials(Creds#aws_credentials.expiration); + _ -> + false + end. %%==================================================================== %% exported wrapper functions @@ -121,7 +216,10 @@ put(Service, Path, Body, Headers, Options) -> %% @doc Manually refresh the credentials from the environment, filesystem or EC2 Instance Metadata Service. %% @end refresh_credentials() -> - gen_server:call(rabbitmq_aws, refresh_credentials). + case refresh_credentials_with_lock(10) of + {ok, _, _, _, _} -> ok; + {error, _} -> error + end. %%==================================================================== %% New Concurrent API Functions @@ -135,11 +233,16 @@ open_connection(Service) -> -spec open_connection(Service :: string(), Options :: list()) -> {ok, connection_handle()} | {error, term()}. open_connection(Service, Options) -> - gen_server:call(?MODULE, {open_direct_connection, Service, Options}). + % Just get region and open connection - no credential validation needed + Region = get_region(), + Host = endpoint_host(Region, Service), + Port = 443, + GunPid = create_gun_connection(Host, Port, Options), + {ok, {GunPid, Service}}. %% Close a direct connection -spec close_connection(Handle :: connection_handle()) -> ok. -close_connection({GunPid, _CredContext}) -> +close_connection({GunPid, _Service}) -> gun:close(GunPid). -spec direct_request( @@ -150,24 +253,42 @@ close_connection({GunPid, _CredContext}) -> Headers :: headers(), Options :: list() ) -> result(). -direct_request({GunPid, CredContext}, Method, Path, Body, Headers, Options) -> - #{service := Service, region := Region} = CredContext, - % Build URI for signing - Host = endpoint_host(Region, Service), - URI = create_uri(Host, Path), - % Sign headers directly (no gen_server call) - SignedHeaders = sign_headers_with_context(CredContext, Method, URI, Headers, Body), - % Make Gun request directly - direct_gun_request(GunPid, Method, Path, SignedHeaders, Body, Options). +direct_request({GunPid, Service}, Method, Path, Body, Headers, Options) -> + case get_credentials() of + {ok, AccessKey, SecretKey, SecurityToken, Region} -> + Host = endpoint_host(Region, Service), + URI = create_uri(Host, Path), + SignedHeaders = sign_headers( + AccessKey, SecretKey, SecurityToken, Region, Service, Method, URI, Headers, Body + ), + direct_gun_request(GunPid, Method, Path, SignedHeaders, Body, Options); + {error, Reason} -> + {error, Reason} + end. --spec refresh_credentials(state()) -> ok | error. -%% @doc Manually refresh the credentials from the environment, filesystem or EC2 Instance Metadata Service. -%% @end -refresh_credentials(State) -> - ?LOG_DEBUG("Refreshing AWS credentials..."), - {_, NewState} = load_credentials(State), - ?LOG_DEBUG("AWS credentials have been refreshed"), - set_credentials(NewState). +-spec sign_headers( + AccessKey :: access_key(), + SecretKey :: secret_access_key(), + SecurityToken :: security_token(), + Region :: region(), + Service :: string(), + Method :: method(), + URI :: string(), + Headers :: headers(), + Body :: body() +) -> headers(). +sign_headers(AccessKey, SecretKey, SecurityToken, Region, Service, Method, URI, Headers, Body) -> + rabbitmq_aws_sign:headers(#request{ + access_key = AccessKey, + secret_access_key = SecretKey, + security_token = SecurityToken, + region = Region, + service = Service, + method = Method, + uri = URI, + headers = Headers, + body = Body + }). -spec request( Service :: string(), @@ -212,18 +333,12 @@ request(Service, Method, Path, Body, Headers, HTTPOptions) -> %% of services such as DynamoDB. The response will automatically be decoded %% if it is either in JSON or XML format. %% @end -request({GunPid, _CredContext} = Handle, Method, Path, Body, Headers, HTTPOptions, _) when +request({GunPid, Service}, Method, Path, Body, Headers, HTTPOptions, _) when is_pid(GunPid) -> - direct_request(Handle, Method, Path, Body, Headers, HTTPOptions); + direct_request({GunPid, Service}, Method, Path, Body, Headers, HTTPOptions); request(Service, Method, Path, Body, Headers, HTTPOptions, Endpoint) -> - gen_server:call( - rabbitmq_aws, {request, Service, Method, Headers, Path, Body, HTTPOptions, Endpoint} - ). - --spec set_credentials(state()) -> ok. -set_credentials(NewState) -> - gen_server:call(rabbitmq_aws, {set_credentials, NewState}). + perform_request_direct(Service, Method, Headers, Path, Body, HTTPOptions, Endpoint). -spec set_credentials(access_key(), secret_access_key()) -> ok. %% @doc Manually set the access credentials for requests. This should @@ -232,121 +347,61 @@ set_credentials(NewState) -> %% configuration or the AWS Instance Metadata service. %% @end set_credentials(AccessKey, SecretAccessKey) -> - gen_server:call(rabbitmq_aws, {set_credentials, AccessKey, SecretAccessKey}). - --spec set_region(Region :: string()) -> ok. -%% @doc Manually set the AWS region to perform API requests to. -%% @end -set_region(Region) -> - gen_server:call(rabbitmq_aws, {set_region, Region}). - --spec set_imdsv2_token(imdsv2token()) -> ok. -%% @doc Manually set the Imdsv2Token used to perform instance metadata service requests. -%% @end -set_imdsv2_token(Imdsv2Token) -> - gen_server:call(rabbitmq_aws, {set_imdsv2_token, Imdsv2Token}). - --spec get_imdsv2_token() -> imdsv2token() | 'undefined'. -%% @doc return the current Imdsv2Token used to perform instance metadata service requests. -%% @end -get_imdsv2_token() -> - {ok, Imdsv2Token} = gen_server:call(rabbitmq_aws, get_imdsv2_token), - Imdsv2Token. - -%%==================================================================== -%% gen_server functions -%%==================================================================== - -start_link() -> - gen_server:start_link({local, ?MODULE}, ?MODULE, [], []). - --spec init(list()) -> {ok, state()}. -init([]) -> - {ok, _} = application:ensure_all_started(gun), - {ok, #state{}}. - -terminate(_, _State) -> + Creds = #aws_credentials{ + access_key = AccessKey, + secret_key = SecretAccessKey, + security_token = undefined, + expiration = undefined + }, + ets:insert(?AWS_CREDENTIALS_TABLE, {current, Creds}), ok. -code_change(_, _, State) -> - {ok, State}. - -handle_call(Msg, _From, State) -> - handle_msg(Msg, State). - -handle_cast(_Request, State) -> - {noreply, State}. - -handle_info(_Info, State) -> - {noreply, State}. +-spec ensure_credentials_valid() -> ok. +%% @doc Invoked before each AWS service API request to check if the current credentials are available and that they have not expired. +%% If the credentials are available and are still current, then move on and perform the request. +%% If the credentials are not available or have expired, then refresh them before performing the request. +%% @end +ensure_credentials_valid() -> + ?LOG_DEBUG("Making sure AWS credentials are available and still valid"), + case has_credentials() of + true -> + ok; + false -> + refresh_credentials(), + ok + end. -%%==================================================================== -%% Internal functions -%%==================================================================== -handle_msg({request, Service, Method, Headers, Path, Body, Options, Host}, State) -> - {Response, NewState} = perform_request( - State, Service, Method, Headers, Path, Body, Options, Host - ), - {reply, Response, NewState}; -handle_msg({open_direct_connection, Service, Options}, State) -> - case ensure_credentials_valid_internal(State) of - {ok, ValidState} -> - case create_direct_connection(ValidState, Service, Options) of - {ok, Handle} -> - {reply, {ok, Handle}, ValidState}; - {error, Reason} -> - {reply, {error, Reason}, ValidState} - end; +-spec perform_request_direct( + Service :: string(), + Method :: method(), + Headers :: headers(), + Path :: path(), + Body :: body(), + Options :: http_options(), + Host :: string() | undefined +) -> result(). +perform_request_direct(Service, Method, Headers, Path, Body, Options, Host) -> + case get_credentials() of + {ok, AccessKey, SecretKey, SecurityToken, Region} -> + URI = endpoint(Region, Host, Service, Path), + SignedHeaders = sign_headers( + AccessKey, SecretKey, SecurityToken, Region, Service, Method, URI, Headers, Body + ), + gun_request(Method, URI, SignedHeaders, Body, Options); {error, Reason} -> - {reply, {error, Reason}, State} - end; -handle_msg(get_state, State) -> - {reply, {ok, State}, State}; -handle_msg(refresh_credentials, State) -> - {Reply, NewState} = load_credentials(State), - {reply, Reply, NewState}; -handle_msg({set_credentials, AccessKey, SecretAccessKey}, State) -> - {reply, ok, State#state{ - access_key = AccessKey, - secret_access_key = SecretAccessKey, - security_token = undefined, - expiration = undefined, - error = undefined - }}; -handle_msg({set_credentials, NewState}, State) -> - {reply, ok, State#state{ - access_key = NewState#state.access_key, - secret_access_key = NewState#state.secret_access_key, - security_token = NewState#state.security_token, - expiration = NewState#state.expiration, - error = NewState#state.error - }}; -handle_msg({set_region, Region}, State) -> - {reply, ok, State#state{region = Region}}; -handle_msg({set_imdsv2_token, Imdsv2Token}, State) -> - {reply, ok, State#state{imdsv2_token = Imdsv2Token}}; -handle_msg(has_credentials, State) -> - {reply, has_credentials(State), State}; -handle_msg(get_imdsv2_token, State) -> - {reply, {ok, State#state.imdsv2_token}, State}; -handle_msg(_Request, State) -> - {noreply, State}. + {error, {credentials, Reason}} + end. -spec endpoint( - State :: state(), - Host :: string(), + Region :: region(), + Host :: string() | undefined, Service :: string(), Path :: string() ) -> string(). -%% @doc Return the endpoint URL, either by constructing it with the service -%% information passed in, or by using the passed in Host value. -%% @ednd -endpoint(#state{region = Region}, undefined, Service, Path) -> +endpoint(Region, undefined, Service, Path) -> lists:flatten(["https://", endpoint_host(Region, Service), Path]); endpoint(_, Host, _, Path) -> lists:flatten(["https://", Host, Path]). - --spec endpoint_host(Region :: region(), Service :: string()) -> host(). %% @doc Construct the endpoint hostname for the request based upon the service %% and region. %% @end @@ -392,18 +447,6 @@ get_content_type(Headers) -> end, parse_content_type(Value). --spec has_credentials() -> boolean(). -has_credentials() -> - gen_server:call(rabbitmq_aws, has_credentials). - --spec has_credentials(state()) -> boolean(). -%% @doc check to see if there are credentials made available in the current state -%% returning false if not or if they have expired. -%% @end -has_credentials(#state{error = Error}) when Error /= undefined -> false; -has_credentials(#state{access_key = Key}) when Key /= undefined -> true; -has_credentials(_) -> false. - -spec expired_credentials(Expiration :: calendar:datetime()) -> boolean(). %% @doc Indicates if the date that is passed in has expired. %% end @@ -414,40 +457,6 @@ expired_credentials(Expiration) -> Expires = calendar:datetime_to_gregorian_seconds(Expiration), Now >= Expires. --spec load_credentials(State :: state()) -> {ok, state()} | {error, state()}. -%% @doc Load the credentials using the following order of configuration precedence: -%% - Environment variables -%% - Credentials file -%% - EC2 Instance Metadata Service -%% @end -load_credentials(#state{region = Region}) -> - case rabbitmq_aws_config:credentials() of - {ok, AccessKey, SecretAccessKey, Expiration, SecurityToken} -> - {ok, #state{ - region = Region, - error = undefined, - access_key = AccessKey, - secret_access_key = SecretAccessKey, - expiration = Expiration, - security_token = SecurityToken, - imdsv2_token = undefined - }}; - {error, Reason} -> - ?LOG_ERROR( - "Could not load AWS credentials from environment variables, AWS_CONFIG_FILE, AWS_SHARED_CREDENTIALS_FILE or EC2 metadata endpoint: ~tp. Will depend on config settings to be set~n", - [Reason] - ), - {error, #state{ - region = Region, - error = Reason, - access_key = undefined, - secret_access_key = undefined, - expiration = undefined, - security_token = undefined, - imdsv2_token = undefined - }} - end. - -spec local_time() -> calendar:datetime(). %% @doc Return the current local time. %% @end @@ -480,180 +489,37 @@ parse_content_type(ContentType) -> [Type, Subtype] = string:tokens(lists:nth(1, Parts), "/"), {Type, Subtype}. --spec perform_request( - State :: state(), - Service :: string(), - Method :: method(), - Headers :: headers(), - Path :: path(), - Body :: body(), - Options :: http_options(), - Host :: string() | undefined -) -> - {Result :: result(), NewState :: state()}. -%% @doc Make the API request and return the formatted response. -%% @end -perform_request(State, Service, Method, Headers, Path, Body, Options, Host) -> - perform_request_has_creds( - has_credentials(State), - State, - Service, - Method, - Headers, - Path, - Body, - Options, - Host - ). - --spec perform_request_has_creds( - HasCreds :: boolean(), - State :: state(), - Service :: string(), - Method :: method(), - Headers :: headers(), - Path :: path(), - Body :: body(), - Options :: http_options(), - Host :: string() | undefined -) -> - {Result :: result(), NewState :: state()}. -%% @doc Invoked after checking to see if there are credentials. If there are, -%% validate they have not or will not expire, performing the request if not, -%% otherwise return an error result. -%% @end -perform_request_has_creds(true, State, Service, Method, Headers, Path, Body, Options, Host) -> - perform_request_creds_expired( - expired_credentials(State#state.expiration), - State, - Service, - Method, - Headers, - Path, - Body, - Options, - Host - ); -perform_request_has_creds(false, State, _, _, _, _, _, _, _) -> - perform_request_creds_error(State). - --spec perform_request_creds_expired( - CredsExp :: boolean(), - State :: state(), - Service :: string(), - Method :: method(), - Headers :: headers(), - Path :: path(), - Body :: body(), - Options :: http_options(), - Host :: string() | undefined -) -> - {Result :: result(), NewState :: state()}. -%% @doc Invoked after checking to see if the current credentials have expired. -%% If they haven't, perform the request, otherwise try and refresh the -%% credentials before performing the request. -%% @end -perform_request_creds_expired(false, State, Service, Method, Headers, Path, Body, Options, Host) -> - perform_request_with_creds(State, Service, Method, Headers, Path, Body, Options, Host); -perform_request_creds_expired(true, State, _, _, _, _, _, _, _) -> - perform_request_creds_error(State#state{error = "Credentials expired!"}). - --spec perform_request_with_creds( - State :: state(), - Service :: string(), - Method :: method(), - Headers :: headers(), - Path :: path(), - Body :: body(), - Options :: http_options(), - Host :: string() | undefined -) -> - {Result :: result(), NewState :: state()}. -%% @doc Once it is validated that there are credentials to try and that they have not -%% expired, perform the request and return the response. -%% @end -perform_request_with_creds(State, Service, Method, Headers, Path, Body, Options, Host) -> - URI = endpoint(State, Host, Service, Path), - SignedHeaders = sign_headers(State, Service, Method, URI, Headers, Body), - perform_request_with_creds(State, Method, URI, SignedHeaders, Body, Options). - --spec perform_request_with_creds( - State :: state(), - Method :: method(), - URI :: string(), - Headers :: headers(), - Body :: body(), - Options :: http_options() -) -> - {Result :: result(), NewState :: state()}. -%% @doc Once it is validated that there are credentials to try and that they have not -%% expired, perform the request and return the response. -%% @end -perform_request_with_creds(State, Method, URI, Headers, "", Options0) -> - Response = gun_request(Method, URI, Headers, <<>>, Options0), - {Response, State}; -perform_request_with_creds(State, Method, URI, Headers, Body, Options0) -> - Response = gun_request(Method, URI, Headers, Body, Options0), - {Response, State}. - --spec perform_request_creds_error(State :: state()) -> - {result_error(), NewState :: state()}. -%% @doc Return the error response when there are not any credentials to use with -%% the request. -%% @end -perform_request_creds_error(State) -> - {{error, {credentials, State#state.error}}, State}. - --spec sign_headers( - State :: state(), - Service :: string(), - Method :: method(), - URI :: string(), - Headers :: headers(), - Body :: body() -) -> headers(). -%% @doc Build the signed headers for the API request. -%% @end -sign_headers( - #state{ - access_key = AccessKey, - secret_access_key = SecretKey, - security_token = SecurityToken, - region = Region - }, - Service, - Method, - URI, - Headers, - Body -) -> - rabbitmq_aws_sign:headers(#request{ - access_key = AccessKey, - secret_access_key = SecretKey, - security_token = SecurityToken, - region = Region, - service = Service, - method = Method, - uri = URI, - headers = Headers, - body = Body - }). - -spec expired_imdsv2_token('undefined' | imdsv2token()) -> boolean(). %% @doc Determine whether or not an Imdsv2Token has expired. %% @end expired_imdsv2_token(undefined) -> ?LOG_DEBUG("EC2 IMDSv2 token has not yet been obtained"), true; -expired_imdsv2_token({_, _, undefined}) -> +expired_imdsv2_token(#imdsv2token{expiration = undefined}) -> ?LOG_DEBUG("EC2 IMDSv2 token is not available"), true; -expired_imdsv2_token({_, _, Expiration}) -> +expired_imdsv2_token(#imdsv2token{expiration = Expiration}) -> Now = calendar:datetime_to_gregorian_seconds(local_time()), HasExpired = Now >= Expiration, ?LOG_DEBUG("EC2 IMDSv2 token has expired: ~tp", [HasExpired]), HasExpired. +-spec get_imdsv2_token() -> imdsv2token() | 'undefined'. +%% @doc return the current Imdsv2Token used to perform instance metadata service requests. +%% @end +get_imdsv2_token() -> + case ets:lookup(?AWS_CONFIG_TABLE, imdsv2_token) of + [{imdsv2_token, Token}] -> Token; + [] -> undefined + end. + +-spec set_imdsv2_token(imdsv2token()) -> ok. +%% @doc Manually set the Imdsv2Token used to perform instance metadata service requests. +%% @end +set_imdsv2_token(Imdsv2Token) -> + ets:insert(?AWS_CONFIG_TABLE, {imdsv2_token, Imdsv2Token}), + ok. + -spec ensure_imdsv2_token_valid() -> security_token(). ensure_imdsv2_token_valid() -> Imdsv2Token = get_imdsv2_token(), @@ -671,24 +537,6 @@ ensure_imdsv2_token_valid() -> Imdsv2Token#imdsv2token.token end. --spec ensure_credentials_valid() -> ok. -%% @doc Invoked before each AWS service API request to check if the current credentials are available and that they have not expired. -%% If the credentials are available and are still current, then move on and perform the request. -%% If the credentials are not available or have expired, then refresh them before performing the request. -%% @end -ensure_credentials_valid() -> - ?LOG_DEBUG("Making sure AWS credentials are available and still valid"), - {ok, State} = gen_server:call(rabbitmq_aws, get_state), - case has_credentials(State) of - true -> - case expired_credentials(State#state.expiration) of - true -> refresh_credentials(State); - _ -> ok - end; - _ -> - refresh_credentials(State) - end. - -spec api_get_request(string(), path()) -> {'ok', list()} | {'error', term()}. %% @doc Invoke an API call to an AWS service. %% @end @@ -825,56 +673,6 @@ status_text(416) -> "Range Not Satisfiable"; status_text(500) -> "Internal Server Error"; status_text(Code) -> integer_to_list(Code). -%%==================================================================== -%% New Concurrent API Helper Functions -%%==================================================================== - -%% Create a direct connection handle --spec create_direct_connection(State :: state(), Service :: string(), Options :: list()) -> - {ok, connection_handle()} | {error, term()}. -create_direct_connection(State, Service, Options) -> - Region = State#state.region, - Host = endpoint_host(Region, Service), - Port = 443, - GunPid = create_gun_connection(Host, Port, Options), - CredContext = #{ - access_key => State#state.access_key, - secret_access_key => State#state.secret_access_key, - security_token => State#state.security_token, - region => Region, - service => Service - }, - {ok, {GunPid, CredContext}}. - -%% Sign headers using credential context (no gen_server state needed) --spec sign_headers_with_context( - CredContext :: credential_context(), - Method :: method(), - URI :: string(), - Headers :: headers(), - Body :: body() -) -> headers(). -sign_headers_with_context(CredContext, Method, URI, Headers, Body) -> - #{ - access_key := AccessKey, - secret_access_key := SecretKey, - security_token := SecurityToken, - region := Region, - service := Service - } = CredContext, - rabbitmq_aws_sign:headers(#request{ - access_key = AccessKey, - secret_access_key = SecretKey, - security_token = SecurityToken, - region = Region, - service = Service, - method = Method, - uri = URI, - headers = Headers, - body = Body - }). - -%% Direct Gun request (extracted from existing gun_request function) -spec direct_gun_request( GunPid :: gun:conn_ref(), Method :: method(), @@ -884,7 +682,7 @@ sign_headers_with_context(CredContext, Method, URI, Headers, Body) -> Options :: list() ) -> result(). direct_gun_request(GunPid, Method, {_, Path}, Headers, Body, Options) -> - direct_gun_request(GunPid, Method, [$/|Path], Headers, Body, Options); + direct_gun_request(GunPid, Method, [$/ | Path], Headers, Body, Options); direct_gun_request(GunPid, Method, Path, Headers, Body, Options) -> HeadersBin = lists:map( fun({Key, Value}) -> @@ -910,16 +708,3 @@ direct_gun_request(GunPid, Method, Path, Headers, Body, Options) -> {error, Error} end, format_response(Response). - -%% Internal credential validation (extracted from existing logic) --spec ensure_credentials_valid_internal(State :: state()) -> {ok, state()} | {error, term()}. -ensure_credentials_valid_internal(State) -> - case has_credentials(State) of - true -> - case expired_credentials(State#state.expiration) of - false -> {ok, State}; - true -> load_credentials(State) - end; - false -> - load_credentials(State) - end. diff --git a/deps/rabbitmq_aws/src/rabbitmq_aws_app.erl b/deps/rabbitmq_aws/src/rabbitmq_aws_app.erl index 543c8f56282d..89a0648ab892 100644 --- a/deps/rabbitmq_aws/src/rabbitmq_aws_app.erl +++ b/deps/rabbitmq_aws/src/rabbitmq_aws_app.erl @@ -1,21 +1,11 @@ -%% ==================================================================== -%% @author Gavin M. Roy -%% @copyright 2016, Gavin M. Roy -%% @doc rabbitmq_aws application startup -%% @end -%% ==================================================================== -module(rabbitmq_aws_app). - -behaviour(application). -%% Application callbacks -export([start/2, stop/1]). -%% =================================================================== -%% Application callbacks -%% =================================================================== - -start(_StartType, _StartArgs) -> +start(_Type, _Args) -> + ets:new(aws_credentials, [named_table, public, {read_concurrency, true}]), + ets:new(aws_config, [named_table, public, {read_concurrency, true}]), rabbitmq_aws_sup:start_link(). stop(_State) -> diff --git a/deps/rabbitmq_aws/src/rabbitmq_aws_config.erl b/deps/rabbitmq_aws/src/rabbitmq_aws_config.erl index 4ba821249a99..39e3d7685137 100644 --- a/deps/rabbitmq_aws/src/rabbitmq_aws_config.erl +++ b/deps/rabbitmq_aws/src/rabbitmq_aws_config.erl @@ -535,14 +535,37 @@ lookup_credentials_from_proplist(_, undefined) -> lookup_credentials_from_proplist(AccessKey, SecretKey) -> {ok, AccessKey, SecretKey, undefined, undefined}. +-spec with_metadata_connection(fun((gun:conn_ref()) -> Result)) -> Result. +%% @doc Execute a function with a shared metadata service connection +%% @end +with_metadata_connection(Fun) -> + {Host, Port, _} = rabbitmq_aws:parse_uri(instance_metadata_url("")), + Opts = #{transport => tcp, protocols => [http]}, + case gun:open(Host, Port, Opts) of + {ok, ConnPid} -> + case gun:await_up(ConnPid, 5000) of + {ok, _Protocol} -> + Result = Fun(ConnPid), + gun:close(ConnPid), + Result; + {error, Reason} -> + gun:close(ConnPid), + {error, Reason} + end; + {error, Reason} -> + {error, Reason} + end. + -spec lookup_credentials_from_instance_metadata() -> security_credentials(). %% @spec lookup_credentials_from_instance_metadata() -> Result. %% @doc Attempt to lookup the values from the EC2 instance metadata service. %% @end lookup_credentials_from_instance_metadata() -> - Role = maybe_get_role_from_instance_metadata(), - maybe_get_credentials_from_instance_metadata(Role). + with_metadata_connection(fun(ConnPid) -> + Role = maybe_get_role_from_instance_metadata_with_conn(ConnPid), + maybe_get_credentials_from_instance_metadata_with_conn(ConnPid, Role) + end). -spec lookup_region( Profile :: string(), @@ -595,19 +618,21 @@ maybe_convert_number(Value) -> F end. --spec maybe_get_credentials_from_instance_metadata( +-spec maybe_get_credentials_from_instance_metadata_with_conn( + ConnPid :: gun:conn_ref(), {ok, Role :: string()} | {error, undefined} ) -> {'ok', security_credentials()} | {'error', term()}. %% @doc Try to query the EC2 local instance metadata service to get temporary -%% authentication credentials. +%% authentication credentials using an existing connection. %% @end -maybe_get_credentials_from_instance_metadata({error, undefined}) -> +maybe_get_credentials_from_instance_metadata_with_conn(_, {error, undefined}) -> {error, undefined}; -maybe_get_credentials_from_instance_metadata({ok, Role}) -> +maybe_get_credentials_from_instance_metadata_with_conn(ConnPid, {ok, Role}) -> URL = instance_credentials_url(Role), - parse_credentials_response(perform_http_get_instance_metadata(URL)). + {_, _, Path} = rabbitmq_aws:parse_uri(URL), + parse_credentials_response(perform_http_get_with_conn(ConnPid, Path)). -spec maybe_get_region_from_instance_metadata() -> {ok, Region :: string()} | {error, Reason :: atom()}. @@ -617,12 +642,29 @@ maybe_get_region_from_instance_metadata() -> URL = instance_availability_zone_url(), parse_az_response(perform_http_get_instance_metadata(URL)). +-spec perform_http_get_with_conn(gun:conn_ref(), string()) -> httpc_result(). +%% @doc Make HTTP GET request using existing Gun connection +%% @end +perform_http_get_with_conn(ConnPid, Path) -> + Headers = instance_metadata_request_headers(), + StreamRef = gun:get(ConnPid, Path, Headers), + case gun:await(ConnPid, StreamRef, ?DEFAULT_HTTP_TIMEOUT) of + {response, fin, Status, RespHeaders} -> + {ok, {{http_version, Status, rabbitmq_aws:status_text(Status)}, RespHeaders, <<>>}}; + {response, nofin, Status, RespHeaders} -> + {ok, Body} = gun:await_body(ConnPid, StreamRef, ?DEFAULT_HTTP_TIMEOUT), + {ok, {{http_version, Status, rabbitmq_aws:status_text(Status)}, RespHeaders, Body}}; + {error, Reason} -> + {error, Reason} + end. + %% @doc Try to query the EC2 local instance metadata service to get the role -%% assigned to the instance. +%% assigned to the instance using an existing connection. %% @end -maybe_get_role_from_instance_metadata() -> +maybe_get_role_from_instance_metadata_with_conn(ConnPid) -> URL = instance_role_url(), - parse_body_response(perform_http_get_instance_metadata(URL)). + {_, _, Path} = rabbitmq_aws:parse_uri(URL), + parse_body_response(perform_http_get_with_conn(ConnPid, Path)). -spec parse_az_response(httpc_result()) -> {ok, Region :: string()} | {error, Reason :: atom()}. diff --git a/deps/rabbitmq_aws/src/rabbitmq_aws_sup.erl b/deps/rabbitmq_aws/src/rabbitmq_aws_sup.erl index 7c4900f7abb6..87297ef5a8a5 100644 --- a/deps/rabbitmq_aws/src/rabbitmq_aws_sup.erl +++ b/deps/rabbitmq_aws/src/rabbitmq_aws_sup.erl @@ -1,7 +1,7 @@ %% ==================================================================== %% @author Gavin M. Roy %% @copyright 2016, Gavin M. Roy -%% @doc rabbitmq_aws supervisor for the gen_server process +%% @doc rabbitmq_aws supervisor for ETS table owner process %% @end %% ==================================================================== -module(rabbitmq_aws_sup). @@ -13,10 +13,9 @@ init/1 ]). --define(CHILD(I, Type), {I, {I, start_link, []}, permanent, 5, Type, [I]}). - start_link() -> supervisor:start_link({local, ?MODULE}, ?MODULE, []). init([]) -> - {ok, {{one_for_one, 5, 10}, [?CHILD(rabbitmq_aws, worker)]}}. + % No children needed - just return empty supervisor + {ok, {{one_for_one, 5, 10}, []}}. diff --git a/deps/rabbitmq_aws/test/rabbitmq_aws_config_tests.erl b/deps/rabbitmq_aws/test/rabbitmq_aws_config_tests.erl index fd6c30376c37..7cd4c858c118 100644 --- a/deps/rabbitmq_aws/test/rabbitmq_aws_config_tests.erl +++ b/deps/rabbitmq_aws/test/rabbitmq_aws_config_tests.erl @@ -2,6 +2,14 @@ -include_lib("eunit/include/eunit.hrl"). +%% Helper function to mock gun for IMDSv2 failure scenarios +mock_gun_imdsv2_failure() -> + meck:expect(gun, open, fun(_, _, _) -> {ok, fake_conn} end), + meck:expect(gun, await_up, fun(_, _) -> {ok, http} end), + meck:expect(gun, get, fun(_, _, _) -> fake_stream end), + meck:expect(gun, await, fun(_, _, _) -> {response, fin, 404, []} end), + meck:expect(gun, close, fun(_) -> ok end). + -include("rabbitmq_aws.hrl"). config_file_test_() -> @@ -145,6 +153,8 @@ credentials_test_() -> {"with missing environment variable", fun() -> os:putenv("AWS_ACCESS_KEY_ID", "Sésame"), meck:sequence(rabbitmq_aws, ensure_imdsv2_token_valid, 0, "secret_imdsv2_token"), + mock_gun_imdsv2_failure(), + ?assertEqual( {error, undefined}, rabbitmq_aws_config:credentials() @@ -167,6 +177,8 @@ credentials_test_() -> {"from config file with bad profile", fun() -> setup_test_config_env_var(), meck:expect(rabbitmq_aws, ensure_imdsv2_token_valid, 0, undefined), + mock_gun_imdsv2_failure(), + ?assertEqual( {error, undefined}, rabbitmq_aws_config:credentials("bad-profile-name") @@ -190,6 +202,8 @@ credentials_test_() -> {"from credentials file with bad profile", fun() -> setup_test_credentials_env_var(), meck:expect(rabbitmq_aws, ensure_imdsv2_token_valid, 0, undefined), + mock_gun_imdsv2_failure(), + ?assertEqual( {error, undefined}, rabbitmq_aws_config:credentials("bad-profile-name") @@ -198,6 +212,8 @@ credentials_test_() -> {"from credentials file with only the key in profile", fun() -> setup_test_credentials_env_var(), meck:expect(rabbitmq_aws, ensure_imdsv2_token_valid, 0, undefined), + mock_gun_imdsv2_failure(), + ?assertEqual( {error, undefined}, rabbitmq_aws_config:credentials("only-key") @@ -206,6 +222,8 @@ credentials_test_() -> {"from credentials file with only the value in profile", fun() -> setup_test_credentials_env_var(), meck:expect(rabbitmq_aws, ensure_imdsv2_token_valid, 0, undefined), + mock_gun_imdsv2_failure(), + ?assertEqual( {error, undefined}, rabbitmq_aws_config:credentials("only-value") @@ -214,6 +232,8 @@ credentials_test_() -> {"from credentials file with missing keys in profile", fun() -> setup_test_credentials_env_var(), meck:expect(rabbitmq_aws, ensure_imdsv2_token_valid, 0, undefined), + mock_gun_imdsv2_failure(), + ?assertEqual( {error, undefined}, rabbitmq_aws_config:credentials("bad-entry") @@ -252,7 +272,7 @@ credentials_test_() -> end}, {"with instance metadata service role error", fun() -> meck:expect(rabbitmq_aws, ensure_imdsv2_token_valid, 0, undefined), - meck:expect(gun, open, fun(_, _, _) -> {error, timeout} end), + mock_gun_imdsv2_failure(), ?assertEqual({error, undefined}, rabbitmq_aws_config:credentials()) end}, {"with instance metadata service role http error", fun() -> diff --git a/deps/rabbitmq_aws/test/rabbitmq_aws_sup_tests.erl b/deps/rabbitmq_aws/test/rabbitmq_aws_sup_tests.erl deleted file mode 100644 index fdb54facb75a..000000000000 --- a/deps/rabbitmq_aws/test/rabbitmq_aws_sup_tests.erl +++ /dev/null @@ -1,31 +0,0 @@ --module(rabbitmq_aws_sup_tests). - --include_lib("eunit/include/eunit.hrl"). - -start_link_test_() -> - {foreach, - fun() -> - meck:new(supervisor, [passthrough, unstick]) - end, - fun(_) -> - meck:unload(supervisor) - end, - [ - {"supervisor start_link", fun() -> - meck:expect(supervisor, start_link, fun(_, _, _) -> {ok, test_result} end), - ?assertEqual( - {ok, test_result}, - rabbitmq_aws_sup:start_link() - ), - meck:validate(supervisor) - end} - ]}. - -init_test() -> - ?assertEqual( - {ok, - {{one_for_one, 5, 10}, [ - {rabbitmq_aws, {rabbitmq_aws, start_link, []}, permanent, 5, worker, [rabbitmq_aws]} - ]}}, - rabbitmq_aws_sup:init([]) - ). diff --git a/deps/rabbitmq_aws/test/rabbitmq_aws_tests.erl b/deps/rabbitmq_aws/test/rabbitmq_aws_tests.erl index 66c23e0f65cc..e25db57e30d8 100644 --- a/deps/rabbitmq_aws/test/rabbitmq_aws_tests.erl +++ b/deps/rabbitmq_aws/test/rabbitmq_aws_tests.erl @@ -4,13 +4,40 @@ -include("rabbitmq_aws.hrl"). +%% Test helper functions +setup() -> + application:ensure_all_started(rabbitmq_aws), + ok. + +teardown(_) -> + application:stop(rabbitmq_aws), + ok. + +% Helper to populate test credentials +set_test_credentials(AccessKey, SecretKey) -> + set_test_credentials(AccessKey, SecretKey, undefined, undefined). + +set_test_credentials(AccessKey, SecretKey, SecurityToken, Expiration) -> + Creds = #aws_credentials{ + access_key = AccessKey, + secret_key = SecretKey, + security_token = SecurityToken, + expiration = Expiration + }, + ets:insert(?AWS_CREDENTIALS_TABLE, {current, Creds}). + +set_test_region(Region) -> + ets:insert(?AWS_CONFIG_TABLE, {region, Region}). + init_test_() -> {foreach, fun() -> os:putenv("AWS_DEFAULT_REGION", "us-west-3"), - meck:new(rabbitmq_aws_config, [passthrough]) + meck:new(rabbitmq_aws_config, [passthrough]), + setup() end, fun(_) -> + teardown(ok), os:unsetenv("AWS_DEFAULT_REGION"), meck:unload(rabbitmq_aws_config) end, @@ -18,45 +45,26 @@ init_test_() -> {"ok", fun() -> os:putenv("AWS_ACCESS_KEY_ID", "Sésame"), os:putenv("AWS_SECRET_ACCESS_KEY", "ouvre-toi"), - {ok, Pid} = rabbitmq_aws:start_link(), - rabbitmq_aws:set_region("us-west-3"), - rabbitmq_aws:refresh_credentials(), - {ok, State} = gen_server:call(Pid, get_state), - ok = gen_server:stop(Pid), + ?assertEqual(ok, rabbitmq_aws:refresh_credentials()), + % Verify credentials were actually stored + ?assertEqual(true, rabbitmq_aws:has_credentials()), + {ok, AccessKey, SecretKey, SecurityToken, Region} = rabbitmq_aws:get_credentials(), + ?assertEqual("Sésame", AccessKey), + ?assertEqual("ouvre-toi", SecretKey), + ?assertEqual(undefined, SecurityToken), + ?assertEqual("us-west-3", Region), os:unsetenv("AWS_ACCESS_KEY_ID"), - os:unsetenv("AWS_SECRET_ACCESS_KEY"), - Expectation = - {state, "Sésame", "ouvre-toi", undefined, undefined, "us-west-3", undefined, - undefined}, - ?assertEqual(Expectation, State) + os:unsetenv("AWS_SECRET_ACCESS_KEY") end}, {"error", fun() -> meck:expect(rabbitmq_aws_config, credentials, fun() -> {error, test_result} end), - {ok, Pid} = rabbitmq_aws:start_link(), - rabbitmq_aws:set_region("us-west-3"), - rabbitmq_aws:refresh_credentials(), - {ok, State} = gen_server:call(Pid, get_state), - ok = gen_server:stop(Pid), - Expectation = - {state, undefined, undefined, undefined, undefined, "us-west-3", undefined, - test_result}, - ?assertEqual(Expectation, State), + ?assertEqual(error, rabbitmq_aws:refresh_credentials()), + % Verify no credentials were stored + ?assertEqual(false, rabbitmq_aws:has_credentials()), meck:validate(rabbitmq_aws_config) end} ]}. -terminate_test() -> - ?assertEqual( - ok, - rabbitmq_aws:terminate( - foo, - {state, undefined, undefined, undefined, undefined, "us-west-3", undefined, test_result} - ) - ). - -code_change_test() -> - ?assertEqual({ok, {state, denial}}, rabbitmq_aws:code_change(foo, bar, {state, denial})). - endpoint_test_() -> [ {"specified", fun() -> @@ -66,7 +74,7 @@ endpoint_test_() -> Host = "localhost:32767", Expectation = "https://localhost:32767/", ?assertEqual( - Expectation, rabbitmq_aws:endpoint(#state{region = Region}, Host, Service, Path) + Expectation, rabbitmq_aws:endpoint(Region, Host, Service, Path) ) end}, {"unspecified", fun() -> @@ -76,7 +84,7 @@ endpoint_test_() -> Host = undefined, Expectation = "https://dynamodb.us-east-3.amazonaws.com/", ?assertEqual( - Expectation, rabbitmq_aws:endpoint(#state{region = Region}, Host, Service, Path) + Expectation, rabbitmq_aws:endpoint(Region, Host, Service, Path) ) end} ]. @@ -160,163 +168,6 @@ format_response_test_() -> end} ]. -gen_server_call_test_() -> - { - foreach, - fun() -> - % We explicitely set a few defaults, in case the caller has - % something in ~/.aws. - os:putenv("AWS_DEFAULT_REGION", "us-west-3"), - os:putenv("AWS_ACCESS_KEY_ID", "Sésame"), - os:putenv("AWS_SECRET_ACCESS_KEY", "ouvre-toi"), - meck:new(gun, []), - [gun] - end, - fun(Mods) -> - meck:unload(Mods), - os:unsetenv("AWS_DEFAULT_REGION"), - os:unsetenv("AWS_ACCESS_KEY_ID"), - os:unsetenv("AWS_SECRET_ACCESS_KEY") - end, - [ - { - "request", - fun() -> - State = #state{ - access_key = "AKIDEXAMPLE", - secret_access_key = "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY", - region = "us-east-1" - }, - Service = "ec2", - Method = get, - Headers = [], - Path = "/?Action=DescribeTags&Version=2015-10-01", - Body = "", - Options = [], - Host = undefined, - meck:expect(gun, open, fun(_, _, _) -> {ok, pid} end), - meck:expect(gun, close, fun(_) -> ok end), - meck:expect(gun, await_up, fun(_, _) -> {ok, protocol} end), - meck:expect( - gun, - get, - fun(_Pid, _Path, _Headers) -> nofin end - ), - %% {ok, {{"HTTP/1.0", 200, "OK"}, [{"content-type", "application/json"}], "{\"pass\": true}"}} - %% end), - meck:expect( - gun, - await, - fun(_Pid, _, _) -> - {response, nofin, 200, [{<<"content-type">>, <<"application/json">>}]} - end - ), - meck:expect( - gun, - await_body, - fun(_Pid, _, _) -> {ok, <<"{\"pass\": true}">>} end - ), - - %% {ok, {{"HTTP/1.0", 200, "OK"}, [{"content-type", "application/json"}], "{\"pass\": true}"}} - %% end), - Expectation = - {reply, - {ok, - {[{<<"content-type">>, <<"application/json">>}], [{"pass", true}]}}, - State}, - Result = rabbitmq_aws:handle_call( - {request, Service, Method, Headers, Path, Body, Options, Host}, eunit, State - ), - ?assertEqual(Expectation, Result), - meck:validate(gun) - end - }, - { - "get_state", - fun() -> - State = #state{ - access_key = "AKIDEXAMPLE", - secret_access_key = "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY", - region = "us-east-1" - }, - ?assertEqual( - {reply, {ok, State}, State}, - rabbitmq_aws:handle_call(get_state, eunit, State) - ) - end - }, - { - "refresh_credentials", - fun() -> - State = #state{ - access_key = "AKIDEXAMPLE", - secret_access_key = "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY", - region = "us-east-1" - }, - State2 = #state{ - access_key = "AKIDEXAMPLE2", - secret_access_key = "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY2", - region = "us-east-1", - security_token = - "AQoEXAMPLEH4aoAH0gNCAPyJxz4BlCFFxWNE1OPTgk5TthT+FvwqnKwRcOIfrRh3c/L2", - expiration = calendar:local_time() - }, - meck:new(rabbitmq_aws_config, [passthrough]), - meck:expect( - rabbitmq_aws_config, - credentials, - fun() -> - {ok, State2#state.access_key, State2#state.secret_access_key, - State2#state.expiration, State2#state.security_token} - end - ), - ?assertEqual( - {reply, ok, State2}, - rabbitmq_aws:handle_call(refresh_credentials, eunit, State) - ), - meck:validate(rabbitmq_aws_config), - meck:unload(rabbitmq_aws_config) - end - }, - { - "set_credentials", - fun() -> - State = #state{ - access_key = "AKIDEXAMPLE", - secret_access_key = "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY", - region = "us-west-3" - }, - ?assertEqual( - {reply, ok, State}, - rabbitmq_aws:handle_call( - {set_credentials, State#state.access_key, - State#state.secret_access_key}, - eunit, - #state{region = "us-west-3"} - ) - ) - end - }, - { - "set_region", - fun() -> - State = #state{ - access_key = "Sésame", - secret_access_key = "ouvre-toi", - region = "us-east-5" - }, - ?assertEqual( - {reply, ok, State}, - rabbitmq_aws:handle_call({set_region, "us-east-5"}, eunit, #state{ - access_key = "Sésame", - secret_access_key = "ouvre-toi" - }) - ) - end - } - ] - }. - get_content_type_test_() -> [ {"from headers caps", fun() -> @@ -332,14 +183,21 @@ get_content_type_test_() -> ]. has_credentials_test_() -> - [ - {"true", fun() -> - ?assertEqual(true, rabbitmq_aws:has_credentials(#state{access_key = "TESTVALUE1"})) - end}, - {"false", fun() -> - ?assertEqual(false, rabbitmq_aws:has_credentials(#state{error = "ERROR"})) - end} - ]. + { + foreach, + fun setup/0, + fun teardown/1, + [ + {"true", fun() -> + set_test_credentials("TESTVALUE1", "SECRET"), + ?assertEqual(true, rabbitmq_aws:has_credentials()) + end}, + {"false", fun() -> + % No credentials set + ?assertEqual(false, rabbitmq_aws:has_credentials()) + end} + ] + }. local_time_test_() -> { @@ -406,33 +264,30 @@ perform_request_test_() -> { foreach, fun() -> + setup(), meck:new(gun, []), - meck:new(rabbitmq_aws_config, []), - [gun, rabbitmq_aws_config] + [gun] + end, + fun(Mods) -> + teardown(ok), + meck:unload(Mods) end, - fun meck:unload/1, [ { - "has_credentials true", + "Successfull run", fun() -> - State = #state{ - access_key = "AKIDEXAMPLE", - secret_access_key = "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY", - region = "us-east-1" - }, + set_test_credentials("AKIDEXAMPLE", "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY"), + set_test_region("us-east-1"), Service = "ec2", Method = get, Headers = [], Path = "/?Action=DescribeTags&Version=2015-10-01", Body = "", Options = [], - Host = undefined, - ExpectURI = - "https://ec2.us-east-1.amazonaws.com/?Action=DescribeTags&Version=2015-10-01", + meck:expect(gun, open, fun(_, _, _) -> {ok, pid} end), meck:expect(gun, close, fun(_) -> ok end), meck:expect(gun, await_up, fun(_, _) -> {ok, protocol} end), - meck:expect( gun, get, @@ -451,71 +306,12 @@ perform_request_test_() -> fun(_Pid, _, _) -> {ok, <<"{\"pass\": true}">>} end ), - Expectation = { + Expectation = {ok, {[{<<"content-type">>, <<"application/json">>}], [{"pass", true}]}}, - State - }, - Result = rabbitmq_aws:perform_request( - State, Service, Method, Headers, Path, Body, Options, Host - ), + Result = rabbitmq_aws:request(Service, Method, Path, Body, Headers, Options), ?assertEqual(Expectation, Result), meck:validate(gun) end - }, - { - "has_credentials false", - fun() -> - State = #state{region = "us-east-1"}, - Service = "ec2", - Method = get, - Headers = [], - Path = "/?Action=DescribeTags&Version=2015-10-01", - Body = "", - Options = [], - Host = undefined, - Expectation = {{error, {credentials, State#state.error}}, State}, - Result = rabbitmq_aws:perform_request( - State, Service, Method, Headers, Path, Body, Options, Host - ), - ?assertEqual(Expectation, Result) - end - }, - { - "has expired credentials", - fun() -> - State = #state{ - access_key = "AKIDEXAMPLE", - secret_access_key = "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY", - region = "us-east-1", - security_token = - "AQoEXAMPLEH4aoAH0gNCAPyJxz4BlCFFxWNE1OPTgk5TthT+FvwqnKwRcOIfrRh3c/L", - expiration = {{1973, 1, 1}, {10, 20, 30}} - }, - Service = "ec2", - Method = get, - Headers = [], - Path = "/?Action=DescribeTags&Version=2015-10-01", - Body = "", - Options = [], - Host = undefined, - meck:expect(rabbitmq_aws_config, credentials, fun() -> {error, unit_test} end), - Expectation = {{error, {credentials, "Credentials expired!"}}, State#state{ - error = "Credentials expired!" - }}, - Result = rabbitmq_aws:perform_request( - State, Service, Method, Headers, Path, Body, Options, Host - ), - ?assertEqual(Expectation, Result), - meck:validate(rabbitmq_aws_config) - end - }, - { - "creds_error", - fun() -> - State = #state{error = unit_test}, - Expectation = {{error, {credentials, State#state.error}}, State}, - ?assertEqual(Expectation, rabbitmq_aws:perform_request_creds_error(State)) - end } ] }. @@ -532,13 +328,11 @@ sign_headers_test_() -> {"with security token", fun() -> Value = {{2016, 5, 1}, {12, 0, 0}}, meck:expect(calendar, local_time_to_universal_time_dst, fun(_) -> [Value] end), - State = #state{ - access_key = "AKIDEXAMPLE", - secret_access_key = "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY", - security_token = - "AQoEXAMPLEH4aoAH0gNCAPyJxz4BlCFFxWNE1OPTgk5TthT+FvwqnKwRcOIfrRh3c/L", - region = "us-east-1" - }, + AccessKey = "AKIDEXAMPLE", + SecretKey = "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY", + SecurityToken = + "AQoEXAMPLEH4aoAH0gNCAPyJxz4BlCFFxWNE1OPTgk5TthT+FvwqnKwRcOIfrRh3c/L", + Region = "us-east-1", Service = "ec2", Method = get, Headers = [], @@ -557,7 +351,17 @@ sign_headers_test_() -> ], ?assertEqual( Expectation, - rabbitmq_aws:sign_headers(State, Service, Method, URI, Headers, Body) + rabbitmq_aws:sign_headers( + AccessKey, + SecretKey, + SecurityToken, + Region, + Service, + Method, + URI, + Headers, + Body + ) ), meck:validate(calendar) end} @@ -568,19 +372,22 @@ api_get_request_test_() -> { foreach, fun() -> + setup(), meck:new(gun, []), meck:new(rabbitmq_aws_config, []), [gun, rabbitmq_aws_config] end, - fun meck:unload/1, + fun(Mods) -> + teardown(ok), + meck:unload(Mods) + end, [ {"AWS service API request succeeded", fun() -> - State = #state{ - access_key = "ExpiredKey", - secret_access_key = "ExpiredAccessKey", - region = "us-east-1", - expiration = {{3016, 4, 1}, {12, 0, 0}} - }, + set_test_credentials("ExpiredKey", "ExpiredAccessKey", undefined, { + {3016, 4, 1}, {12, 0, 0} + }), + set_test_region("us-east-1"), + meck:expect(gun, open, fun(_, _, _) -> {ok, pid} end), meck:expect(gun, close, fun(_) -> ok end), meck:expect(gun, await_up, fun(_, _) -> {ok, protocol} end), @@ -602,29 +409,24 @@ api_get_request_test_() -> fun(_Pid, _, _) -> {ok, <<"{\"data\": \"value\"}">>} end ), - {ok, Pid} = rabbitmq_aws:start_link(), - rabbitmq_aws:set_region("us-east-1"), - rabbitmq_aws:set_credentials(State), Result = rabbitmq_aws:api_get_request("AWS", "API"), - ok = gen_server:stop(Pid), ?assertEqual({ok, [{"data", "value"}]}, Result), meck:validate(gun) end}, {"AWS service API request failed - credentials", fun() -> + set_test_region("us-east-1"), + % No credentials set - should fail meck:expect(rabbitmq_aws_config, credentials, 0, {error, undefined}), - {ok, Pid} = rabbitmq_aws:start_link(), - rabbitmq_aws:set_region("us-east-1"), + Result = rabbitmq_aws:api_get_request("AWS", "API"), - ok = gen_server:stop(Pid), ?assertEqual({error, credentials}, Result) end}, {"AWS service API request failed - API error with persistent failure", fun() -> - State = #state{ - access_key = "ExpiredKey", - secret_access_key = "ExpiredAccessKey", - region = "us-east-1", - expiration = {{3016, 4, 1}, {12, 0, 0}} - }, + set_test_credentials("ExpiredKey", "ExpiredAccessKey", undefined, { + {3016, 4, 1}, {12, 0, 0} + }), + set_test_region("us-east-1"), + meck:expect(gun, open, fun(_, _, _) -> {ok, spawn(fun() -> ok end)} end), meck:expect(gun, close, fun(_) -> ok end), meck:expect(gun, await_up, fun(_, _) -> {ok, protocol} end), @@ -639,21 +441,16 @@ api_get_request_test_() -> fun(_Pid, _, _) -> {error, "network error"} end ), - {ok, Pid} = rabbitmq_aws:start_link(), - rabbitmq_aws:set_region("us-east-1"), - rabbitmq_aws:set_credentials(State), Result = rabbitmq_aws:api_get_request_with_retries("AWS", "API", 3, 1), - ok = gen_server:stop(Pid), ?assertEqual({error, "AWS service is unavailable"}, Result), meck:validate(gun) end}, {"AWS service API request succeeded after a transient error", fun() -> - State = #state{ - access_key = "ExpiredKey", - secret_access_key = "ExpiredAccessKey", - region = "us-east-1", - expiration = {{3016, 4, 1}, {12, 0, 0}} - }, + set_test_credentials("ExpiredKey", "ExpiredAccessKey", undefined, { + {3016, 4, 1}, {12, 0, 0} + }), + set_test_region("us-east-1"), + meck:expect(gun, open, fun(_, _, _) -> {ok, spawn(fun() -> ok end)} end), meck:expect(gun, close, fun(_) -> ok end), meck:expect(gun, await_up, fun(_, _) -> {ok, protocol} end), @@ -685,11 +482,7 @@ api_get_request_test_() -> {ok, <<"{\"data\": \"value\"}">>} ]) ), - {ok, Pid} = rabbitmq_aws:start_link(), - rabbitmq_aws:set_region("us-east-1"), - rabbitmq_aws:set_credentials(State), Result = rabbitmq_aws:api_get_request_with_retries("AWS", "API", 3, 1), - ok = gen_server:stop(Pid), ?assertEqual({ok, [{"data", "value"}]}, Result), meck:validate(gun) end} @@ -700,82 +493,85 @@ ensure_credentials_valid_test_() -> { foreach, fun() -> + setup(), meck:new(rabbitmq_aws_config, []), [rabbitmq_aws_config] end, - fun meck:unload/1, + fun(Mods) -> + teardown(ok), + meck:unload(Mods) + end, [ {"expired credentials are refreshed", fun() -> - State = #state{ - access_key = "ExpiredKey", - secret_access_key = "ExpiredAccessKey", - region = "us-east-1", - expiration = {{2016, 4, 1}, {12, 0, 0}} - }, - State2 = #state{ - access_key = "NewKey", - secret_access_key = "NewAccessKey", - region = "us-east-1", - expiration = {{3016, 4, 1}, {12, 0, 0}} - }, + % Set expired credentials in ETS + set_test_credentials("ExpiredKey", "ExpiredAccessKey", undefined, { + {2016, 4, 1}, {12, 0, 0} + }), + set_test_region("us-east-1"), + % Mock config to return new credentials when refresh is called meck:expect( rabbitmq_aws_config, credentials, fun() -> - {ok, State2#state.access_key, State2#state.secret_access_key, - State2#state.expiration, State2#state.security_token} + {ok, "NewKey", "NewAccessKey", {{3016, 4, 1}, {12, 0, 0}}, undefined} end ), - {ok, Pid} = rabbitmq_aws:start_link(), - rabbitmq_aws:set_region("us-east-1"), - rabbitmq_aws:set_credentials(State), + Result = rabbitmq_aws:ensure_credentials_valid(), - Credentials = gen_server:call(Pid, get_state), - ok = gen_server:stop(Pid), + + % Check that credentials were refreshed in ETS + {ok, AccessKey, SecretKey, SecurityToken, Region} = rabbitmq_aws:get_credentials(), + ?assertEqual(ok, Result), - ?assertEqual(Credentials, {ok, State2}), + ?assertEqual("NewKey", AccessKey), + ?assertEqual("NewAccessKey", SecretKey), + ?assertEqual(undefined, SecurityToken), + ?assertEqual("us-east-1", Region), meck:validate(rabbitmq_aws_config) end}, {"valid credentials are returned", fun() -> - State = #state{ - access_key = "GoodKey", - secret_access_key = "GoodAccessKey", - region = "us-east-1", - expiration = {{3016, 4, 1}, {12, 0, 0}} - }, - {ok, Pid} = rabbitmq_aws:start_link(), - rabbitmq_aws:set_region("us-east-1"), - rabbitmq_aws:set_credentials(State), + % Set valid (non-expired) credentials in ETS + set_test_credentials("GoodKey", "GoodAccessKey", undefined, { + {3016, 4, 1}, {12, 0, 0} + }), + set_test_region("us-east-1"), + Result = rabbitmq_aws:ensure_credentials_valid(), - Credentials = gen_server:call(Pid, get_state), - ok = gen_server:stop(Pid), + + % Check that credentials remain unchanged in ETS + {ok, AccessKey, SecretKey, SecurityToken, Region} = rabbitmq_aws:get_credentials(), + ?assertEqual(ok, Result), - ?assertEqual(Credentials, {ok, State}), + ?assertEqual("GoodKey", AccessKey), + ?assertEqual("GoodAccessKey", SecretKey), + ?assertEqual(undefined, SecurityToken), + ?assertEqual("us-east-1", Region), meck:validate(rabbitmq_aws_config) end}, {"load credentials if missing", fun() -> - State = #state{ - access_key = "GoodKey", - secret_access_key = "GoodAccessKey", - region = "us-east-1", - expiration = {{3016, 4, 1}, {12, 0, 0}} - }, + % Don't set any credentials in ETS - should trigger refresh + set_test_region("us-east-1"), + + % Mock config to return credentials when refresh is called meck:expect( rabbitmq_aws_config, credentials, fun() -> - {ok, State#state.access_key, State#state.secret_access_key, - State#state.expiration, State#state.security_token} + {ok, "GoodKey", "GoodAccessKey", {{3016, 4, 1}, {12, 0, 0}}, undefined} end ), - {ok, Pid} = rabbitmq_aws:start_link(), - rabbitmq_aws:set_region("us-east-1"), + Result = rabbitmq_aws:ensure_credentials_valid(), - Credentials = gen_server:call(Pid, get_state), - ok = gen_server:stop(Pid), + + % Check that credentials were loaded into ETS + {ok, AccessKey, SecretKey, SecurityToken, Region} = rabbitmq_aws:get_credentials(), + ?assertEqual(ok, Result), - ?assertEqual(Credentials, {ok, State}), + ?assertEqual("GoodKey", AccessKey), + ?assertEqual("GoodAccessKey", SecretKey), + ?assertEqual(undefined, SecurityToken), + ?assertEqual("us-east-1", Region), meck:validate(rabbitmq_aws_config) end} ] From e669ff9262fafdf11dddb8c770f711f090adea24 Mon Sep 17 00:00:00 2001 From: Simon Unge Date: Wed, 17 Sep 2025 21:35:08 +0000 Subject: [PATCH 2/5] All tests in aws_tests work --- deps/rabbitmq_aws/src/rabbitmq_aws.erl | 180 ++++++++++-------- .../test/rabbitmq_aws_sign_tests.erl | 2 +- deps/rabbitmq_aws/test/rabbitmq_aws_tests.erl | 3 +- 3 files changed, 101 insertions(+), 84 deletions(-) diff --git a/deps/rabbitmq_aws/src/rabbitmq_aws.erl b/deps/rabbitmq_aws/src/rabbitmq_aws.erl index c598a40a937d..58c163745ce6 100644 --- a/deps/rabbitmq_aws/src/rabbitmq_aws.erl +++ b/deps/rabbitmq_aws/src/rabbitmq_aws.erl @@ -26,7 +26,7 @@ close_connection/1, direct_request/6, endpoint/4, - sign_headers/9 + sign_headers/10 ]). %% Export all for unit tests @@ -37,83 +37,7 @@ -include("rabbitmq_aws.hrl"). -include_lib("kernel/include/logger.hrl"). --type connection_handle() :: {gun:conn_ref(), string()}. -%%==================================================================== -%% ETS-based state management -%%==================================================================== - --spec get_credentials() -> - {ok, access_key(), secret_access_key(), security_token(), region()} | {error, term()}. -get_credentials() -> - get_credentials(10). - --spec get_credentials(Retries :: non_neg_integer()) -> - {ok, access_key(), secret_access_key(), security_token(), region()} | {error, term()}. -get_credentials(Retries) -> - case ets:lookup(?AWS_CREDENTIALS_TABLE, current) of - [{current, Creds}] -> - case expired_credentials(Creds#aws_credentials.expiration) of - false -> - Region = get_region(), - {ok, Creds#aws_credentials.access_key, Creds#aws_credentials.secret_key, - Creds#aws_credentials.security_token, Region}; - true -> - refresh_credentials_with_lock(Retries) - end; - [] -> - refresh_credentials_with_lock(Retries) - end. - --spec refresh_credentials_with_lock(Retries :: non_neg_integer()) -> - {ok, access_key(), secret_access_key(), security_token(), region()} | {error, term()}. -refresh_credentials_with_lock(0) -> - {error, lock_timeout}; -refresh_credentials_with_lock(Retries) -> - LockId = {aws_credentials_refresh, node()}, - case global:set_lock(LockId, [node()], 0) of - true -> - try - % Double-check if someone else already refreshed - case ets:lookup(?AWS_CREDENTIALS_TABLE, current) of - [{current, Creds}] -> - case expired_credentials(Creds#aws_credentials.expiration) of - false -> - Region = get_region(), - {ok, Creds#aws_credentials.access_key, - Creds#aws_credentials.secret_key, - Creds#aws_credentials.security_token, Region}; - true -> - do_refresh_credentials() - end; - [] -> - do_refresh_credentials() - end - after - global:del_lock(LockId, [node()]) - end; - false -> - % Someone else is refreshing, wait and retry - timer:sleep(100), - get_credentials(Retries - 1) - end. - --spec do_refresh_credentials() -> - {ok, access_key(), secret_access_key(), security_token(), region()} | {error, term()}. -do_refresh_credentials() -> - Region = get_region(), - case rabbitmq_aws_config:credentials() of - {ok, AccessKey, SecretAccessKey, Expiration, SecurityToken} -> - Creds = #aws_credentials{ - access_key = AccessKey, - secret_key = SecretAccessKey, - security_token = SecurityToken, - expiration = Expiration - }, - ets:insert(?AWS_CREDENTIALS_TABLE, {current, Creds}), - {ok, AccessKey, SecretAccessKey, SecurityToken, Region}; - {error, Reason} -> - {error, Reason} - end. +-type connection_handle() :: {pid(), string()}. -spec get_region() -> region(). get_region() -> @@ -258,7 +182,16 @@ direct_request({GunPid, Service}, Method, Path, Body, Headers, Options) -> URI = create_uri(Host, Path), BodyHash = proplists:get_value(payload_hash, Options), SignedHeaders = sign_headers( - AccessKey, SecretKey, SecurityToken, Region, Service, Method, URI, Headers, Body, BodyHash + AccessKey, + SecretKey, + SecurityToken, + Region, + Service, + Method, + URI, + Headers, + Body, + BodyHash ), direct_gun_request(GunPid, Method, Path, SignedHeaders, Body, Options); {error, Reason} -> @@ -277,7 +210,9 @@ direct_request({GunPid, Service}, Method, Path, Body, Headers, Options) -> Body :: body(), BodyHash :: iodata() ) -> headers(). -sign_headers(AccessKey, SecretKey, SecurityToken, Region, Service, Method, URI, Headers, Body, BodyHash) -> +sign_headers( + AccessKey, SecretKey, SecurityToken, Region, Service, Method, URI, Headers, Body, BodyHash +) -> rabbitmq_aws_sign:headers( #request{ access_key = AccessKey, @@ -388,7 +323,16 @@ perform_request_direct(Service, Method, Headers, Path, Body, Options, Host) -> {ok, AccessKey, SecretKey, SecurityToken, Region} -> URI = endpoint(Region, Host, Service, Path), SignedHeaders = sign_headers( - AccessKey, SecretKey, SecurityToken, Region, Service, Method, URI, Headers, Body + AccessKey, + SecretKey, + SecurityToken, + Region, + Service, + Method, + URI, + Headers, + Body, + undefined ), gun_request(Method, URI, SignedHeaders, Body, Options); {error, Reason} -> @@ -422,6 +366,79 @@ endpoint_tld("cn-northwest-1") -> endpoint_tld(_Other) -> "amazonaws.com". +-spec get_credentials() -> + {ok, access_key(), secret_access_key(), security_token(), region()} | {error, term()}. +get_credentials() -> + get_credentials(10). + +-spec get_credentials(Retries :: non_neg_integer()) -> + {ok, access_key(), secret_access_key(), security_token(), region()} | {error, term()}. +get_credentials(Retries) -> + case ets:lookup(?AWS_CREDENTIALS_TABLE, current) of + [{current, Creds}] -> + case expired_credentials(Creds#aws_credentials.expiration) of + false -> + Region = get_region(), + {ok, Creds#aws_credentials.access_key, Creds#aws_credentials.secret_key, + Creds#aws_credentials.security_token, Region}; + true -> + refresh_credentials_with_lock(Retries) + end; + [] -> + refresh_credentials_with_lock(Retries) + end. + +-spec refresh_credentials_with_lock(Retries :: non_neg_integer()) -> + {ok, access_key(), secret_access_key(), security_token(), region()} | {error, term()}. +refresh_credentials_with_lock(0) -> + {error, lock_timeout}; +refresh_credentials_with_lock(Retries) -> + LockId = {aws_credentials_refresh, node()}, + case global:set_lock(LockId, [node()], 0) of + true -> + try + % Double-check if someone else already refreshed + case ets:lookup(?AWS_CREDENTIALS_TABLE, current) of + [{current, Creds}] -> + case expired_credentials(Creds#aws_credentials.expiration) of + false -> + Region = get_region(), + {ok, Creds#aws_credentials.access_key, + Creds#aws_credentials.secret_key, + Creds#aws_credentials.security_token, Region}; + true -> + do_refresh_credentials() + end; + [] -> + do_refresh_credentials() + end + after + global:del_lock(LockId, [node()]) + end; + false -> + % Someone else is refreshing, wait and retry + timer:sleep(100), + get_credentials(Retries - 1) + end. + +-spec do_refresh_credentials() -> + {ok, access_key(), secret_access_key(), security_token(), region()} | {error, term()}. +do_refresh_credentials() -> + Region = get_region(), + case rabbitmq_aws_config:credentials() of + {ok, AccessKey, SecretAccessKey, Expiration, SecurityToken} -> + Creds = #aws_credentials{ + access_key = AccessKey, + secret_key = SecretAccessKey, + security_token = SecurityToken, + expiration = Expiration + }, + ets:insert(?AWS_CREDENTIALS_TABLE, {current, Creds}), + {ok, AccessKey, SecretAccessKey, SecurityToken, Region}; + {error, Reason} -> + {error, Reason} + end. + -spec format_response(Response :: httpc_result()) -> result(). %% @doc Format the httpc response result, returning the request result data %% structure. The response body will attempt to be decoded by invoking the @@ -676,7 +693,6 @@ status_text(416) -> "Range Not Satisfiable"; status_text(500) -> "Internal Server Error"; status_text(Code) -> integer_to_list(Code). - -spec direct_gun_request( GunPid :: pid(), Method :: method(), diff --git a/deps/rabbitmq_aws/test/rabbitmq_aws_sign_tests.erl b/deps/rabbitmq_aws/test/rabbitmq_aws_sign_tests.erl index fbdd0a877344..9b22a90ec8b3 100644 --- a/deps/rabbitmq_aws/test/rabbitmq_aws_sign_tests.erl +++ b/deps/rabbitmq_aws/test/rabbitmq_aws_sign_tests.erl @@ -221,7 +221,7 @@ request_hash_test_() -> {"Host", "iam.amazonaws.com"}, {"Date", "20150830T123600Z"} ], - Payload = "", + Payload = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", Expectation = "49b454e0f20fe17f437eaa570846fc5d687efc1752c8b5a1eeee5597a7eb92a5", ?assertEqual( Expectation, diff --git a/deps/rabbitmq_aws/test/rabbitmq_aws_tests.erl b/deps/rabbitmq_aws/test/rabbitmq_aws_tests.erl index e25db57e30d8..39f42091f6ed 100644 --- a/deps/rabbitmq_aws/test/rabbitmq_aws_tests.erl +++ b/deps/rabbitmq_aws/test/rabbitmq_aws_tests.erl @@ -360,7 +360,8 @@ sign_headers_test_() -> Method, URI, Headers, - Body + Body, + undefined ) ), meck:validate(calendar) From 514aae1a3ad2e42c612e34b9278a792595d3b9ac Mon Sep 17 00:00:00 2001 From: Simon Unge Date: Mon, 22 Sep 2025 22:48:39 +0000 Subject: [PATCH 3/5] gun dependency --- deps/rabbitmq_aws/Makefile | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/deps/rabbitmq_aws/Makefile b/deps/rabbitmq_aws/Makefile index 4e85ebc41134..aa51b1b2f6e5 100644 --- a/deps/rabbitmq_aws/Makefile +++ b/deps/rabbitmq_aws/Makefile @@ -6,7 +6,8 @@ define PROJECT_ENV [] endef -LOCAL_DEPS = crypto inets ssl xmerl public_key gun +LOCAL_DEPS = crypto inets ssl xmerl public_key +DEPS = gun BUILD_DEPS = rabbit_common # We do not depend on rabbit therefore can't run the broker. DEP_PLUGINS = rabbit_common/mk/rabbitmq-build.mk From a976b64f2bd7c602c3b52b52462dc27121bb54f2 Mon Sep 17 00:00:00 2001 From: Simon Unge Date: Tue, 23 Sep 2025 23:46:06 +0000 Subject: [PATCH 4/5] Dialyze take 1 --- deps/rabbitmq_aws/src/rabbitmq_aws.erl | 68 +++++-------------- deps/rabbitmq_aws/src/rabbitmq_aws_app.erl | 2 - deps/rabbitmq_aws/src/rabbitmq_aws_config.erl | 8 +-- deps/rabbitmq_aws/src/rabbitmq_aws_sup.erl | 2 + .../test/rabbitmq_aws_app_tests.erl | 25 ------- deps/rabbitmq_aws/test/rabbitmq_aws_tests.erl | 2 +- 6 files changed, 25 insertions(+), 82 deletions(-) delete mode 100644 deps/rabbitmq_aws/test/rabbitmq_aws_app_tests.erl diff --git a/deps/rabbitmq_aws/src/rabbitmq_aws.erl b/deps/rabbitmq_aws/src/rabbitmq_aws.erl index 58c163745ce6..e56fbbab4c18 100644 --- a/deps/rabbitmq_aws/src/rabbitmq_aws.erl +++ b/deps/rabbitmq_aws/src/rabbitmq_aws.erl @@ -45,17 +45,9 @@ get_region() -> [{region, Region}] -> Region; [] -> - % Use proper region detection - case rabbitmq_aws_config:region() of - {ok, DetectedRegion} -> - % Cache the detected region - ets:insert(?AWS_CONFIG_TABLE, {region, DetectedRegion}), - DetectedRegion; - _ -> - % Final fallback - ets:insert(?AWS_CONFIG_TABLE, {region, "us-east-1"}), - "us-east-1" - end + {ok, DetectedRegion} = rabbitmq_aws_config:region(), + ets:insert(?AWS_CONFIG_TABLE, {region, DetectedRegion}), + DetectedRegion end. -spec set_region(Region :: region()) -> ok. @@ -65,9 +57,9 @@ set_region(Region) -> -spec has_credentials() -> boolean(). has_credentials() -> - case ets:lookup(?AWS_CREDENTIALS_TABLE, current) of - [{current, Creds}] when Creds#aws_credentials.access_key =/= undefined -> - not expired_credentials(Creds#aws_credentials.expiration); + case ets:lookup(?AWS_CREDENTIALS_TABLE, aws_credentials) of + [#aws_credentials{access_key = Key, expiration = Expiration}] when Key =/= undefined -> + not expired_credentials(Expiration); _ -> false end. @@ -76,26 +68,9 @@ has_credentials() -> %% exported wrapper functions %%==================================================================== --spec get( - ServiceOrHandle :: string() | connection_handle(), - Path :: path() -) -> result(). -%% @doc Perform a HTTP GET request to the AWS API for the specified service. The -%% response will automatically be decoded if it is either in JSON, or XML -%% format. -%% @end get(ServiceOrHandle, Path) -> get(ServiceOrHandle, Path, []). --spec get( - ServiceOrHandle :: string() | connection_handle(), - Path :: path(), - Headers :: headers() -) -> result(). -%% @doc Perform a HTTP GET request to the AWS API for the specified service. The -%% response will automatically be decoded if it is either in JSON or XML -%% format. -%% @end get(ServiceOrHandle, Path, Headers) -> get(ServiceOrHandle, Path, Headers, []). @@ -208,7 +183,7 @@ direct_request({GunPid, Service}, Method, Path, Body, Headers, Options) -> URI :: string(), Headers :: headers(), Body :: body(), - BodyHash :: iodata() + BodyHash :: iodata() | undefined ) -> headers(). sign_headers( AccessKey, SecretKey, SecurityToken, Region, Service, Method, URI, Headers, Body, BodyHash @@ -264,7 +239,7 @@ request(Service, Method, Path, Body, Headers, HTTPOptions) -> Body :: body(), Headers :: headers(), HTTPOptions :: http_options(), - Endpoint :: host() + Endpoint :: host() | undefined ) -> result(). %% @doc Perform a HTTP request to the AWS API for the specified service, overriding %% the endpoint URL to use when invoking the API. This is useful for local testing @@ -291,7 +266,7 @@ set_credentials(AccessKey, SecretAccessKey) -> security_token = undefined, expiration = undefined }, - ets:insert(?AWS_CREDENTIALS_TABLE, {current, Creds}), + ets:insert(?AWS_CREDENTIALS_TABLE, Creds), ok. -spec ensure_credentials_valid() -> ok. @@ -366,21 +341,16 @@ endpoint_tld("cn-northwest-1") -> endpoint_tld(_Other) -> "amazonaws.com". --spec get_credentials() -> - {ok, access_key(), secret_access_key(), security_token(), region()} | {error, term()}. get_credentials() -> get_credentials(10). --spec get_credentials(Retries :: non_neg_integer()) -> - {ok, access_key(), secret_access_key(), security_token(), region()} | {error, term()}. get_credentials(Retries) -> - case ets:lookup(?AWS_CREDENTIALS_TABLE, current) of - [{current, Creds}] -> - case expired_credentials(Creds#aws_credentials.expiration) of + case ets:lookup(?AWS_CREDENTIALS_TABLE, aws_credentials) of + [#aws_credentials{access_key = Key, expiration = Expiration, secret_key = SecretKey, security_token = SecurityToken}] -> + case expired_credentials(Expiration) of false -> Region = get_region(), - {ok, Creds#aws_credentials.access_key, Creds#aws_credentials.secret_key, - Creds#aws_credentials.security_token, Region}; + {ok, Key, SecretKey, SecurityToken, Region}; true -> refresh_credentials_with_lock(Retries) end; @@ -398,14 +368,12 @@ refresh_credentials_with_lock(Retries) -> true -> try % Double-check if someone else already refreshed - case ets:lookup(?AWS_CREDENTIALS_TABLE, current) of - [{current, Creds}] -> - case expired_credentials(Creds#aws_credentials.expiration) of + case ets:lookup(?AWS_CREDENTIALS_TABLE, aws_credentials) of + [#aws_credentials{access_key = Key, expiration = Expiration, secret_key = SecretKey, security_token = SecurityToken}] -> + case expired_credentials(Expiration) of false -> Region = get_region(), - {ok, Creds#aws_credentials.access_key, - Creds#aws_credentials.secret_key, - Creds#aws_credentials.security_token, Region}; + {ok, Key, SecretKey, SecurityToken, Region}; true -> do_refresh_credentials() end; @@ -433,7 +401,7 @@ do_refresh_credentials() -> security_token = SecurityToken, expiration = Expiration }, - ets:insert(?AWS_CREDENTIALS_TABLE, {current, Creds}), + ets:insert(?AWS_CREDENTIALS_TABLE, Creds), {ok, AccessKey, SecretAccessKey, SecurityToken, Region}; {error, Reason} -> {error, Reason} diff --git a/deps/rabbitmq_aws/src/rabbitmq_aws_app.erl b/deps/rabbitmq_aws/src/rabbitmq_aws_app.erl index 89a0648ab892..ef5497c94772 100644 --- a/deps/rabbitmq_aws/src/rabbitmq_aws_app.erl +++ b/deps/rabbitmq_aws/src/rabbitmq_aws_app.erl @@ -4,8 +4,6 @@ -export([start/2, stop/1]). start(_Type, _Args) -> - ets:new(aws_credentials, [named_table, public, {read_concurrency, true}]), - ets:new(aws_config, [named_table, public, {read_concurrency, true}]), rabbitmq_aws_sup:start_link(). stop(_State) -> diff --git a/deps/rabbitmq_aws/src/rabbitmq_aws_config.erl b/deps/rabbitmq_aws/src/rabbitmq_aws_config.erl index 39e3d7685137..1e5da7313d83 100644 --- a/deps/rabbitmq_aws/src/rabbitmq_aws_config.erl +++ b/deps/rabbitmq_aws/src/rabbitmq_aws_config.erl @@ -535,7 +535,7 @@ lookup_credentials_from_proplist(_, undefined) -> lookup_credentials_from_proplist(AccessKey, SecretKey) -> {ok, AccessKey, SecretKey, undefined, undefined}. --spec with_metadata_connection(fun((gun:conn_ref()) -> Result)) -> Result. +-spec with_metadata_connection(fun((pid()) -> Result)) -> Result. %% @doc Execute a function with a shared metadata service connection %% @end with_metadata_connection(Fun) -> @@ -619,7 +619,7 @@ maybe_convert_number(Value) -> end. -spec maybe_get_credentials_from_instance_metadata_with_conn( - ConnPid :: gun:conn_ref(), + ConnPid :: pid(), {ok, Role :: string()} | {error, undefined} ) -> @@ -627,7 +627,7 @@ maybe_convert_number(Value) -> %% @doc Try to query the EC2 local instance metadata service to get temporary %% authentication credentials using an existing connection. %% @end -maybe_get_credentials_from_instance_metadata_with_conn(_, {error, undefined}) -> +maybe_get_credentials_from_instance_metadata_with_conn(_, {error, _}) -> {error, undefined}; maybe_get_credentials_from_instance_metadata_with_conn(ConnPid, {ok, Role}) -> URL = instance_credentials_url(Role), @@ -642,7 +642,7 @@ maybe_get_region_from_instance_metadata() -> URL = instance_availability_zone_url(), parse_az_response(perform_http_get_instance_metadata(URL)). --spec perform_http_get_with_conn(gun:conn_ref(), string()) -> httpc_result(). +-spec perform_http_get_with_conn(pid(), string()) -> httpc_result(). %% @doc Make HTTP GET request using existing Gun connection %% @end perform_http_get_with_conn(ConnPid, Path) -> diff --git a/deps/rabbitmq_aws/src/rabbitmq_aws_sup.erl b/deps/rabbitmq_aws/src/rabbitmq_aws_sup.erl index 87297ef5a8a5..73d79d50bf8b 100644 --- a/deps/rabbitmq_aws/src/rabbitmq_aws_sup.erl +++ b/deps/rabbitmq_aws/src/rabbitmq_aws_sup.erl @@ -18,4 +18,6 @@ start_link() -> init([]) -> % No children needed - just return empty supervisor + ets:new(aws_credentials, [named_table, public, {read_concurrency, true}]), + ets:new(aws_config, [named_table, public, {read_concurrency, true}]), {ok, {{one_for_one, 5, 10}, []}}. diff --git a/deps/rabbitmq_aws/test/rabbitmq_aws_app_tests.erl b/deps/rabbitmq_aws/test/rabbitmq_aws_app_tests.erl deleted file mode 100644 index ccb95aa52738..000000000000 --- a/deps/rabbitmq_aws/test/rabbitmq_aws_app_tests.erl +++ /dev/null @@ -1,25 +0,0 @@ --module(rabbitmq_aws_app_tests). - --include_lib("eunit/include/eunit.hrl"). - -start_test_() -> - {foreach, - fun() -> - meck:new(rabbitmq_aws_sup, [passthrough]) - end, - fun(_) -> - meck:unload(rabbitmq_aws_sup) - end, - [ - {"supervisor initialized", fun() -> - meck:expect(rabbitmq_aws_sup, start_link, fun() -> {ok, test_result} end), - ?assertEqual( - {ok, test_result}, - rabbitmq_aws_app:start(temporary, []) - ), - meck:validate(rabbitmq_aws_sup) - end} - ]}. - -stop_test() -> - ?assertEqual(ok, rabbitmq_aws_app:stop({})). diff --git a/deps/rabbitmq_aws/test/rabbitmq_aws_tests.erl b/deps/rabbitmq_aws/test/rabbitmq_aws_tests.erl index 39f42091f6ed..08a2da4512f3 100644 --- a/deps/rabbitmq_aws/test/rabbitmq_aws_tests.erl +++ b/deps/rabbitmq_aws/test/rabbitmq_aws_tests.erl @@ -24,7 +24,7 @@ set_test_credentials(AccessKey, SecretKey, SecurityToken, Expiration) -> security_token = SecurityToken, expiration = Expiration }, - ets:insert(?AWS_CREDENTIALS_TABLE, {current, Creds}). + ets:insert(?AWS_CREDENTIALS_TABLE, Creds). set_test_region(Region) -> ets:insert(?AWS_CONFIG_TABLE, {region, Region}). From ea8fd41bd4a450a0649ffd64ee6d7d633343f315 Mon Sep 17 00:00:00 2001 From: Simon Unge Date: Wed, 24 Sep 2025 17:31:21 +0000 Subject: [PATCH 5/5] Dialyze take 2 --- deps/rabbitmq_aws/src/rabbitmq_aws.erl | 4 +--- deps/rabbitmq_aws/src/rabbitmq_aws_config.erl | 15 ++++----------- deps/rabbitmq_aws/src/rabbitmq_aws_sup.erl | 4 ++-- deps/rabbitmq_aws/test/rabbitmq_aws_tests.erl | 8 -------- 4 files changed, 7 insertions(+), 24 deletions(-) diff --git a/deps/rabbitmq_aws/src/rabbitmq_aws.erl b/deps/rabbitmq_aws/src/rabbitmq_aws.erl index e56fbbab4c18..2e242078e4cc 100644 --- a/deps/rabbitmq_aws/src/rabbitmq_aws.erl +++ b/deps/rabbitmq_aws/src/rabbitmq_aws.erl @@ -311,7 +311,7 @@ perform_request_direct(Service, Method, Headers, Path, Body, Options, Host) -> ), gun_request(Method, URI, SignedHeaders, Body, Options); {error, Reason} -> - {error, {credentials, Reason}} + {error, Reason} end. -spec endpoint( @@ -545,8 +545,6 @@ api_get_request_with_retries(Service, Path, Retries, WaitTimeBetweenRetries) -> {ok, {_Headers, Payload}} -> ?LOG_DEBUG("AWS request: ~ts~nResponse: ~tp", [Path, Payload]), {ok, Payload}; - {error, {credentials, _}} -> - {error, credentials}; {error, Message, Response} -> ?LOG_WARNING("Error occurred: ~ts", [Message]), case Response of diff --git a/deps/rabbitmq_aws/src/rabbitmq_aws_config.erl b/deps/rabbitmq_aws/src/rabbitmq_aws_config.erl index 1e5da7313d83..1d860e91efca 100644 --- a/deps/rabbitmq_aws/src/rabbitmq_aws_config.erl +++ b/deps/rabbitmq_aws/src/rabbitmq_aws_config.erl @@ -642,7 +642,7 @@ maybe_get_region_from_instance_metadata() -> URL = instance_availability_zone_url(), parse_az_response(perform_http_get_instance_metadata(URL)). --spec perform_http_get_with_conn(pid(), string()) -> httpc_result(). +-spec perform_http_get_with_conn(pid(), string()) -> {ok, {any(), any(), any()}} | {error, term()}. %% @doc Make HTTP GET request using existing Gun connection %% @end perform_http_get_with_conn(ConnPid, Path) -> @@ -666,8 +666,8 @@ maybe_get_role_from_instance_metadata_with_conn(ConnPid) -> {_, _, Path} = rabbitmq_aws:parse_uri(URL), parse_body_response(perform_http_get_with_conn(ConnPid, Path)). --spec parse_az_response(httpc_result()) -> - {ok, Region :: string()} | {error, Reason :: atom()}. +%% -spec parse_az_response(httpc_result()) -> +%% {ok, Region :: string()} | {error, Reason :: atom()}. %% @doc Parse the response from the Availability Zone query to the %% Instance Metadata service, returning the Region if successful. %% end. @@ -675,13 +675,9 @@ parse_az_response({error, _}) -> {error, undefined}; parse_az_response({ok, {{_, 200, _}, _, Body}}) when is_binary(Body) -> {ok, region_from_availability_zone(binary_to_list(Body))}; -parse_az_response({ok, {{_, 200, _}, _, Body}}) -> - {ok, region_from_availability_zone(Body)}; parse_az_response({ok, {{_, _, _}, _, _}}) -> {error, undefined}. --spec parse_body_response(httpc_result()) -> - {ok, Value :: string()} | {error, Reason :: atom()}. %% @doc Parse the return response from the Instance Metadata Service where the %% body value is the string to process. %% end. @@ -689,7 +685,6 @@ parse_body_response({error, _}) -> {error, undefined}; parse_body_response({ok, {{_, 200, _}, _, Body}}) when is_binary(Body) -> {ok, binary_to_list(Body)}; -parse_body_response({ok, {{_, 200, _}, _, Body}}) when is_list(Body) -> {ok, Body}; parse_body_response({ok, {{_, 401, _}, _, _}}) -> ?LOG_ERROR( get_instruction_on_instance_metadata_error( @@ -704,7 +699,7 @@ parse_body_response({ok, {{_, 403, _}, _, _}}) -> ) ), {error, undefined}; -parse_body_response({ok, {{_, _, _}, _, _}}) -> +parse_body_response(_) -> {error, undefined}. -spec parse_credentials_response(httpc_result()) -> security_credentials(). @@ -721,7 +716,6 @@ parse_credentials_response({ok, {{_, 200, _}, _, Body}}) -> parse_iso8601_timestamp(proplists:get_value("Expiration", Parsed)), proplists:get_value("Token", Parsed)}. --spec perform_http_get_instance_metadata(string()) -> httpc_result(). %% @doc Wrap httpc:get/4 to simplify Instance Metadata service v2 requests %% @end perform_http_get_instance_metadata(URL) -> @@ -813,7 +807,6 @@ read_file(Path) -> Error end. --spec region_from_availability_zone(Value :: string()) -> string(). %% @doc Strip the availability zone suffix from the region. %% @end region_from_availability_zone(Value) -> diff --git a/deps/rabbitmq_aws/src/rabbitmq_aws_sup.erl b/deps/rabbitmq_aws/src/rabbitmq_aws_sup.erl index 73d79d50bf8b..151423914c7a 100644 --- a/deps/rabbitmq_aws/src/rabbitmq_aws_sup.erl +++ b/deps/rabbitmq_aws/src/rabbitmq_aws_sup.erl @@ -18,6 +18,6 @@ start_link() -> init([]) -> % No children needed - just return empty supervisor - ets:new(aws_credentials, [named_table, public, {read_concurrency, true}]), - ets:new(aws_config, [named_table, public, {read_concurrency, true}]), + _ = ets:new(aws_credentials, [named_table, public, {read_concurrency, true}]), + _ = ets:new(aws_config, [named_table, public, {read_concurrency, true}]), {ok, {{one_for_one, 5, 10}, []}}. diff --git a/deps/rabbitmq_aws/test/rabbitmq_aws_tests.erl b/deps/rabbitmq_aws/test/rabbitmq_aws_tests.erl index 08a2da4512f3..62f49abe3462 100644 --- a/deps/rabbitmq_aws/test/rabbitmq_aws_tests.erl +++ b/deps/rabbitmq_aws/test/rabbitmq_aws_tests.erl @@ -414,14 +414,6 @@ api_get_request_test_() -> ?assertEqual({ok, [{"data", "value"}]}, Result), meck:validate(gun) end}, - {"AWS service API request failed - credentials", fun() -> - set_test_region("us-east-1"), - % No credentials set - should fail - meck:expect(rabbitmq_aws_config, credentials, 0, {error, undefined}), - - Result = rabbitmq_aws:api_get_request("AWS", "API"), - ?assertEqual({error, credentials}, Result) - end}, {"AWS service API request failed - API error with persistent failure", fun() -> set_test_credentials("ExpiredKey", "ExpiredAccessKey", undefined, { {3016, 4, 1}, {12, 0, 0}