Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 34 additions & 3 deletions sdk/cosmos/azure-cosmos/azure/cosmos/_auth_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,31 @@
# Licensed under the MIT License. See LICENSE.txt in the project root for
# license information.
# -------------------------------------------------------------------------
from typing import TypeVar, Any, MutableMapping, cast, Optional
import logging
import threading
from typing import TypeVar, Any, MutableMapping, cast, Optional, Union
from weakref import WeakKeyDictionary

from azure.core.pipeline import PipelineRequest
from azure.core.pipeline.policies import BearerTokenCredentialPolicy
from azure.core.pipeline.transport import HttpRequest as LegacyHttpRequest
from azure.core.rest import HttpRequest
from azure.core.credentials import AccessToken
from azure.core.credentials import AccessToken, TokenCredential, SupportsTokenInfo
from azure.core.exceptions import HttpResponseError

from .http_constants import HttpHeaders
from ._constants import _Constants as Constants

HTTPRequestType = TypeVar("HTTPRequestType", HttpRequest, LegacyHttpRequest)
logger = logging.getLogger("azure.cosmos.CosmosBearerTokenCredentialPolicy")
_credential_locks: "WeakKeyDictionary[Union[TokenCredential, SupportsTokenInfo], threading.RLock]" = WeakKeyDictionary()

def _get_credential_lock(credential: Union[TokenCredential, SupportsTokenInfo]) -> threading.RLock:
lock = _credential_locks.get(credential)
if lock is None:
lock = threading.RLock()
_credential_locks[credential] = lock
return lock

# NOTE: This class accesses protected members (_scopes, _token) of the parent class
# to implement fallback and scope-switching logic not exposed by the public API.
Expand All @@ -24,11 +36,25 @@
class CosmosBearerTokenCredentialPolicy(BearerTokenCredentialPolicy):
AadDefaultScope = Constants.AAD_DEFAULT_SCOPE

def __init__(self, credential, account_scope: str, override_scope: Optional[str] = None):
def __init__(
self,
credential: Union[TokenCredential, SupportsTokenInfo],
account_scope: str,
override_scope: Optional[str] = None
) -> None:
self._account_scope = account_scope
self._override_scope = override_scope
self._current_scope = override_scope or account_scope
self._credential_lock = _get_credential_lock(credential)
super().__init__(credential, self._current_scope)
# initialize the cache by requesting a token for the current scope (thread-safe)
with self._credential_lock:
try:
self._get_token(self._current_scope)
# don't fail on filling up the cache
except Exception: #pylint: disable=broad-exception-caught
logger.warning("Failed to acquire initial token for scope '%s'. Cache was not populated.",
self._current_scope, exc_info=True)

@staticmethod
def _update_headers(headers: MutableMapping[str, str], token: str) -> None:
Expand Down Expand Up @@ -61,6 +87,11 @@ def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
self._current_scope != self.AadDefaultScope and
"AADSTS500011" in str(ex)
):
logger.warning(
"Received AADSTS500011 error when using scope '%s'. Falling back to default scope '%s'.",
self._current_scope,
self.AadDefaultScope
)
self._scopes = (self.AadDefaultScope,)
self._current_scope = self.AadDefaultScope
tried_fallback = True
Expand Down
37 changes: 34 additions & 3 deletions sdk/cosmos/azure-cosmos/azure/cosmos/aio/_auth_policy_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
# Licensed under the MIT License. See LICENSE.txt in the project root for
# license information.
# -------------------------------------------------------------------------
import asyncio # pylint: disable=do-not-import-asyncio
import logging
from typing import Any, MutableMapping, TypeVar, cast, Optional, Union
from weakref import WeakKeyDictionary

from typing import Any, MutableMapping, TypeVar, cast, Optional

from azure.core.credentials_async import AsyncTokenCredential, AsyncSupportsTokenInfo
from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy
from azure.core.pipeline import PipelineRequest
from azure.core.pipeline.transport import HttpRequest as LegacyHttpRequest
Expand All @@ -17,6 +20,16 @@
from .._constants import _Constants as Constants

HTTPRequestType = TypeVar("HTTPRequestType", HttpRequest, LegacyHttpRequest)
logger = logging.getLogger("azure.cosmos.AsyncCosmosBearerTokenCredentialPolicy")
_credential_locks: "WeakKeyDictionary[Union[AsyncTokenCredential, AsyncSupportsTokenInfo], asyncio.Lock]" = (
WeakKeyDictionary())

def _get_credential_lock(credential: Union[AsyncTokenCredential, AsyncSupportsTokenInfo]) -> asyncio.Lock:
lock = _credential_locks.get(credential)
if lock is None:
lock = asyncio.Lock()
_credential_locks[credential] = lock
return lock

# NOTE: This class accesses protected members (_scopes, _token) of the parent class
# to implement fallback and scope-switching logic not exposed by the public API.
Expand All @@ -25,12 +38,25 @@
class AsyncCosmosBearerTokenCredentialPolicy(AsyncBearerTokenCredentialPolicy):
AadDefaultScope = Constants.AAD_DEFAULT_SCOPE

