Skip to content

Commit

Permalink
WIP: try to isolate redis-entraid in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
akx committed Feb 21, 2025
1 parent 2b65eff commit c19d8a4
Show file tree
Hide file tree
Showing 6 changed files with 164 additions and 285 deletions.
2 changes: 1 addition & 1 deletion dev_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ uvloop
vulture>=2.3.0
wheel>=0.30.0
numpy>=1.24.0
redis-entraid==0.3.0b1
#redis-entraid==0.3.0b1
142 changes: 8 additions & 134 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import random
import time
from datetime import datetime, timezone
from enum import Enum
from typing import Callable, TypeVar, Union
from unittest import mock
from unittest.mock import Mock
Expand All @@ -17,7 +16,6 @@
from redis import Sentinel
from redis.auth.idp import IdentityProviderInterface
from redis.auth.token import JWToken
from redis.auth.token_manager import RetryPolicy, TokenManagerConfig
from redis.backoff import NoBackoff
from redis.cache import (
CacheConfig,
Expand All @@ -30,22 +28,6 @@
from redis.credentials import CredentialProvider
from redis.exceptions import RedisClusterException
from redis.retry import Retry
from redis_entraid.cred_provider import (
DEFAULT_DELAY_IN_MS,
DEFAULT_EXPIRATION_REFRESH_RATIO,
DEFAULT_LOWER_REFRESH_BOUND_MILLIS,
DEFAULT_MAX_ATTEMPTS,
DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS,
EntraIdCredentialsProvider,
)
from redis_entraid.identity_provider import (
ManagedIdentityIdType,
ManagedIdentityProviderConfig,
ManagedIdentityType,
ServicePrincipalIdentityProviderConfig,
_create_provider_from_managed_identity,
_create_provider_from_service_principal,
)
from tests.ssl_utils import get_tls_certificates

REDIS_INFO = {}
Expand All @@ -61,11 +43,6 @@
_TestDecorator = Callable[[_DecoratedTest], _DecoratedTest]


class AuthType(Enum):
MANAGED_IDENTITY = "managed_identity"
SERVICE_PRINCIPAL = "service_principal"


# Taken from python3.9
class BooleanOptionalAction(argparse.Action):
def __init__(
Expand Down Expand Up @@ -623,124 +600,21 @@ def mock_identity_provider() -> IdentityProviderInterface:
return mock_provider


def identity_provider(request) -> IdentityProviderInterface:
if hasattr(request, "param"):
kwargs = request.param.get("idp_kwargs", {})
else:
kwargs = {}

if request.param.get("mock_idp", None) is not None:
return mock_identity_provider()

auth_type = kwargs.pop("auth_type", AuthType.SERVICE_PRINCIPAL)
config = get_identity_provider_config(request=request)

if auth_type == "MANAGED_IDENTITY":
return _create_provider_from_managed_identity(config)

return _create_provider_from_service_principal(config)


def get_identity_provider_config(
request,
) -> Union[ManagedIdentityProviderConfig, ServicePrincipalIdentityProviderConfig]:
if hasattr(request, "param"):
kwargs = request.param.get("idp_kwargs", {})
else:
kwargs = {}

auth_type = kwargs.pop("auth_type", AuthType.SERVICE_PRINCIPAL)

if auth_type == AuthType.MANAGED_IDENTITY:
return _get_managed_identity_provider_config(request)

return _get_service_principal_provider_config(request)


def _get_managed_identity_provider_config(request) -> ManagedIdentityProviderConfig:
resource = os.getenv("AZURE_RESOURCE")
id_value = os.getenv("AZURE_USER_ASSIGNED_MANAGED_ID", None)

if hasattr(request, "param"):
kwargs = request.param.get("idp_kwargs", {})
else:
kwargs = {}

identity_type = kwargs.pop("identity_type", ManagedIdentityType.SYSTEM_ASSIGNED)
id_type = kwargs.pop("id_type", ManagedIdentityIdType.OBJECT_ID)

return ManagedIdentityProviderConfig(
identity_type=identity_type,
resource=resource,
id_type=id_type,
id_value=id_value,
kwargs=kwargs,
)


def _get_service_principal_provider_config(
request,
) -> ServicePrincipalIdentityProviderConfig:
client_id = os.getenv("AZURE_CLIENT_ID")
client_credential = os.getenv("AZURE_CLIENT_SECRET")
tenant_id = os.getenv("AZURE_TENANT_ID")
scopes = os.getenv("AZURE_REDIS_SCOPES", None)

if hasattr(request, "param"):
kwargs = request.param.get("idp_kwargs", {})
token_kwargs = request.param.get("token_kwargs", {})
timeout = request.param.get("timeout", None)
else:
kwargs = {}
token_kwargs = {}
timeout = None

if isinstance(scopes, str):
scopes = scopes.split(",")

return ServicePrincipalIdentityProviderConfig(
client_id=client_id,
client_credential=client_credential,
scopes=scopes,
timeout=timeout,
token_kwargs=token_kwargs,
tenant_id=tenant_id,
app_kwargs=kwargs,
)


def get_credential_provider(request) -> CredentialProvider:
cred_provider_class = request.param.get("cred_provider_class")
cred_provider_kwargs = request.param.get("cred_provider_kwargs", {})

if cred_provider_class != EntraIdCredentialsProvider:
if not cred_provider_class:
pytest.skip("No credential provider class specified in the test")

# Since we can't import EntraIdCredentialsProvider in this module,
# we'll just check the class name.
if cred_provider_class.__name__ != "EntraIdCredentialsProvider":
return cred_provider_class(**cred_provider_kwargs)

idp = identity_provider(request)
expiration_refresh_ratio = cred_provider_kwargs.get(
"expiration_refresh_ratio", DEFAULT_EXPIRATION_REFRESH_RATIO
)
lower_refresh_bound_millis = cred_provider_kwargs.get(
"lower_refresh_bound_millis", DEFAULT_LOWER_REFRESH_BOUND_MILLIS
)
max_attempts = cred_provider_kwargs.get("max_attempts", DEFAULT_MAX_ATTEMPTS)
delay_in_ms = cred_provider_kwargs.get("delay_in_ms", DEFAULT_DELAY_IN_MS)

token_mgr_config = TokenManagerConfig(
expiration_refresh_ratio=expiration_refresh_ratio,
lower_refresh_bound_millis=lower_refresh_bound_millis,
token_request_execution_timeout_in_ms=DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS, # noqa
retry_policy=RetryPolicy(
max_attempts=max_attempts,
delay_in_ms=delay_in_ms,
),
)
from tests.entraid_utils import get_entra_id_credentials_provider
return get_entra_id_credentials_provider(request, cred_provider_kwargs)

return EntraIdCredentialsProvider(
identity_provider=idp,
token_manager_config=token_mgr_config,
initial_delay_in_ms=delay_in_ms,
)


@pytest.fixture()
Expand Down
143 changes: 143 additions & 0 deletions tests/entraid_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
import os
from enum import Enum
from typing import Union

from redis_entraid.cred_provider import (
DEFAULT_DELAY_IN_MS,
DEFAULT_EXPIRATION_REFRESH_RATIO,
DEFAULT_LOWER_REFRESH_BOUND_MILLIS,
DEFAULT_MAX_ATTEMPTS,
DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS,
EntraIdCredentialsProvider,
)
from redis_entraid.identity_provider import (
ManagedIdentityIdType,
ManagedIdentityProviderConfig,
ManagedIdentityType,
ServicePrincipalIdentityProviderConfig,
_create_provider_from_managed_identity,
_create_provider_from_service_principal,
)

from redis.auth.idp import IdentityProviderInterface
from redis.auth.token_manager import TokenManagerConfig, RetryPolicy
from tests.conftest import mock_identity_provider


class AuthType(Enum):
MANAGED_IDENTITY = "managed_identity"
SERVICE_PRINCIPAL = "service_principal"




def identity_provider(request) -> IdentityProviderInterface:
if hasattr(request, "param"):
kwargs = request.param.get("idp_kwargs", {})
else:
kwargs = {}

if request.param.get("mock_idp", None) is not None:
return mock_identity_provider()

auth_type = kwargs.pop("auth_type", AuthType.SERVICE_PRINCIPAL)
config = get_identity_provider_config(request=request)

if auth_type == "MANAGED_IDENTITY":
return _create_provider_from_managed_identity(config)

return _create_provider_from_service_principal(config)


def get_identity_provider_config(
request,
) -> Union[ManagedIdentityProviderConfig, ServicePrincipalIdentityProviderConfig]:
if hasattr(request, "param"):
kwargs = request.param.get("idp_kwargs", {})
else:
kwargs = {}

auth_type = kwargs.pop("auth_type", AuthType.SERVICE_PRINCIPAL)

if auth_type == AuthType.MANAGED_IDENTITY:
return _get_managed_identity_provider_config(request)

return _get_service_principal_provider_config(request)


def _get_managed_identity_provider_config(request) -> ManagedIdentityProviderConfig:
resource = os.getenv("AZURE_RESOURCE")
id_value = os.getenv("AZURE_USER_ASSIGNED_MANAGED_ID", None)

if hasattr(request, "param"):
kwargs = request.param.get("idp_kwargs", {})
else:
kwargs = {}

identity_type = kwargs.pop("identity_type", ManagedIdentityType.SYSTEM_ASSIGNED)
id_type = kwargs.pop("id_type", ManagedIdentityIdType.OBJECT_ID)

return ManagedIdentityProviderConfig(
identity_type=identity_type,
resource=resource,
id_type=id_type,
id_value=id_value,
kwargs=kwargs,
)


def _get_service_principal_provider_config(
request,
) -> ServicePrincipalIdentityProviderConfig:
client_id = os.getenv("AZURE_CLIENT_ID")
client_credential = os.getenv("AZURE_CLIENT_SECRET")
tenant_id = os.getenv("AZURE_TENANT_ID")
scopes = os.getenv("AZURE_REDIS_SCOPES", None)

if hasattr(request, "param"):
kwargs = request.param.get("idp_kwargs", {})
token_kwargs = request.param.get("token_kwargs", {})
timeout = request.param.get("timeout", None)
else:
kwargs = {}
token_kwargs = {}
timeout = None

if isinstance(scopes, str):
scopes = scopes.split(",")

return ServicePrincipalIdentityProviderConfig(
client_id=client_id,
client_credential=client_credential,
scopes=scopes,
timeout=timeout,
token_kwargs=token_kwargs,
tenant_id=tenant_id,
app_kwargs=kwargs,
)


def get_entra_id_credentials_provider(request, cred_provider_kwargs):
idp = identity_provider(request)
expiration_refresh_ratio = cred_provider_kwargs.get(
"expiration_refresh_ratio", DEFAULT_EXPIRATION_REFRESH_RATIO
)
lower_refresh_bound_millis = cred_provider_kwargs.get(
"lower_refresh_bound_millis", DEFAULT_LOWER_REFRESH_BOUND_MILLIS
)
max_attempts = cred_provider_kwargs.get("max_attempts", DEFAULT_MAX_ATTEMPTS)
delay_in_ms = cred_provider_kwargs.get("delay_in_ms", DEFAULT_DELAY_IN_MS)
token_mgr_config = TokenManagerConfig(
expiration_refresh_ratio=expiration_refresh_ratio,
lower_refresh_bound_millis=lower_refresh_bound_millis,
token_request_execution_timeout_in_ms=DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS, # noqa
retry_policy=RetryPolicy(
max_attempts=max_attempts,
delay_in_ms=delay_in_ms,
),
)
return EntraIdCredentialsProvider(
identity_provider=idp,
token_manager_config=token_mgr_config,
initial_delay_in_ms=delay_in_ms,
)
Loading

0 comments on commit c19d8a4

Please sign in to comment.