Skip to content

Commit

Permalink
Add helper to proxy analytics requests
Browse files Browse the repository at this point in the history
* Handles OIDC token creation and usage transparently
  • Loading branch information
chrismeyersfsu committed Jan 30, 2025
1 parent ad706d6 commit c1d6666
Show file tree
Hide file tree
Showing 2 changed files with 300 additions and 0 deletions.
113 changes: 113 additions & 0 deletions awx/main/tests/unit/utils/test_analytics_proxy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import pytest
import requests
from unittest import mock

from awx.main.utils.analytics_proxy import OIDCClient, TokenType, TokenError


MOCK_TOKEN_RESPONSE = {
'access_token': 'bob-access-token',
'expires_in': 500,
'refresh_expires_in': 900,
'token_type': 'Bearer',
'not-before-policy': 6,
'scope': 'fake-scope1, fake-scope2',
}


@pytest.fixture
def oidc_client():
'''
oidc client instantiation fixture.
'''
return OIDCClient(
'fake-client-id',
'fake-client-secret',
'https://my-token-url.com/get/a/token/',
['api.console'],
)


@pytest.fixture
def token():
'''
Create Token class out of example OIDC token response.
'''
return OIDCClient._json_response_to_token(MOCK_TOKEN_RESPONSE)


def test_generate_access_token(oidc_client):
with mock.patch(
'awx.main.utils.analytics_proxy.requests.post',
return_value=mock.Mock(json=lambda: MOCK_TOKEN_RESPONSE, raise_for_status=mock.Mock(return_value=None)), # No exception raised
):
oidc_client._generate_access_token()

assert oidc_client.token
assert oidc_client.token.access_token == 'bob-access-token'
assert oidc_client.token.expires_in == 500
assert oidc_client.token.refresh_expires_in == 900
assert oidc_client.token.token_type == TokenType.BEARER
assert oidc_client.token.not_before_policy == 6
assert oidc_client.token.scope == 'fake-scope1, fake-scope2'


def test_token_generation_error(oidc_client):
'''
Check that TokenError is raised for failure in token generation process
'''
exception_404 = requests.HTTPError('404 Client Error: Not Found for url')
with mock.patch(
'awx.main.utils.analytics_proxy.requests.post',
return_value=mock.Mock(status_code=404, json=mock.Mock(return_value={'error': 'Not Found'}), raise_for_status=mock.Mock(side_effect=exception_404)),
):
with pytest.raises(TokenError) as exc_info:
oidc_client._generate_access_token()

assert exc_info.value.__cause__ == exception_404


def test_make_request(oidc_client, token):
'''
Check that make_request makes an http request with a generated token.
'''

def fake_generate_access_token():
oidc_client.token = token

with (
mock.patch.object(oidc_client, '_generate_access_token', side_effect=fake_generate_access_token),
mock.patch('awx.main.utils.analytics_proxy.requests.request') as mock_request,
):
oidc_client.make_request('GET', 'https://does_not_exist.com')

mock_request.assert_called_with(
'GET',
'https://does_not_exist.com',
headers={
'Authorization': f'Bearer {token.access_token}',
'Accept': 'application/json',
},
)


def test_make_request_existing_token(oidc_client, token):
'''
Check that make_request does not try and generate a token.
'''
oidc_client.token = token

with (
mock.patch.object(oidc_client, '_generate_access_token', side_effect=RuntimeError('expected not to be called')),
mock.patch('awx.main.utils.analytics_proxy.requests.request') as mock_request,
):
oidc_client.make_request('GET', 'https://does_not_exist.com')

mock_request.assert_called_with(
'GET',
'https://does_not_exist.com',
headers={
'Authorization': f'Bearer {token.access_token}',
'Accept': 'application/json',
},
)
187 changes: 187 additions & 0 deletions awx/main/utils/analytics_proxy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
'''
Proxy requests Analytics requests
'''

import time

from dataclasses import dataclass
from enum import Enum

# from jose import jwt
from typing import Optional, Any

import requests


