From 72cdaf69324cf2b21dae80fb5df1dea8e20be3e0 Mon Sep 17 00:00:00 2001 From: Vladyslav Vildanov <117659936+vladvildanov@users.noreply.github.com> Date: Tue, 11 Feb 2025 13:15:14 +0200 Subject: [PATCH] test: Updated CredentialProvider test infrastructure (#3502) * test: Updated CredentialProvider test infrastructure * Added linter exclusion * Updated dev dependency * Codestyle fixes * Updated async test infra * Added missing constant --- dev_requirements.txt | 2 +- tests/conftest.py | 99 +++++++++++++++++++++------------- tests/test_asyncio/conftest.py | 97 ++++++++++++++++++++------------- 3 files changed, 124 insertions(+), 74 deletions(-) diff --git a/dev_requirements.txt b/dev_requirements.txt index be74470ec2..728536d6fb 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -16,4 +16,4 @@ uvloop vulture>=2.3.0 wheel>=0.30.0 numpy>=1.24.0 -redis-entraid==0.1.0b1 +redis-entraid==0.3.0b1 diff --git a/tests/conftest.py b/tests/conftest.py index a900cea8bf..fc732c0d72 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,7 +5,7 @@ import time from datetime import datetime, timezone from enum import Enum -from typing import Callable, TypeVar +from typing import Callable, TypeVar, Union from unittest import mock from unittest.mock import Mock from urllib.parse import urlparse @@ -17,6 +17,7 @@ from redis import Sentinel from redis.auth.idp import IdentityProviderInterface from redis.auth.token import JWToken +from redis.auth.token_manager import RetryPolicy, TokenManagerConfig from redis.backoff import NoBackoff from redis.cache import ( CacheConfig, @@ -29,12 +30,21 @@ from redis.credentials import CredentialProvider from redis.exceptions import RedisClusterException from redis.retry import Retry -from redis_entraid.cred_provider import EntraIdCredentialsProvider, TokenAuthConfig +from redis_entraid.cred_provider import ( + DEFAULT_DELAY_IN_MS, + DEFAULT_EXPIRATION_REFRESH_RATIO, + DEFAULT_LOWER_REFRESH_BOUND_MILLIS, + DEFAULT_MAX_ATTEMPTS, + DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS, + EntraIdCredentialsProvider, +) from redis_entraid.identity_provider import ( ManagedIdentityIdType, + ManagedIdentityProviderConfig, ManagedIdentityType, - create_provider_from_managed_identity, - create_provider_from_service_principal, + ServicePrincipalIdentityProviderConfig, + _create_provider_from_managed_identity, + _create_provider_from_service_principal, ) from tests.ssl_utils import get_tls_certificates @@ -623,17 +633,33 @@ def identity_provider(request) -> IdentityProviderInterface: return mock_identity_provider() auth_type = kwargs.pop("auth_type", AuthType.SERVICE_PRINCIPAL) + config = get_identity_provider_config(request=request) if auth_type == "MANAGED_IDENTITY": - return _get_managed_identity_provider(request) + return _create_provider_from_managed_identity(config) + + return _create_provider_from_service_principal(config) - return _get_service_principal_provider(request) +def get_identity_provider_config( + request, +) -> Union[ManagedIdentityProviderConfig, ServicePrincipalIdentityProviderConfig]: + if hasattr(request, "param"): + kwargs = request.param.get("idp_kwargs", {}) + else: + kwargs = {} + + auth_type = kwargs.pop("auth_type", AuthType.SERVICE_PRINCIPAL) -def _get_managed_identity_provider(request): - authority = os.getenv("AZURE_AUTHORITY") + if auth_type == AuthType.MANAGED_IDENTITY: + return _get_managed_identity_provider_config(request) + + return _get_service_principal_provider_config(request) + + +def _get_managed_identity_provider_config(request) -> ManagedIdentityProviderConfig: resource = os.getenv("AZURE_RESOURCE") - id_value = os.getenv("AZURE_ID_VALUE", None) + id_value = os.getenv("AZURE_USER_ASSIGNED_MANAGED_ID", None) if hasattr(request, "param"): kwargs = request.param.get("idp_kwargs", {}) @@ -641,23 +667,24 @@ def _get_managed_identity_provider(request): kwargs = {} identity_type = kwargs.pop("identity_type", ManagedIdentityType.SYSTEM_ASSIGNED) - id_type = kwargs.pop("id_type", ManagedIdentityIdType.CLIENT_ID) + id_type = kwargs.pop("id_type", ManagedIdentityIdType.OBJECT_ID) - return create_provider_from_managed_identity( + return ManagedIdentityProviderConfig( identity_type=identity_type, resource=resource, id_type=id_type, id_value=id_value, - authority=authority, - **kwargs, + kwargs=kwargs, ) -def _get_service_principal_provider(request): +def _get_service_principal_provider_config( + request, +) -> ServicePrincipalIdentityProviderConfig: client_id = os.getenv("AZURE_CLIENT_ID") client_credential = os.getenv("AZURE_CLIENT_SECRET") - authority = os.getenv("AZURE_AUTHORITY") - scopes = os.getenv("AZURE_REDIS_SCOPES", []) + tenant_id = os.getenv("AZURE_TENANT_ID") + scopes = os.getenv("AZURE_REDIS_SCOPES", None) if hasattr(request, "param"): kwargs = request.param.get("idp_kwargs", {}) @@ -671,14 +698,14 @@ def _get_service_principal_provider(request): if isinstance(scopes, str): scopes = scopes.split(",") - return create_provider_from_service_principal( + return ServicePrincipalIdentityProviderConfig( client_id=client_id, client_credential=client_credential, scopes=scopes, timeout=timeout, token_kwargs=token_kwargs, - authority=authority, - **kwargs, + tenant_id=tenant_id, + app_kwargs=kwargs, ) @@ -690,31 +717,29 @@ def get_credential_provider(request) -> CredentialProvider: return cred_provider_class(**cred_provider_kwargs) idp = identity_provider(request) - initial_delay_in_ms = cred_provider_kwargs.get("initial_delay_in_ms", 0) - block_for_initial = cred_provider_kwargs.get("block_for_initial", False) expiration_refresh_ratio = cred_provider_kwargs.get( - "expiration_refresh_ratio", TokenAuthConfig.DEFAULT_EXPIRATION_REFRESH_RATIO + "expiration_refresh_ratio", DEFAULT_EXPIRATION_REFRESH_RATIO ) lower_refresh_bound_millis = cred_provider_kwargs.get( - "lower_refresh_bound_millis", TokenAuthConfig.DEFAULT_LOWER_REFRESH_BOUND_MILLIS - ) - max_attempts = cred_provider_kwargs.get( - "max_attempts", TokenAuthConfig.DEFAULT_MAX_ATTEMPTS + "lower_refresh_bound_millis", DEFAULT_LOWER_REFRESH_BOUND_MILLIS ) - delay_in_ms = cred_provider_kwargs.get( - "delay_in_ms", TokenAuthConfig.DEFAULT_DELAY_IN_MS + max_attempts = cred_provider_kwargs.get("max_attempts", DEFAULT_MAX_ATTEMPTS) + delay_in_ms = cred_provider_kwargs.get("delay_in_ms", DEFAULT_DELAY_IN_MS) + + token_mgr_config = TokenManagerConfig( + expiration_refresh_ratio=expiration_refresh_ratio, + lower_refresh_bound_millis=lower_refresh_bound_millis, + token_request_execution_timeout_in_ms=DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS, # noqa + retry_policy=RetryPolicy( + max_attempts=max_attempts, + delay_in_ms=delay_in_ms, + ), ) - auth_config = TokenAuthConfig(idp) - auth_config.expiration_refresh_ratio = expiration_refresh_ratio - auth_config.lower_refresh_bound_millis = lower_refresh_bound_millis - auth_config.max_attempts = max_attempts - auth_config.delay_in_ms = delay_in_ms - return EntraIdCredentialsProvider( - config=auth_config, - initial_delay_in_ms=initial_delay_in_ms, - block_for_initial=block_for_initial, + identity_provider=idp, + token_manager_config=token_mgr_config, + initial_delay_in_ms=delay_in_ms, ) diff --git a/tests/test_asyncio/conftest.py b/tests/test_asyncio/conftest.py index 8833426af1..fb6c51140e 100644 --- a/tests/test_asyncio/conftest.py +++ b/tests/test_asyncio/conftest.py @@ -17,14 +17,24 @@ from redis.asyncio.retry import Retry from redis.auth.idp import IdentityProviderInterface from redis.auth.token import JWToken +from redis.auth.token_manager import RetryPolicy, TokenManagerConfig from redis.backoff import NoBackoff from redis.credentials import CredentialProvider -from redis_entraid.cred_provider import EntraIdCredentialsProvider, TokenAuthConfig +from redis_entraid.cred_provider import ( + DEFAULT_DELAY_IN_MS, + DEFAULT_EXPIRATION_REFRESH_RATIO, + DEFAULT_LOWER_REFRESH_BOUND_MILLIS, + DEFAULT_MAX_ATTEMPTS, + DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS, + EntraIdCredentialsProvider, +) from redis_entraid.identity_provider import ( ManagedIdentityIdType, + ManagedIdentityProviderConfig, ManagedIdentityType, - create_provider_from_managed_identity, - create_provider_from_service_principal, + ServicePrincipalIdentityProviderConfig, + _create_provider_from_managed_identity, + _create_provider_from_service_principal, ) from tests.conftest import REDIS_INFO @@ -255,17 +265,33 @@ def identity_provider(request) -> IdentityProviderInterface: return mock_identity_provider() auth_type = kwargs.pop("auth_type", AuthType.SERVICE_PRINCIPAL) + config = get_identity_provider_config(request=request) if auth_type == "MANAGED_IDENTITY": - return _get_managed_identity_provider(request) + return _create_provider_from_managed_identity(config) + + return _create_provider_from_service_principal(config) + + +def get_identity_provider_config( + request, +) -> Union[ManagedIdentityProviderConfig, ServicePrincipalIdentityProviderConfig]: + if hasattr(request, "param"): + kwargs = request.param.get("idp_kwargs", {}) + else: + kwargs = {} - return _get_service_principal_provider(request) + auth_type = kwargs.pop("auth_type", AuthType.SERVICE_PRINCIPAL) + + if auth_type == AuthType.MANAGED_IDENTITY: + return _get_managed_identity_provider_config(request) + return _get_service_principal_provider_config(request) -def _get_managed_identity_provider(request): - authority = os.getenv("AZURE_AUTHORITY") + +def _get_managed_identity_provider_config(request) -> ManagedIdentityProviderConfig: resource = os.getenv("AZURE_RESOURCE") - id_value = os.getenv("AZURE_ID_VALUE", None) + id_value = os.getenv("AZURE_USER_ASSIGNED_MANAGED_ID", None) if hasattr(request, "param"): kwargs = request.param.get("idp_kwargs", {}) @@ -273,23 +299,24 @@ def _get_managed_identity_provider(request): kwargs = {} identity_type = kwargs.pop("identity_type", ManagedIdentityType.SYSTEM_ASSIGNED) - id_type = kwargs.pop("id_type", ManagedIdentityIdType.CLIENT_ID) + id_type = kwargs.pop("id_type", ManagedIdentityIdType.OBJECT_ID) - return create_provider_from_managed_identity( + return ManagedIdentityProviderConfig( identity_type=identity_type, resource=resource, id_type=id_type, id_value=id_value, - authority=authority, - **kwargs, + kwargs=kwargs, ) -def _get_service_principal_provider(request): +def _get_service_principal_provider_config( + request, +) -> ServicePrincipalIdentityProviderConfig: client_id = os.getenv("AZURE_CLIENT_ID") client_credential = os.getenv("AZURE_CLIENT_SECRET") - authority = os.getenv("AZURE_AUTHORITY") - scopes = os.getenv("AZURE_REDIS_SCOPES", []) + tenant_id = os.getenv("AZURE_TENANT_ID") + scopes = os.getenv("AZURE_REDIS_SCOPES", None) if hasattr(request, "param"): kwargs = request.param.get("idp_kwargs", {}) @@ -303,14 +330,14 @@ def _get_service_principal_provider(request): if isinstance(scopes, str): scopes = scopes.split(",") - return create_provider_from_service_principal( + return ServicePrincipalIdentityProviderConfig( client_id=client_id, client_credential=client_credential, scopes=scopes, timeout=timeout, token_kwargs=token_kwargs, - authority=authority, - **kwargs, + tenant_id=tenant_id, + app_kwargs=kwargs, ) @@ -322,31 +349,29 @@ def get_credential_provider(request) -> CredentialProvider: return cred_provider_class(**cred_provider_kwargs) idp = identity_provider(request) - initial_delay_in_ms = cred_provider_kwargs.get("initial_delay_in_ms", 0) - block_for_initial = cred_provider_kwargs.get("block_for_initial", False) expiration_refresh_ratio = cred_provider_kwargs.get( - "expiration_refresh_ratio", TokenAuthConfig.DEFAULT_EXPIRATION_REFRESH_RATIO + "expiration_refresh_ratio", DEFAULT_EXPIRATION_REFRESH_RATIO ) lower_refresh_bound_millis = cred_provider_kwargs.get( - "lower_refresh_bound_millis", TokenAuthConfig.DEFAULT_LOWER_REFRESH_BOUND_MILLIS - ) - max_attempts = cred_provider_kwargs.get( - "max_attempts", TokenAuthConfig.DEFAULT_MAX_ATTEMPTS + "lower_refresh_bound_millis", DEFAULT_LOWER_REFRESH_BOUND_MILLIS ) - delay_in_ms = cred_provider_kwargs.get( - "delay_in_ms", TokenAuthConfig.DEFAULT_DELAY_IN_MS + max_attempts = cred_provider_kwargs.get("max_attempts", DEFAULT_MAX_ATTEMPTS) + delay_in_ms = cred_provider_kwargs.get("delay_in_ms", DEFAULT_DELAY_IN_MS) + + token_mgr_config = TokenManagerConfig( + expiration_refresh_ratio=expiration_refresh_ratio, + lower_refresh_bound_millis=lower_refresh_bound_millis, + token_request_execution_timeout_in_ms=DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS, # noqa + retry_policy=RetryPolicy( + max_attempts=max_attempts, + delay_in_ms=delay_in_ms, + ), ) - auth_config = TokenAuthConfig(idp) - auth_config.expiration_refresh_ratio = expiration_refresh_ratio - auth_config.lower_refresh_bound_millis = lower_refresh_bound_millis - auth_config.max_attempts = max_attempts - auth_config.delay_in_ms = delay_in_ms - return EntraIdCredentialsProvider( - config=auth_config, - initial_delay_in_ms=initial_delay_in_ms, - block_for_initial=block_for_initial, + identity_provider=idp, + token_manager_config=token_mgr_config, + initial_delay_in_ms=delay_in_ms, )