diff --git a/httpx_auth/authentication.py b/httpx_auth/authentication.py index 6540a5f..a959f01 100644 --- a/httpx_auth/authentication.py +++ b/httpx_auth/authentication.py @@ -3,7 +3,7 @@ import uuid from hashlib import sha256, sha512 from urllib.parse import parse_qs, urlsplit, urlunsplit, urlencode -from typing import Optional, Generator +from typing import Optional, Generator, Union, AsyncGenerator import httpx @@ -59,11 +59,9 @@ def _get_query_parameter(url: str, param_name: str) -> Optional[str]: return all_values[0] if all_values else None -def request_new_grant_with_post( - url: str, data, grant_name: str, client: httpx.Client +def process_new_grant_response( + response: httpx.Response, grant_name: str, ) -> (str, int): - response = client.post(url, data=data) - if response.is_error: # As described in https://tools.ietf.org/html/rfc6749#section-5.2 raise InvalidGrantRequest(response) @@ -152,6 +150,8 @@ def __init__(self, token_url: str, username: str, password: str, **kwargs): reaches the actual server. Set it to 0 to deactivate this feature and use the same token until actual expiry. :param client: httpx.Client instance that will be used to request the token. Use it to provide a custom proxying rule for instance. + :param async_client: httpx.AsyncClient instance that will be used to request the token. + Use it to provide a custom proxying rule for instance. :param kwargs: all additional authorization parameters that should be put as body parameters in the token URL. """ self.token_url = token_url @@ -175,6 +175,7 @@ def __init__(self, token_url: str, username: str, password: str, **kwargs): # Time is expressed in seconds self.timeout = int(kwargs.pop("timeout", None) or 60) self.client = kwargs.pop("client", None) + self.async_client = kwargs.pop("async_client", None) # As described in https://tools.ietf.org/html/rfc6749#section-4.3.2 self.data = { @@ -190,7 +191,7 @@ def __init__(self, token_url: str, username: str, password: str, **kwargs): all_parameters_in_url = _add_parameters(self.token_url, self.data) self.state = sha512(all_parameters_in_url.encode("unicode_escape")).hexdigest() - def auth_flow( + def sync_auth_flow( self, request: httpx.Request ) -> Generator[httpx.Request, httpx.Response, None]: token = OAuth2.token_cache.get_token( @@ -201,13 +202,25 @@ def auth_flow( request.headers[self.header_name] = self.header_value.format(token=token) yield request + async def async_auth_flow( + self, request: httpx.Request + ) -> AsyncGenerator[httpx.Request, httpx.Response]: + token = await OAuth2.token_cache.get_token_async( + self.state, + early_expiry=self.early_expiry, + on_missing_token=self.request_new_token_async, + ) + request.headers[self.header_name] = self.header_value.format(token=token) + yield request + def request_new_token(self) -> tuple: client = self.client or httpx.Client() self._configure_client(client) try: + grant_response = client.post(self.token_url, data=self.data) # As described in https://tools.ietf.org/html/rfc6749#section-4.3.3 - token, expires_in = request_new_grant_with_post( - self.token_url, self.data, self.token_field_name, client + token, expires_in = process_new_grant_response( + grant_response, self.token_field_name ) finally: # Close client only if it was created by this module @@ -216,7 +229,23 @@ def request_new_token(self) -> tuple: # Handle both Access and Bearer tokens return (self.state, token, expires_in) if expires_in else (self.state, token) - def _configure_client(self, client: httpx.Client): + async def request_new_token_async(self) -> tuple: + client = self.async_client or httpx.AsyncClient() + self._configure_client(client) + try: + grant_response = await client.post(self.token_url, data=self.data) + # As described in https://tools.ietf.org/html/rfc6749#section-4.3.3 + token, expires_in = process_new_grant_response( + grant_response, self.token_field_name + ) + finally: + # Close client only if it was created by this module + if self.async_client is None: + await client.aclose() + # Handle both Access and Bearer tokens + return (self.state, token, expires_in) if expires_in else (self.state, token) + + def _configure_client(self, client: Union[httpx.Client, httpx.AsyncClient]): client.auth = (self.username, self.password) client.timeout = self.timeout @@ -248,6 +277,8 @@ def __init__(self, token_url: str, client_id: str, client_secret: str, **kwargs) reaches the actual server. Set it to 0 to deactivate this feature and use the same token until actual expiry. :param client: httpx.Client instance that will be used to request the token. Use it to provide a custom proxying rule for instance. + :param async_client: httpx.AsyncClient instance that will be used to request the token. + Use it to provide a custom proxying rule for instance. :param kwargs: all additional authorization parameters that should be put as query parameter in the token URL. """ self.token_url = token_url @@ -272,6 +303,7 @@ def __init__(self, token_url: str, client_id: str, client_secret: str, **kwargs) self.timeout = int(kwargs.pop("timeout", None) or 60) self.client = kwargs.pop("client", None) + self.async_client = kwargs.pop("async_client", None) # As described in https://tools.ietf.org/html/rfc6749#section-4.4.2 self.data = {"grant_type": "client_credentials"} @@ -283,7 +315,7 @@ def __init__(self, token_url: str, client_id: str, client_secret: str, **kwargs) all_parameters_in_url = _add_parameters(self.token_url, self.data) self.state = sha512(all_parameters_in_url.encode("unicode_escape")).hexdigest() - def auth_flow( + def sync_auth_flow( self, request: httpx.Request ) -> Generator[httpx.Request, httpx.Response, None]: token = OAuth2.token_cache.get_token( @@ -294,13 +326,25 @@ def auth_flow( request.headers[self.header_name] = self.header_value.format(token=token) yield request + async def async_auth_flow( + self, request: httpx.Request + ) -> Generator[httpx.Request, httpx.Response, None]: + token = await OAuth2.token_cache.get_token_async( + self.state, + early_expiry=self.early_expiry, + on_missing_token=self.request_new_token_async, + ) + request.headers[self.header_name] = self.header_value.format(token=token) + yield request + def request_new_token(self) -> tuple: client = self.client or httpx.Client() self._configure_client(client) try: + grant_response = client.post(self.token_url, data=self.data) # As described in https://tools.ietf.org/html/rfc6749#section-4.4.3 - token, expires_in = request_new_grant_with_post( - self.token_url, self.data, self.token_field_name, client + token, expires_in = process_new_grant_response( + grant_response, self.token_field_name ) finally: # Close client only if it was created by this module @@ -309,7 +353,23 @@ def request_new_token(self) -> tuple: # Handle both Access and Bearer tokens return (self.state, token, expires_in) if expires_in else (self.state, token) - def _configure_client(self, client: httpx.Client): + async def request_new_token_async(self) -> tuple: + client = self.async_client or httpx.AsyncClient() + self._configure_client(client) + try: + grant_response = await client.post(self.token_url, data=self.data) + # As described in https://tools.ietf.org/html/rfc6749#section-4.4.3 + token, expires_in = process_new_grant_response( + grant_response, self.token_field_name + ) + finally: + # Close client only if it was created by this module + if self.async_client is None: + await client.aclose() + # Handle both Access and Bearer tokens + return (self.state, token, expires_in) if expires_in else (self.state, token) + + def _configure_client(self, client: Union[httpx.Client, httpx.AsyncClient]): client.auth = (self.client_id, self.client_secret) client.timeout = self.timeout @@ -358,6 +418,8 @@ def __init__(self, authorization_url: str, token_url: str, **kwargs): :param password: User password in case basic authentication should be used to retrieve token. :param client: httpx.Client instance that will be used to request the token. Use it to provide a custom proxying rule for instance. + :param async_client: httpx.AsyncClient instance that will be used to request the token. + Use it to provide a custom proxying rule for instance. :param kwargs: all additional authorization parameters that should be put as query parameter in the authorization URL and as body parameters in the token URL. Usual parameters are: @@ -387,6 +449,7 @@ def __init__(self, authorization_url: str, token_url: str, **kwargs): password = kwargs.pop("password", None) self.auth = (username, password) if username and password else None self.client = kwargs.pop("client", None) + self.async_client = kwargs.pop("async_client", None) # As described in https://tools.ietf.org/html/rfc6749#section-4.1.2 code_field_name = kwargs.pop("code_field_name", "code") @@ -431,7 +494,7 @@ def __init__(self, authorization_url: str, token_url: str, **kwargs): } self.token_data.update(kwargs) - def auth_flow( + def sync_auth_flow( self, request: httpx.Request ) -> Generator[httpx.Request, httpx.Response, None]: token = OAuth2.token_cache.get_token( @@ -442,6 +505,17 @@ def auth_flow( request.headers[self.header_name] = self.header_value.format(token=token) yield request + async def async_sync_auth_flow( + self, request: httpx.Request + ) -> AsyncGenerator[httpx.Request, httpx.Response]: + token = await OAuth2.token_cache.get_token_async( + self.state, + early_expiry=self.early_expiry, + on_missing_token=self.request_new_token_async, + ) + request.headers[self.header_name] = self.header_value.format(token=token) + yield request + def request_new_token(self) -> tuple: # Request code state, code = oauth2_authentication_responses_server.request_new_grant( @@ -454,9 +528,10 @@ def request_new_token(self) -> tuple: client = self.client or httpx.Client() self._configure_client(client) try: + grant_response = client.post(self.token_url, data=self.token_data) # As described in https://tools.ietf.org/html/rfc6749#section-4.1.4 - token, expires_in = request_new_grant_with_post( - self.token_url, self.token_data, self.token_field_name, client + token, expires_in = process_new_grant_response( + grant_response, self.token_field_name ) finally: # Close client only if it was created by this module @@ -465,7 +540,31 @@ def request_new_token(self) -> tuple: # Handle both Access and Bearer tokens return (self.state, token, expires_in) if expires_in else (self.state, token) - def _configure_client(self, client: httpx.Client): + async def request_new_token_async(self) -> tuple: + # Request code + state, code = oauth2_authentication_responses_server.request_new_grant( + self.code_grant_details + ) + + # As described in https://tools.ietf.org/html/rfc6749#section-4.1.3 + self.token_data["code"] = code + + client = self.async_client or httpx.AsyncClient() + self._configure_client(client) + try: + grant_response = await client.post(self.token_url, data=self.token_data) + # As described in https://tools.ietf.org/html/rfc6749#section-4.1.4 + token, expires_in = process_new_grant_response( + grant_response, self.token_field_name + ) + finally: + # Close client only if it was created by this module + if self.async_client is None: + await client.aclose() + # Handle both Access and Bearer tokens + return (self.state, token, expires_in) if expires_in else (self.state, token) + + def _configure_client(self, client: Union[httpx.Client, httpx.AsyncClient]): client.auth = self.auth client.timeout = self.timeout @@ -512,6 +611,8 @@ def __init__(self, authorization_url: str, token_url: str, **kwargs): :param code_field_name: Field name containing the code. code by default. :param client: httpx.Client instance that will be used to request the token. Use it to provide a custom proxying rule for instance. + :param async_client: httpx.AsyncClient instance that will be used to request the token. + Use it to provide a custom proxying rule for instance. :param kwargs: all additional authorization parameters that should be put as query parameter in the authorization URL and as body parameters in the token URL. Usual parameters are: @@ -530,6 +631,7 @@ def __init__(self, authorization_url: str, token_url: str, **kwargs): BrowserAuth.__init__(self, kwargs) self.client = kwargs.pop("client", None) + self.async_client = kwargs.pop("async_client", None) self.header_name = kwargs.pop("header_name", None) or "Authorization" self.header_value = kwargs.pop("header_value", None) or "Bearer {token}" @@ -596,7 +698,7 @@ def __init__(self, authorization_url: str, token_url: str, **kwargs): } self.token_data.update(kwargs) - def auth_flow( + def sync_auth_flow( self, request: httpx.Request ) -> Generator[httpx.Request, httpx.Response, None]: token = OAuth2.token_cache.get_token( @@ -607,6 +709,17 @@ def auth_flow( request.headers[self.header_name] = self.header_value.format(token=token) yield request + async def async_auth_flow( + self, request: httpx.Request + ) -> AsyncGenerator[httpx.Request, httpx.Response]: + token = await OAuth2.token_cache.get_token_async( + self.state, + early_expiry=self.early_expiry, + on_missing_token=self.request_new_token_async, + ) + request.headers[self.header_name] = self.header_value.format(token=token) + yield request + def request_new_token(self) -> tuple: # Request code state, code = oauth2_authentication_responses_server.request_new_grant( @@ -619,9 +732,10 @@ def request_new_token(self) -> tuple: client = self.client or httpx.Client() self._configure_client(client) try: + grant_response = client.post(self.token_url, data=self.token_data) # As described in https://tools.ietf.org/html/rfc6749#section-4.1.4 - token, expires_in = request_new_grant_with_post( - self.token_url, self.token_data, self.token_field_name, client + token, expires_in = process_new_grant_response( + grant_response, self.token_field_name ) finally: # Close client only if it was created by this module @@ -630,7 +744,31 @@ def request_new_token(self) -> tuple: # Handle both Access and Bearer tokens return (self.state, token, expires_in) if expires_in else (self.state, token) - def _configure_client(self, client: httpx.Client): + async def request_new_token_async(self) -> tuple: + # Request code + state, code = oauth2_authentication_responses_server.request_new_grant( + self.code_grant_details + ) + + # As described in https://tools.ietf.org/html/rfc6749#section-4.1.3 + self.token_data["code"] = code + + client = self.async_client or httpx.AsyncClient() + self._configure_client(client) + try: + grant_response = await client.post(self.token_url, data=self.token_data) + # As described in https://tools.ietf.org/html/rfc6749#section-4.1.4 + token, expires_in = process_new_grant_response( + grant_response, self.token_field_name + ) + finally: + # Close client only if it was created by this module + if self.async_client is None: + await client.aclose() + # Handle both Access and Bearer tokens + return (self.state, token, expires_in) if expires_in else (self.state, token) + + def _configure_client(self, client: Union[httpx.Client, httpx.AsyncClient]): client.timeout = self.timeout @staticmethod @@ -1207,11 +1345,18 @@ class _MultiAuth(httpx.Auth): def __init__(self, *authentication_modes): self.authentication_modes = authentication_modes - def auth_flow( + def sync_auth_flow( self, request: httpx.Request ) -> Generator[httpx.Request, httpx.Response, None]: for authentication_mode in self.authentication_modes: - next(authentication_mode.auth_flow(request)) + next(authentication_mode.sync_auth_flow(request)) + yield request + + async def async_auth_flow( + self, request: httpx.Request + ) -> AsyncGenerator[httpx.Request, httpx.Response]: + for authentication_mode in self.authentication_modes: + await authentication_mode.async_auth_flow(request).__anext__() yield request def __add__(self, other) -> "_MultiAuth": diff --git a/httpx_auth/oauth2_tokens.py b/httpx_auth/oauth2_tokens.py index 7156865..c29164e 100644 --- a/httpx_auth/oauth2_tokens.py +++ b/httpx_auth/oauth2_tokens.py @@ -1,3 +1,4 @@ +import asyncio import base64 import json import os @@ -46,7 +47,22 @@ class TokenMemoryCache: def __init__(self): self.tokens = {} self.forbid_concurrent_cache_access = threading.Lock() + self._forbid_concurrent_cache_access_async = None self.forbid_concurrent_missing_token_function_call = threading.Lock() + self._forbid_concurrent_missing_token_function_call_async = None + + @property + def forbid_concurrent_cache_access_async(self): + if self._forbid_concurrent_cache_access_async is None: + self._forbid_concurrent_cache_access_async = asyncio.Lock() + return self._forbid_concurrent_cache_access_async + + @property + def forbid_concurrent_missing_token_function_call_async(self): + if self._forbid_concurrent_missing_token_function_call_async is None: + self._forbid_concurrent_missing_token_function_call_async = asyncio.Lock() + return self._forbid_concurrent_missing_token_function_call_async + def _add_bearer_token(self, key: str, token: str): """ @@ -153,6 +169,68 @@ def get_token( ) raise AuthenticationFailed() + async def get_token_async( + self, + key: str, + *, + early_expiry: float = 30.0, + on_missing_token=None, + **on_missing_token_kwargs, + ) -> str: + """ + Return the bearer token. + :param key: key identifier of the token + :param early_expiry: As the time between the token extraction from cache and the token reception on server side + might not higher than one second, on slow networks, token might be expired when received by the actual server, + even if still valid when fetched. + This is the number of seconds to subtract to the actual token expiry. Token will be considered as + expired 30 seconds before real expiry by default. + :param on_missing_token: function to call when token is expired or missing (returning token and expiry tuple) + :param on_missing_token_kwargs: arguments of the function (key-value arguments) + :return: the token + :raise AuthenticationFailed: in case token cannot be retrieved. + """ + logger.debug(f'Retrieving token with "{key}" key.') + async with self.forbid_concurrent_cache_access_async: + self._load_tokens() + if key in self.tokens: + bearer, expiry = self.tokens[key] + if _is_expired(expiry, early_expiry): + logger.debug(f'Authentication token with "{key}" key is expired.') + del self.tokens[key] + else: + logger.debug( + f"Using already received authentication, will expire on {datetime.datetime.utcfromtimestamp(expiry)} (UTC)." + ) + return bearer + + logger.debug("Token cannot be found in cache.") + if on_missing_token is not None: + async with self.forbid_concurrent_missing_token_function_call_async: + new_token = await on_missing_token(**on_missing_token_kwargs) + if len(new_token) == 2: # Bearer token + state, token = new_token + self._add_bearer_token(state, token) + else: # Access Token + state, token, expires_in = new_token + self._add_access_token(state, token, expires_in) + if key != state: + logger.warning( + f"Using a token received on another key than expected. Expecting {key} but was {state}." + ) + async with self.forbid_concurrent_cache_access_async: + if state in self.tokens: + bearer, expiry = self.tokens[state] + logger.debug( + f"Using newly received authentication, expiring on {datetime.datetime.utcfromtimestamp(expiry)} (UTC)." + ) + return bearer + + logger.debug( + f"User was not authenticated: key {key} cannot be found in {self.tokens}." + ) + raise AuthenticationFailed() + def clear(self): """Remove tokens from the cache.""" with self.forbid_concurrent_cache_access: