-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathcredentials.py
75 lines (56 loc) · 2.27 KB
/
credentials.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
# This file provides functions to handle TAMS API credentials
from abc import ABCMeta, abstractmethod
import base64
import aiohttp
def get_basic_auth_header(name: str, secret: str) -> dict[str, str]:
creds = base64.b64encode(f"{name}:{secret}".encode()).decode()
return {"Authorization": f"Basic {creds}"}
class Credentials(metaclass=ABCMeta):
"""Base class for TAMS API credentials"""
@abstractmethod
def header(self) -> dict[str, str]:
"""Returns the Authorization HTTP header"""
return {}
class BasicCredentials(Credentials):
"""Basic username/password credentials"""
def __init__(self, username: str, password: str) -> None:
self._header = get_basic_auth_header(username, password)
def header(self) -> dict[str, str]:
return self._header
class RenewableCredentials(Credentials):
"""Base class for credentials that need to be renewed"""
@abstractmethod
async def ensure_credentials(self) -> None:
pass
@abstractmethod
async def renew_credentials(self) -> None:
pass
class OAuth2ClientCredentials(RenewableCredentials):
"""OAuth2 Client Credentials Grant credentials"""
def __init__(
self,
authorization_url: str,
client_id: str,
client_secret: str,
) -> None:
self.authorization_url = authorization_url
self.client_id = client_id
self.client_secret = client_secret
self.access_token = ""
self.expires_in = 0.0
async def ensure_credentials(self) -> None:
if not self.access_token:
await self.renew_credentials()
async def renew_credentials(self) -> None:
form_data = {
"grant_type": "client_credentials"
}
headers = get_basic_auth_header(self.client_id, self.client_secret)
async with aiohttp.ClientSession() as session:
async with session.post(self.authorization_url, data=form_data, headers=headers) as resp:
resp.raise_for_status()
token_response = await resp.json()
self.access_token = token_response["access_token"]
self.expires_in = token_response["expires_in"]
def header(self) -> dict[str, str]:
return {"Authorization": f"Bearer {self.access_token}"}