diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_auth_policy.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_auth_policy.py index 83418e1f375d..ede27a0b4fad 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_auth_policy.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_auth_policy.py @@ -3,19 +3,31 @@ # Licensed under the MIT License. See LICENSE.txt in the project root for # license information. # ------------------------------------------------------------------------- -from typing import TypeVar, Any, MutableMapping, cast, Optional +import logging +import threading +from typing import TypeVar, Any, MutableMapping, cast, Optional, Union +from weakref import WeakKeyDictionary from azure.core.pipeline import PipelineRequest from azure.core.pipeline.policies import BearerTokenCredentialPolicy from azure.core.pipeline.transport import HttpRequest as LegacyHttpRequest from azure.core.rest import HttpRequest -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, TokenCredential, SupportsTokenInfo from azure.core.exceptions import HttpResponseError from .http_constants import HttpHeaders from ._constants import _Constants as Constants HTTPRequestType = TypeVar("HTTPRequestType", HttpRequest, LegacyHttpRequest) +logger = logging.getLogger("azure.cosmos.CosmosBearerTokenCredentialPolicy") +_credential_locks: "WeakKeyDictionary[Union[TokenCredential, SupportsTokenInfo], threading.RLock]" = WeakKeyDictionary() + +def _get_credential_lock(credential: Union[TokenCredential, SupportsTokenInfo]) -> threading.RLock: + lock = _credential_locks.get(credential) + if lock is None: + lock = threading.RLock() + _credential_locks[credential] = lock + return lock # NOTE: This class accesses protected members (_scopes, _token) of the parent class # to implement fallback and scope-switching logic not exposed by the public API. @@ -24,11 +36,25 @@ class CosmosBearerTokenCredentialPolicy(BearerTokenCredentialPolicy): AadDefaultScope = Constants.AAD_DEFAULT_SCOPE - def __init__(self, credential, account_scope: str, override_scope: Optional[str] = None): + def __init__( + self, + credential: Union[TokenCredential, SupportsTokenInfo], + account_scope: str, + override_scope: Optional[str] = None + ) -> None: self._account_scope = account_scope self._override_scope = override_scope self._current_scope = override_scope or account_scope + self._credential_lock = _get_credential_lock(credential) super().__init__(credential, self._current_scope) + # initialize the cache by requesting a token for the current scope (thread-safe) + with self._credential_lock: + try: + self._get_token(self._current_scope) + # don't fail on filling up the cache + except Exception: #pylint: disable=broad-exception-caught + logger.warning("Failed to acquire initial token for scope '%s'. Cache was not populated.", + self._current_scope, exc_info=True) @staticmethod def _update_headers(headers: MutableMapping[str, str], token: str) -> None: @@ -61,6 +87,11 @@ def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None: self._current_scope != self.AadDefaultScope and "AADSTS500011" in str(ex) ): + logger.warning( + "Received AADSTS500011 error when using scope '%s'. Falling back to default scope '%s'.", + self._current_scope, + self.AadDefaultScope + ) self._scopes = (self.AadDefaultScope,) self._current_scope = self.AadDefaultScope tried_fallback = True diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_auth_policy_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_auth_policy_async.py index ea1a86b120a1..708bd5bc5301 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_auth_policy_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_auth_policy_async.py @@ -3,9 +3,12 @@ # Licensed under the MIT License. See LICENSE.txt in the project root for # license information. # ------------------------------------------------------------------------- +import asyncio # pylint: disable=do-not-import-asyncio +import logging +from typing import Any, MutableMapping, TypeVar, cast, Optional, Union +from weakref import WeakKeyDictionary -from typing import Any, MutableMapping, TypeVar, cast, Optional - +from azure.core.credentials_async import AsyncTokenCredential, AsyncSupportsTokenInfo from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy from azure.core.pipeline import PipelineRequest from azure.core.pipeline.transport import HttpRequest as LegacyHttpRequest @@ -17,6 +20,16 @@ from .._constants import _Constants as Constants HTTPRequestType = TypeVar("HTTPRequestType", HttpRequest, LegacyHttpRequest) +logger = logging.getLogger("azure.cosmos.AsyncCosmosBearerTokenCredentialPolicy") +_credential_locks: "WeakKeyDictionary[Union[AsyncTokenCredential, AsyncSupportsTokenInfo], asyncio.Lock]" = ( + WeakKeyDictionary()) + +def _get_credential_lock(credential: Union[AsyncTokenCredential, AsyncSupportsTokenInfo]) -> asyncio.Lock: + lock = _credential_locks.get(credential) + if lock is None: + lock = asyncio.Lock() + _credential_locks[credential] = lock + return lock # NOTE: This class accesses protected members (_scopes, _token) of the parent class # to implement fallback and scope-switching logic not exposed by the public API. @@ -25,12 +38,25 @@ class AsyncCosmosBearerTokenCredentialPolicy(AsyncBearerTokenCredentialPolicy): AadDefaultScope = Constants.AAD_DEFAULT_SCOPE - def __init__(self, credential, account_scope: str, override_scope: Optional[str] = None): + def __init__(self, credential: AsyncTokenCredential, account_scope: str, override_scope: Optional[str] = None): self._account_scope = account_scope self._override_scope = override_scope self._current_scope = override_scope or account_scope + self._credential_lock = _get_credential_lock(credential) super().__init__(credential, self._current_scope) + async def setup(self) -> None: + # have to also support get_token info + self._credential_lock = _get_credential_lock(self._credential) + # initialize the cache by requesting a token for the current scope (thread-safe) + async with self._credential_lock: + try: + await self._get_token() + except Exception: #pylint: disable=broad-exception-caught + logger.warning("Failed to acquire initial token for scope '%s'. Cache was not populated.", + self._current_scope, exc_info=True) + + @staticmethod def _update_headers(headers: MutableMapping[str, str], token: str) -> None: """Updates the Authorization header with the bearer token. @@ -62,6 +88,11 @@ async def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None: self._current_scope != self.AadDefaultScope and "AADSTS500011" in str(ex) ): + logger.warning( + "Received AADSTS500011 error when using scope '%s'. Falling back to default scope '%s'.", + self._current_scope, + self.AadDefaultScope + ) self._scopes = (self.AadDefaultScope,) self._current_scope = self.AadDefaultScope tried_fallback = True diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py index 0035af9ee9df..4f8ce4f35484 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py @@ -205,11 +205,11 @@ def __init__( # pylint: disable=too-many-statements suffix = kwargs.pop('user_agent_suffix', None) self._user_agent = _utils.get_user_agent_async(suffix) - credentials_policy = None + self.credentials_policy: Optional[AsyncCosmosBearerTokenCredentialPolicy] = None if self.aad_credentials: scope_override = os.environ.get(Constants.AAD_SCOPE_OVERRIDE, "") account_scope = base.create_scope_from_url(self.url_connection) - credentials_policy = AsyncCosmosBearerTokenCredentialPolicy( + self.credentials_policy = AsyncCosmosBearerTokenCredentialPolicy( self.aad_credentials, account_scope, scope_override @@ -221,7 +221,7 @@ def __init__( # pylint: disable=too-many-statements UserAgentPolicy(base_user_agent=self._user_agent, **kwargs), ContentDecodePolicy(), retry_policy, - credentials_policy, + self.credentials_policy, CustomHookPolicy(**kwargs), NetworkTraceLoggingPolicy(**kwargs), DistributedTracingPolicy(**kwargs), @@ -314,6 +314,9 @@ def _ReadEndpoint(self) -> str: return self._global_endpoint_manager.get_read_endpoint() async def _setup(self) -> None: + if self.credentials_policy: + await self.credentials_policy.setup() + if 'database_account' not in self._setup_kwargs: database_account, _ = await self._global_endpoint_manager._GetDatabaseAccount( **self._setup_kwargs diff --git a/sdk/cosmos/azure-cosmos/setup.py b/sdk/cosmos/azure-cosmos/setup.py index 1a96740eac2c..95ae4f832929 100644 --- a/sdk/cosmos/azure-cosmos/setup.py +++ b/sdk/cosmos/azure-cosmos/setup.py @@ -73,7 +73,7 @@ packages=find_packages(exclude=exclude_packages), python_requires=">=3.9", install_requires=[ - "azure-core>=1.30.0", + "azure-core>=1.31.0", "typing-extensions>=4.6.0" ], ) diff --git a/sdk/cosmos/azure-cosmos/tests/test_aad.py b/sdk/cosmos/azure-cosmos/tests/test_aad.py index b1d593bd96a6..dc694d5b78c7 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_aad.py +++ b/sdk/cosmos/azure-cosmos/tests/test_aad.py @@ -6,6 +6,7 @@ import os import time import unittest +from concurrent.futures import ThreadPoolExecutor, as_completed from io import StringIO import pytest @@ -34,6 +35,11 @@ def get_test_item(num): class CosmosEmulatorCredential(object): + def __init__(self): + self.token = None + # used to verify that get_token was called only once with concurrent clients + self.counter = 0 + def get_token(self, *scopes, **kwargs): # type: (*str, **Any) -> AccessToken """Request an access token for the emulator. Based on Azure Core's Access Token Credential. @@ -47,6 +53,9 @@ def get_token(self, *scopes, **kwargs): :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message`` attribute gives a reason. """ + if self.token: + return self.token + aad_header_cosmos_emulator = "{\"typ\":\"JWT\",\"alg\":\"RS256\",\"x5t\":\"" \ "CosmosEmulatorPrimaryMaster\",\"kid\":\"CosmosEmulatorPrimaryMaster\"}" @@ -82,7 +91,10 @@ def get_token(self, *scopes, **kwargs): emulator_key_encoded_padded = str(emulator_key_encoded_bytes, "utf-8") emulator_key_encoded = _remove_padding(emulator_key_encoded_padded) - return AccessToken(first_encoded + "." + second_encoded + "." + emulator_key_encoded, int(time.time() + 7200)) + self.counter += 1 + # cache token + self.token = AccessToken(first_encoded + "." + second_encoded + "." + emulator_key_encoded, int(time.time() + 7200)) + return self.token @pytest.mark.cosmosEmulator @@ -176,10 +188,31 @@ def action(scopes_captured): scopes, _ = self._run_with_scope_capture(FailingCredential, action) try: - assert scopes == [override_scope], f"Expected only override scope, got: {scopes}" + assert scopes == [override_scope, override_scope], f"Expected only override scope, got: {scopes}" finally: del os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"] + def test_client_warmup(self): + """When multiple clients are created concurrently, only one token request occurs.""" + os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"] = "" + + credential = CosmosEmulatorCredential() + def make_client(): + client = cosmos_client.CosmosClient(self.host, credential) + return client + + clients = [] + with ThreadPoolExecutor(max_workers=100) as ex: + futures = [ex.submit(make_client) for _ in range(100)] + for f in as_completed(futures): + try: + clients.append(f.result()) + except Exception: + pass + + assert credential.counter == 1, f"Expected only one token request, got {credential.counter}" + del os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"] + def test_account_scope_only(self): """When account scope is provided, only that scope is used.""" account_scope = "https://localhost/.default" @@ -212,10 +245,11 @@ def test_account_scope_fallback_on_error(self): class FallbackCredential(CosmosEmulatorCredential): def __init__(self): self.call_count = 0 + super().__init__() def get_token(self, *scopes, **kwargs): self.call_count += 1 - if self.call_count == 1: + if self.call_count <= 2: raise HttpResponseError(message="AADSTS500011: Simulated error for fallback") return super().get_token(*scopes, **kwargs) diff --git a/sdk/cosmos/azure-cosmos/tests/test_aad_async.py b/sdk/cosmos/azure-cosmos/tests/test_aad_async.py index 6ce3cb4d1124..0161914b78dd 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_aad_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_aad_async.py @@ -1,6 +1,6 @@ # The MIT License (MIT) # Copyright (c) Microsoft Corporation. All rights reserved. - +import asyncio import base64 import json import time @@ -34,6 +34,11 @@ def get_test_item(num): class CosmosEmulatorCredential(object): + def __init__(self): + self.token = None + # used to verify that get_token was called only once with concurrent clients + self.counter = 0 + async def get_token(self, *scopes, **kwargs): # type: (*str, **Any) -> AccessToken """Request an access token for the emulator. Based on Azure Core's Access Token Credential. @@ -47,6 +52,9 @@ async def get_token(self, *scopes, **kwargs): :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message`` attribute gives a reason. """ + if self.token: + return self.token + await asyncio.sleep(.02) aad_header_cosmos_emulator = "{\"typ\":\"JWT\",\"alg\":\"RS256\",\"x5t\":\"" \ "CosmosEmulatorPrimaryMaster\",\"kid\":\"CosmosEmulatorPrimaryMaster\"}" @@ -82,7 +90,10 @@ async def get_token(self, *scopes, **kwargs): emulator_key_encoded_padded = str(emulator_key_encoded_bytes, "utf-8") emulator_key_encoded = _remove_padding(emulator_key_encoded_padded) - return AccessToken(first_encoded + "." + second_encoded + "." + emulator_key_encoded, int(time.time() + 7200)) + # cache token in credential + self.token = AccessToken(first_encoded + "." + second_encoded + "." + emulator_key_encoded, int(time.time() + 7200)) + self.counter += 1 + return self.token @pytest.mark.cosmosEmulator @@ -172,6 +183,26 @@ async def action(scopes_captured): except Exception: pass + async def test_client_warmup_async(self): + """When multiple clients are created concurrently, only one token request occurs.""" + os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"] = "" + + credential = CosmosEmulatorCredential() + + async def make_client(): + client = CosmosClient(self.host, credential) + await client.__aenter__() + await client.close() + + # use asyncio to call make_client concurrently + import asyncio + tasks = [make_client() for _ in range(100)] + await asyncio.gather(*tasks) + + + assert credential.counter == 1, f"Expected only one token request, got {credential.counter}" + del os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"] + async def test_override_scope_no_fallback_on_error_async(self): """When override scope is provided and auth fails, no fallback occurs.""" override_scope = "https://my.custom.scope/.default" @@ -239,16 +270,18 @@ async def test_account_scope_fallback_on_error_async(self): class FallbackCredential(CosmosEmulatorCredential): def __init__(self): self.call_count = 0 + super().__init__() async def get_token(self, *scopes, **kwargs): self.call_count += 1 - if self.call_count == 1: + if self.call_count <= 2: raise HttpResponseError(message="AADSTS500011: Simulated error for fallback") return await super().get_token(*scopes, **kwargs) async def action(scopes_captured): credential = FallbackCredential() client = CosmosClient(self.host, credential) + await client.__aenter__() try: db = client.get_database_client(self.configs.TEST_DATABASE_ID) container = db.get_container_client(self.configs.TEST_SINGLE_PARTITION_CONTAINER_ID)