def __init__(self, credential, account_scope: str, override_scope: Optional[str] = None):
def __init__(self, credential: AsyncTokenCredential, account_scope: str, override_scope: Optional[str] = None):
self._account_scope = account_scope
self._override_scope = override_scope
self._current_scope = override_scope or account_scope
self._credential_lock = _get_credential_lock(credential)
super().__init__(credential, self._current_scope)

async def setup(self) -> None:
# have to also support get_token info
self._credential_lock = _get_credential_lock(self._credential)
# initialize the cache by requesting a token for the current scope (thread-safe)
async with self._credential_lock:
try:
await self._get_token()
except Exception: #pylint: disable=broad-exception-caught
logger.warning("Failed to acquire initial token for scope '%s'. Cache was not populated.",
self._current_scope, exc_info=True)


@staticmethod
def _update_headers(headers: MutableMapping[str, str], token: str) -> None:
"""Updates the Authorization header with the bearer token.
Expand Down Expand Up @@ -62,6 +88,11 @@ async def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
self._current_scope != self.AadDefaultScope and
"AADSTS500011" in str(ex)
):
logger.warning(
"Received AADSTS500011 error when using scope '%s'. Falling back to default scope '%s'.",
self._current_scope,
self.AadDefaultScope
)
self._scopes = (self.AadDefaultScope,)
self._current_scope = self.AadDefaultScope
tried_fallback = True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,11 +205,11 @@ def __init__( # pylint: disable=too-many-statements
suffix = kwargs.pop('user_agent_suffix', None)
self._user_agent = _utils.get_user_agent_async(suffix)

credentials_policy = None
self.credentials_policy: Optional[AsyncCosmosBearerTokenCredentialPolicy] = None
if self.aad_credentials:
scope_override = os.environ.get(Constants.AAD_SCOPE_OVERRIDE, "")
account_scope = base.create_scope_from_url(self.url_connection)
credentials_policy = AsyncCosmosBearerTokenCredentialPolicy(
self.credentials_policy = AsyncCosmosBearerTokenCredentialPolicy(
self.aad_credentials,
account_scope,
scope_override
Expand All @@ -221,7 +221,7 @@ def __init__( # pylint: disable=too-many-statements
UserAgentPolicy(base_user_agent=self._user_agent, **kwargs),
ContentDecodePolicy(),
retry_policy,
credentials_policy,
self.credentials_policy,
CustomHookPolicy(**kwargs),
NetworkTraceLoggingPolicy(**kwargs),
DistributedTracingPolicy(**kwargs),
Expand Down Expand Up @@ -314,6 +314,9 @@ def _ReadEndpoint(self) -> str:
return self._global_endpoint_manager.get_read_endpoint()

async def _setup(self) -> None:
if self.credentials_policy:
await self.credentials_policy.setup()

if 'database_account' not in self._setup_kwargs:
database_account, _ = await self._global_endpoint_manager._GetDatabaseAccount(
**self._setup_kwargs
Expand Down
2 changes: 1 addition & 1 deletion sdk/cosmos/azure-cosmos/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
packages=find_packages(exclude=exclude_packages),
python_requires=">=3.9",
install_requires=[
"azure-core>=1.30.0",
"azure-core>=1.31.0",
"typing-extensions>=4.6.0"
],
)
40 changes: 37 additions & 3 deletions sdk/cosmos/azure-cosmos/tests/test_aad.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import os
import time
import unittest
from concurrent.futures import ThreadPoolExecutor, as_completed
from io import StringIO

import pytest
Expand Down Expand Up @@ -34,6 +35,11 @@ def get_test_item(num):


class CosmosEmulatorCredential(object):
def __init__(self):
self.token = None
# used to verify that get_token was called only once with concurrent clients
self.counter = 0

def get_token(self, *scopes, **kwargs):
# type: (*str, **Any) -> AccessToken
"""Request an access token for the emulator. Based on Azure Core's Access Token Credential.
Expand All @@ -47,6 +53,9 @@ def get_token(self, *scopes, **kwargs):
:raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message``
attribute gives a reason.
"""
if self.token:
return self.token

aad_header_cosmos_emulator = "{\"typ\":\"JWT\",\"alg\":\"RS256\",\"x5t\":\"" \
"CosmosEmulatorPrimaryMaster\",\"kid\":\"CosmosEmulatorPrimaryMaster\"}"

Expand Down Expand Up @@ -82,7 +91,10 @@ def get_token(self, *scopes, **kwargs):
emulator_key_encoded_padded = str(emulator_key_encoded_bytes, "utf-8")
emulator_key_encoded = _remove_padding(emulator_key_encoded_padded)

return AccessToken(first_encoded + "." + second_encoded + "." + emulator_key_encoded, int(time.time() + 7200))
self.counter += 1
# cache token
self.token = AccessToken(first_encoded + "." + second_encoded + "." + emulator_key_encoded, int(time.time() + 7200))
return self.token


@pytest.mark.cosmosEmulator
Expand Down Expand Up @@ -176,10 +188,31 @@ def action(scopes_captured):

scopes, _ = self._run_with_scope_capture(FailingCredential, action)
try:
assert scopes == [override_scope], f"Expected only override scope, got: {scopes}"
assert scopes == [override_scope, override_scope], f"Expected only override scope, got: {scopes}"
finally:
del os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"]

def test_client_warmup(self):
"""When multiple clients are created concurrently, only one token request occurs."""
os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"] = ""

credential = CosmosEmulatorCredential()
def make_client():
client = cosmos_client.CosmosClient(self.host, credential)
return client

clients = []
with ThreadPoolExecutor(max_workers=100) as ex:
futures = [ex.submit(make_client) for _ in range(100)]
for f in as_completed(futures):
try:
clients.append(f.result())
except Exception:
pass

assert credential.counter == 1, f"Expected only one token request, got {credential.counter}"
del os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"]

def test_account_scope_only(self):
"""When account scope is provided, only that scope is used."""
account_scope = "https://localhost/.default"
Expand Down Expand Up @@ -212,10 +245,11 @@ def test_account_scope_fallback_on_error(self):
class FallbackCredential(CosmosEmulatorCredential):
def __init__(self):
self.call_count = 0
super().__init__()

def get_token(self, *scopes, **kwargs):
self.call_count += 1
if self.call_count == 1:
if self.call_count <= 2:
raise HttpResponseError(message="AADSTS500011: Simulated error for fallback")
return super().get_token(*scopes, **kwargs)

Expand Down
39 changes: 36 additions & 3 deletions sdk/cosmos/azure-cosmos/tests/test_aad_async.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# The MIT License (MIT)
# Copyright (c) Microsoft Corporation. All rights reserved.

import asyncio
import base64
import json
import time
Expand Down Expand Up @@ -34,6 +34,11 @@ def get_test_item(num):


class CosmosEmulatorCredential(object):
def __init__(self):
self.token = None
# used to verify that get_token was called only once with concurrent clients
self.counter = 0

async def get_token(self, *scopes, **kwargs):
# type: (*str, **Any) -> AccessToken
"""Request an access token for the emulator. Based on Azure Core's Access Token Credential.
Expand All @@ -47,6 +52,9 @@ async def get_token(self, *scopes, **kwargs):
:raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message``
attribute gives a reason.
"""
if self.token:
return self.token
await asyncio.sleep(.02)
aad_header_cosmos_emulator = "{\"typ\":\"JWT\",\"alg\":\"RS256\",\"x5t\":\"" \
"CosmosEmulatorPrimaryMaster\",\"kid\":\"CosmosEmulatorPrimaryMaster\"}"

Expand Down Expand Up @@ -82,7 +90,10 @@ async def get_token(self, *scopes, **kwargs):
emulator_key_encoded_padded = str(emulator_key_encoded_bytes, "utf-8")
emulator_key_encoded = _remove_padding(emulator_key_encoded_padded)

return AccessToken(first_encoded + "." + second_encoded + "." + emulator_key_encoded, int(time.time() + 7200))
# cache token in credential
self.token = AccessToken(first_encoded + "." + second_encoded + "." + emulator_key_encoded, int(time.time() + 7200))
self.counter += 1
return self.token


@pytest.mark.cosmosEmulator
Expand Down Expand Up @@ -172,6 +183,26 @@ async def action(scopes_captured):
except Exception:
pass

async def test_client_warmup_async(self):
"""When multiple clients are created concurrently, only one token request occurs."""
os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"] = ""

credential = CosmosEmulatorCredential()

async def make_client():
client = CosmosClient(self.host, credential)
await client.__aenter__()
await client.close()

# use asyncio to call make_client concurrently
import asyncio
tasks = [make_client() for _ in range(100)]
await asyncio.gather(*tasks)


assert credential.counter == 1, f"Expected only one token request, got {credential.counter}"
del os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"]

async def test_override_scope_no_fallback_on_error_async(self):
"""When override scope is provided and auth fails, no fallback occurs."""
override_scope = "https://my.custom.scope/.default"
Expand Down Expand Up @@ -239,16 +270,18 @@ async def test_account_scope_fallback_on_error_async(self):
class FallbackCredential(CosmosEmulatorCredential):
def __init__(self):
self.call_count = 0
super().__init__()

async def get_token(self, *scopes, **kwargs):
self.call_count += 1
if self.call_count == 1:
if self.call_count <= 2:
raise HttpResponseError(message="AADSTS500011: Simulated error for fallback")
return await super().get_token(*scopes, **kwargs)

async def action(scopes_captured):
credential = FallbackCredential()
client = CosmosClient(self.host, credential)
await client.__aenter__()
try:
db = client.get_database_client(self.configs.TEST_DATABASE_ID)
container = db.get_container_client(self.configs.TEST_SINGLE_PARTITION_CONTAINER_ID)
Expand Down
Loading