|
4 | 4 | import random |
5 | 5 | import time |
6 | 6 | from datetime import datetime, timezone |
7 | | -from enum import Enum |
8 | | -from typing import Callable, TypeVar, Union |
| 7 | +from typing import Callable, TypeVar |
9 | 8 | from unittest import mock |
10 | 9 | from unittest.mock import Mock |
11 | 10 | from urllib.parse import urlparse |
|
17 | 16 | from redis import Sentinel |
18 | 17 | from redis.auth.idp import IdentityProviderInterface |
19 | 18 | from redis.auth.token import JWToken |
20 | | -from redis.auth.token_manager import RetryPolicy, TokenManagerConfig |
21 | 19 | from redis.backoff import NoBackoff |
22 | 20 | from redis.cache import ( |
23 | 21 | CacheConfig, |
|
30 | 28 | from redis.credentials import CredentialProvider |
31 | 29 | from redis.exceptions import RedisClusterException |
32 | 30 | from redis.retry import Retry |
33 | | -from redis_entraid.cred_provider import ( |
34 | | - DEFAULT_DELAY_IN_MS, |
35 | | - DEFAULT_EXPIRATION_REFRESH_RATIO, |
36 | | - DEFAULT_LOWER_REFRESH_BOUND_MILLIS, |
37 | | - DEFAULT_MAX_ATTEMPTS, |
38 | | - DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS, |
39 | | - EntraIdCredentialsProvider, |
40 | | -) |
41 | | -from redis_entraid.identity_provider import ( |
42 | | - ManagedIdentityIdType, |
43 | | - ManagedIdentityProviderConfig, |
44 | | - ManagedIdentityType, |
45 | | - ServicePrincipalIdentityProviderConfig, |
46 | | - _create_provider_from_managed_identity, |
47 | | - _create_provider_from_service_principal, |
48 | | -) |
49 | 31 | from tests.ssl_utils import get_tls_certificates |
50 | 32 |
|
51 | 33 | REDIS_INFO = {} |
|
61 | 43 | _TestDecorator = Callable[[_DecoratedTest], _DecoratedTest] |
62 | 44 |
|
63 | 45 |
|
64 | | -class AuthType(Enum): |
65 | | - MANAGED_IDENTITY = "managed_identity" |
66 | | - SERVICE_PRINCIPAL = "service_principal" |
67 | | - |
68 | | - |
69 | 46 | # Taken from python3.9 |
70 | 47 | class BooleanOptionalAction(argparse.Action): |
71 | 48 | def __init__( |
@@ -623,124 +600,21 @@ def mock_identity_provider() -> IdentityProviderInterface: |
623 | 600 | return mock_provider |
624 | 601 |
|
625 | 602 |
|
626 | | -def identity_provider(request) -> IdentityProviderInterface: |
627 | | - if hasattr(request, "param"): |
628 | | - kwargs = request.param.get("idp_kwargs", {}) |
629 | | - else: |
630 | | - kwargs = {} |
631 | | - |
632 | | - if request.param.get("mock_idp", None) is not None: |
633 | | - return mock_identity_provider() |
634 | | - |
635 | | - auth_type = kwargs.pop("auth_type", AuthType.SERVICE_PRINCIPAL) |
636 | | - config = get_identity_provider_config(request=request) |
637 | | - |
638 | | - if auth_type == "MANAGED_IDENTITY": |
639 | | - return _create_provider_from_managed_identity(config) |
640 | | - |
641 | | - return _create_provider_from_service_principal(config) |
642 | | - |
643 | | - |
644 | | -def get_identity_provider_config( |
645 | | - request, |
646 | | -) -> Union[ManagedIdentityProviderConfig, ServicePrincipalIdentityProviderConfig]: |
647 | | - if hasattr(request, "param"): |
648 | | - kwargs = request.param.get("idp_kwargs", {}) |
649 | | - else: |
650 | | - kwargs = {} |
651 | | - |
652 | | - auth_type = kwargs.pop("auth_type", AuthType.SERVICE_PRINCIPAL) |
653 | | - |
654 | | - if auth_type == AuthType.MANAGED_IDENTITY: |
655 | | - return _get_managed_identity_provider_config(request) |
656 | | - |
657 | | - return _get_service_principal_provider_config(request) |
658 | | - |
659 | | - |
660 | | -def _get_managed_identity_provider_config(request) -> ManagedIdentityProviderConfig: |
661 | | - resource = os.getenv("AZURE_RESOURCE") |
662 | | - id_value = os.getenv("AZURE_USER_ASSIGNED_MANAGED_ID", None) |
663 | | - |
664 | | - if hasattr(request, "param"): |
665 | | - kwargs = request.param.get("idp_kwargs", {}) |
666 | | - else: |
667 | | - kwargs = {} |
668 | | - |
669 | | - identity_type = kwargs.pop("identity_type", ManagedIdentityType.SYSTEM_ASSIGNED) |
670 | | - id_type = kwargs.pop("id_type", ManagedIdentityIdType.OBJECT_ID) |
671 | | - |
672 | | - return ManagedIdentityProviderConfig( |
673 | | - identity_type=identity_type, |
674 | | - resource=resource, |
675 | | - id_type=id_type, |
676 | | - id_value=id_value, |
677 | | - kwargs=kwargs, |
678 | | - ) |
679 | | - |
680 | | - |
681 | | -def _get_service_principal_provider_config( |
682 | | - request, |
683 | | -) -> ServicePrincipalIdentityProviderConfig: |
684 | | - client_id = os.getenv("AZURE_CLIENT_ID") |
685 | | - client_credential = os.getenv("AZURE_CLIENT_SECRET") |
686 | | - tenant_id = os.getenv("AZURE_TENANT_ID") |
687 | | - scopes = os.getenv("AZURE_REDIS_SCOPES", None) |
688 | | - |
689 | | - if hasattr(request, "param"): |
690 | | - kwargs = request.param.get("idp_kwargs", {}) |
691 | | - token_kwargs = request.param.get("token_kwargs", {}) |
692 | | - timeout = request.param.get("timeout", None) |
693 | | - else: |
694 | | - kwargs = {} |
695 | | - token_kwargs = {} |
696 | | - timeout = None |
697 | | - |
698 | | - if isinstance(scopes, str): |
699 | | - scopes = scopes.split(",") |
700 | | - |
701 | | - return ServicePrincipalIdentityProviderConfig( |
702 | | - client_id=client_id, |
703 | | - client_credential=client_credential, |
704 | | - scopes=scopes, |
705 | | - timeout=timeout, |
706 | | - token_kwargs=token_kwargs, |
707 | | - tenant_id=tenant_id, |
708 | | - app_kwargs=kwargs, |
709 | | - ) |
710 | | - |
711 | | - |
712 | 603 | def get_credential_provider(request) -> CredentialProvider: |
713 | 604 | cred_provider_class = request.param.get("cred_provider_class") |
714 | 605 | cred_provider_kwargs = request.param.get("cred_provider_kwargs", {}) |
715 | 606 |
|
716 | | - if cred_provider_class != EntraIdCredentialsProvider: |
| 607 | + if not cred_provider_class: |
| 608 | + pytest.skip("No credential provider class specified in the test") |
| 609 | + |
| 610 | + # Since we can't import EntraIdCredentialsProvider in this module, |
| 611 | + # we'll just check the class name. |
| 612 | + if cred_provider_class.__name__ != "EntraIdCredentialsProvider": |
717 | 613 | return cred_provider_class(**cred_provider_kwargs) |
718 | 614 |
|
719 | | - idp = identity_provider(request) |
720 | | - expiration_refresh_ratio = cred_provider_kwargs.get( |
721 | | - "expiration_refresh_ratio", DEFAULT_EXPIRATION_REFRESH_RATIO |
722 | | - ) |
723 | | - lower_refresh_bound_millis = cred_provider_kwargs.get( |
724 | | - "lower_refresh_bound_millis", DEFAULT_LOWER_REFRESH_BOUND_MILLIS |
725 | | - ) |
726 | | - max_attempts = cred_provider_kwargs.get("max_attempts", DEFAULT_MAX_ATTEMPTS) |
727 | | - delay_in_ms = cred_provider_kwargs.get("delay_in_ms", DEFAULT_DELAY_IN_MS) |
728 | | - |
729 | | - token_mgr_config = TokenManagerConfig( |
730 | | - expiration_refresh_ratio=expiration_refresh_ratio, |
731 | | - lower_refresh_bound_millis=lower_refresh_bound_millis, |
732 | | - token_request_execution_timeout_in_ms=DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS, # noqa |
733 | | - retry_policy=RetryPolicy( |
734 | | - max_attempts=max_attempts, |
735 | | - delay_in_ms=delay_in_ms, |
736 | | - ), |
737 | | - ) |
| 615 | + from tests.entraid_utils import get_entra_id_credentials_provider |
738 | 616 |
|
739 | | - return EntraIdCredentialsProvider( |
740 | | - identity_provider=idp, |
741 | | - token_manager_config=token_mgr_config, |
742 | | - initial_delay_in_ms=delay_in_ms, |
743 | | - ) |
| 617 | + return get_entra_id_credentials_provider(request, cred_provider_kwargs) |
744 | 618 |
|
745 | 619 |
|
746 | 620 | @pytest.fixture() |
|
0 commit comments