-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add helper to proxy analytics requests
* Handles OIDC token creation and usage transparently
- Loading branch information
1 parent
ad706d6
commit c1d6666
Showing
2 changed files
with
300 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
import pytest | ||
import requests | ||
from unittest import mock | ||
|
||
from awx.main.utils.analytics_proxy import OIDCClient, TokenType, TokenError | ||
|
||
|
||
MOCK_TOKEN_RESPONSE = { | ||
'access_token': 'bob-access-token', | ||
'expires_in': 500, | ||
'refresh_expires_in': 900, | ||
'token_type': 'Bearer', | ||
'not-before-policy': 6, | ||
'scope': 'fake-scope1, fake-scope2', | ||
} | ||
|
||
|
||
@pytest.fixture | ||
def oidc_client(): | ||
''' | ||
oidc client instantiation fixture. | ||
''' | ||
return OIDCClient( | ||
'fake-client-id', | ||
'fake-client-secret', | ||
'https://my-token-url.com/get/a/token/', | ||
['api.console'], | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def token(): | ||
''' | ||
Create Token class out of example OIDC token response. | ||
''' | ||
return OIDCClient._json_response_to_token(MOCK_TOKEN_RESPONSE) | ||
|
||
|
||
def test_generate_access_token(oidc_client): | ||
with mock.patch( | ||
'awx.main.utils.analytics_proxy.requests.post', | ||
return_value=mock.Mock(json=lambda: MOCK_TOKEN_RESPONSE, raise_for_status=mock.Mock(return_value=None)), # No exception raised | ||
): | ||
oidc_client._generate_access_token() | ||
|
||
assert oidc_client.token | ||
assert oidc_client.token.access_token == 'bob-access-token' | ||
assert oidc_client.token.expires_in == 500 | ||
assert oidc_client.token.refresh_expires_in == 900 | ||
assert oidc_client.token.token_type == TokenType.BEARER | ||
assert oidc_client.token.not_before_policy == 6 | ||
assert oidc_client.token.scope == 'fake-scope1, fake-scope2' | ||
|
||
|
||
def test_token_generation_error(oidc_client): | ||
''' | ||
Check that TokenError is raised for failure in token generation process | ||
''' | ||
exception_404 = requests.HTTPError('404 Client Error: Not Found for url') | ||
with mock.patch( | ||
'awx.main.utils.analytics_proxy.requests.post', | ||
return_value=mock.Mock(status_code=404, json=mock.Mock(return_value={'error': 'Not Found'}), raise_for_status=mock.Mock(side_effect=exception_404)), | ||
): | ||
with pytest.raises(TokenError) as exc_info: | ||
oidc_client._generate_access_token() | ||
|
||
assert exc_info.value.__cause__ == exception_404 | ||
|
||
|
||
def test_make_request(oidc_client, token): | ||
''' | ||
Check that make_request makes an http request with a generated token. | ||
''' | ||
|
||
def fake_generate_access_token(): | ||
oidc_client.token = token | ||
|
||
with ( | ||
mock.patch.object(oidc_client, '_generate_access_token', side_effect=fake_generate_access_token), | ||
mock.patch('awx.main.utils.analytics_proxy.requests.request') as mock_request, | ||
): | ||
oidc_client.make_request('GET', 'https://does_not_exist.com') | ||
|
||
mock_request.assert_called_with( | ||
'GET', | ||
'https://does_not_exist.com', | ||
headers={ | ||
'Authorization': f'Bearer {token.access_token}', | ||
'Accept': 'application/json', | ||
}, | ||
) | ||
|
||
|
||
def test_make_request_existing_token(oidc_client, token): | ||
''' | ||
Check that make_request does not try and generate a token. | ||
''' | ||
oidc_client.token = token | ||
|
||
with ( | ||
mock.patch.object(oidc_client, '_generate_access_token', side_effect=RuntimeError('expected not to be called')), | ||
mock.patch('awx.main.utils.analytics_proxy.requests.request') as mock_request, | ||
): | ||
oidc_client.make_request('GET', 'https://does_not_exist.com') | ||
|
||
mock_request.assert_called_with( | ||
'GET', | ||
'https://does_not_exist.com', | ||
headers={ | ||
'Authorization': f'Bearer {token.access_token}', | ||
'Accept': 'application/json', | ||
}, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,187 @@ | ||
''' | ||
Proxy requests Analytics requests | ||
''' | ||
|
||
import time | ||
|
||
from dataclasses import dataclass | ||
from enum import Enum | ||
|
||
# from jose import jwt | ||
from typing import Optional, Any | ||
|
||
import requests | ||
|
||
|
||
class TokenError(requests.RequestException): | ||
''' | ||
Raised when token generation request fails. | ||
Useful for differentiating request failure for make_request() vs. | ||
other requests issued to get a token i.e.: | ||
try: | ||
client = OIDCClient(...) | ||
client.make_request(...) | ||
except TokenGenerationError as e: | ||
print(f"Token generation failed due to {e.__cause__}") | ||
except requests.RequestException: | ||
print("API request failed) | ||
''' | ||
|
||
def __init__(self, message="Token generation request failed", response=None): | ||
super().__init__(message) | ||
self.response = response # Store the response for debugging | ||
|
||
|
||
def _now(reason: str): | ||
''' | ||
Wrapper for time. Helps with testing. | ||
''' | ||
return int(time.time()) | ||
|
||
|
||
class TokenType(Enum): | ||
''' | ||
Access token type as returned by the remote API. | ||
''' | ||
|
||
BEARER = 'Bearer' | ||
|
||
|
||
@dataclass | ||
class Token: | ||
''' | ||
Token data generated by OIDC response. | ||
''' | ||
|
||
access_token: str | ||
expires_in: int | ||
refresh_expires_in: int | ||
token_type: TokenType | ||
not_before_policy: int # not-before-policy | ||
scope: str | ||
|
||
def __init__( | ||
self, | ||
access_token: str, | ||
expires_in: int, | ||
refresh_expires_in: int, | ||
token_type: TokenType, | ||
not_before_policy: int, | ||
scope: str, | ||
): | ||
self.access_token = access_token | ||
self.expires_in = expires_in | ||
self.refresh_expires_in = refresh_expires_in | ||
self.token_type = token_type | ||
self.not_before_policy = not_before_policy | ||
self.scope = scope | ||
|
||
self._now = _now(reason='token-creation') | ||
|
||
@property | ||
def expires_at(self) -> int: | ||
''' | ||
Unix timestamp in seconds of when the token expires. | ||
''' | ||
return self._now + self.expires_in | ||
|
||
def is_expired(self) -> bool: | ||
''' | ||
Check if the token is expired. | ||
''' | ||
return _now(reason='token-expiration-check') >= self.expires_at | ||
|
||
|
||
class OIDCClient: | ||
''' | ||
Wraps requests library make_request() and manages OIDC access token. | ||
''' | ||
|
||
def __init__( | ||
self, | ||
client_id: str, | ||
client_secret: str, | ||
token_url: str, | ||
scopes: list[str], | ||
base_url: str = '', | ||
) -> None: | ||
self.client_id: str = client_id | ||
self.client_secret: str = client_secret | ||
self.token_url: str = token_url | ||
self.scopes = scopes | ||
self.base_url: str = base_url | ||
self.token: Optional[Token] = None | ||
|
||
@classmethod | ||
def _json_response_to_token(cls, json_response: Any) -> Token: | ||
return Token( | ||
access_token=json_response['access_token'], | ||
expires_in=json_response['expires_in'], | ||
refresh_expires_in=json_response['refresh_expires_in'], | ||
token_type=TokenType(json_response['token_type']), | ||
not_before_policy=json_response['not-before-policy'], | ||
scope=json_response['scope'], | ||
) | ||
|
||
def _generate_access_token(self) -> None: | ||
''' | ||
Fetches the initial access token using client credentials. | ||
''' | ||
response = requests.post( | ||
self.token_url, | ||
data={ | ||
'grant_type': 'client_credentials', | ||
'client_id': self.client_id, | ||
'client_secret': self.client_secret, | ||
'scope': self.scopes, | ||
}, | ||
headers={'Content-Type': 'application/x-www-form-urlencoded'}, | ||
) | ||
try: | ||
response.raise_for_status() | ||
except requests.RequestException as e: | ||
raise TokenError() from e | ||
self.token = OIDCClient._json_response_to_token(response.json()) | ||
|
||
def _add_headers(self, headers: dict[str, str]) -> None: | ||
''' | ||
Add token header | ||
''' | ||
headers.update( | ||
{ | ||
'Authorization': f'Bearer {self.token.access_token}', | ||
'Accept': 'application/json', | ||
} | ||
) | ||
|
||
def _make_request(self, method: str, url: str, headers: dict[str, str], **kwargs: Any) -> requests.Response: | ||
''' | ||
Actually make an API call. | ||
''' | ||
self._add_headers(headers) | ||
return requests.request(method, url, headers=headers, **kwargs) | ||
|
||
def make_request(self, method: str, endpoint: str, **kwargs: Any) -> requests.Response: | ||
''' | ||
Makes an authenticated request and refreshes the token if expired. | ||
''' | ||
has_generated_token = False | ||
|
||
def generate_access_token(): | ||
self._generate_access_token() | ||
return True | ||
|
||
if not self.token or self.token.is_expired(): | ||
has_generated_token = generate_access_token() | ||
|
||
url = f'{self.base_url}{endpoint}' | ||
headers = kwargs.pop('headers', {}) | ||
|
||
response = self._make_request(method, url, headers, **kwargs) | ||
if not has_generated_token and response.status_code == 401: | ||
generate_access_token() | ||
response = self._make_request(method, url, headers, **kwargs) | ||
|
||
return response |