class TokenError(requests.RequestException):
'''
Raised when token generation request fails.
Useful for differentiating request failure for make_request() vs.
other requests issued to get a token i.e.:
try:
client = OIDCClient(...)
client.make_request(...)
except TokenGenerationError as e:
print(f"Token generation failed due to {e.__cause__}")
except requests.RequestException:
print("API request failed)
'''

def __init__(self, message="Token generation request failed", response=None):
super().__init__(message)
self.response = response # Store the response for debugging


def _now(reason: str):
'''
Wrapper for time. Helps with testing.
'''
return int(time.time())


class TokenType(Enum):
'''
Access token type as returned by the remote API.
'''

BEARER = 'Bearer'


@dataclass
class Token:
'''
Token data generated by OIDC response.
'''

access_token: str
expires_in: int
refresh_expires_in: int
token_type: TokenType
not_before_policy: int # not-before-policy
scope: str

def __init__(
self,
access_token: str,
expires_in: int,
refresh_expires_in: int,
token_type: TokenType,
not_before_policy: int,
scope: str,
):
self.access_token = access_token
self.expires_in = expires_in
self.refresh_expires_in = refresh_expires_in
self.token_type = token_type
self.not_before_policy = not_before_policy
self.scope = scope

self._now = _now(reason='token-creation')

@property
def expires_at(self) -> int:
'''
Unix timestamp in seconds of when the token expires.
'''
return self._now + self.expires_in

def is_expired(self) -> bool:
'''
Check if the token is expired.
'''
return _now(reason='token-expiration-check') >= self.expires_at


class OIDCClient:
'''
Wraps requests library make_request() and manages OIDC access token.
'''

def __init__(
self,
client_id: str,
client_secret: str,
token_url: str,
scopes: list[str],
base_url: str = '',
) -> None:
self.client_id: str = client_id
self.client_secret: str = client_secret
self.token_url: str = token_url
self.scopes = scopes
self.base_url: str = base_url
self.token: Optional[Token] = None

@classmethod
def _json_response_to_token(cls, json_response: Any) -> Token:
return Token(
access_token=json_response['access_token'],
expires_in=json_response['expires_in'],
refresh_expires_in=json_response['refresh_expires_in'],
token_type=TokenType(json_response['token_type']),
not_before_policy=json_response['not-before-policy'],
scope=json_response['scope'],
)

def _generate_access_token(self) -> None:
'''
Fetches the initial access token using client credentials.
'''
response = requests.post(
self.token_url,
data={
'grant_type': 'client_credentials',
'client_id': self.client_id,
'client_secret': self.client_secret,
'scope': self.scopes,
},
headers={'Content-Type': 'application/x-www-form-urlencoded'},
)
try:
response.raise_for_status()
except requests.RequestException as e:
raise TokenError() from e
self.token = OIDCClient._json_response_to_token(response.json())

def _add_headers(self, headers: dict[str, str]) -> None:
'''
Add token header
'''
headers.update(
{
'Authorization': f'Bearer {self.token.access_token}',
'Accept': 'application/json',
}
)

def _make_request(self, method: str, url: str, headers: dict[str, str], **kwargs: Any) -> requests.Response:
'''
Actually make an API call.
'''
self._add_headers(headers)
return requests.request(method, url, headers=headers, **kwargs)

def make_request(self, method: str, endpoint: str, **kwargs: Any) -> requests.Response:
'''
Makes an authenticated request and refreshes the token if expired.
'''
has_generated_token = False

def generate_access_token():
self._generate_access_token()
return True

if not self.token or self.token.is_expired():
has_generated_token = generate_access_token()

url = f'{self.base_url}{endpoint}'
headers = kwargs.pop('headers', {})

response = self._make_request(method, url, headers, **kwargs)
if not has_generated_token and response.status_code == 401:
generate_access_token()
response = self._make_request(method, url, headers, **kwargs)

Check warning on line 185 in awx/main/utils/analytics_proxy.py

View check run for this annotation

Codecov / codecov/patch

awx/main/utils/analytics_proxy.py#L184-L185

Added lines #L184 - L185 were not covered by tests

return response

0 comments on commit c1d6666

Please sign in to comment.