Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions .github/workflows/pull_request.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ jobs:
runs-on: ubuntu-22.04
steps:
- name: Checkout
uses: actions/checkout@v1
- name: Setup python 3.8
uses: actions/setup-python@v1
uses: actions/checkout@v4
- name: Setup python 3.10
uses: actions/setup-python@v4
with:
python-version: 3.8
python-version: "3.10"
- name: Install pre-commit
run: pip install pre-commit
- name: Run pre-commit
Expand All @@ -17,12 +17,12 @@ jobs:
runs-on: ubuntu-22.04
strategy:
matrix:
python-version: ['3.7.x', '3.8.x', '3.9.x', '3.10.x', '3.11.x']
python-version: ['3.9.x', '3.10.x', '3.11.x']
steps:
- name: Checkout
uses: actions/checkout@v1
uses: actions/checkout@v4
- name: Setup python ${{ matrix.python-version }}
uses: actions/setup-python@v1
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
Expand Down
8 changes: 4 additions & 4 deletions .github/workflows/push.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ jobs:
runs-on: ubuntu-22.04
steps:
- name: Checkout
uses: actions/checkout@v1
- name: Setup python 3.8
uses: actions/setup-python@v1
uses: actions/checkout@v4
- name: Setup python 3.10
uses: actions/setup-python@v4
with:
python-version: 3.8
python-version: "3.10"
- name: Add wheel dependency
run: pip install wheel
- name: Generate dist
Expand Down
101 changes: 95 additions & 6 deletions kmsauth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def __init__(
else:
self.extra_context = extra_context
self.TOKENS = LRU(token_cache_size)
self.token_cache_size = token_cache_size
self.KEY_METADATA = {}
self.stats = stats
self._validate()
Expand Down Expand Up @@ -174,9 +175,15 @@ def _get_key_arn(self, key):
'KeyMetadata': {'Arn': key}
}
if key not in self.KEY_METADATA:
self.KEY_METADATA[key] = self.kms_client.describe_key(
KeyId='{0}'.format(key)
)
if self.stats:
with self.stats.timer('kms_describe_key'):
self.KEY_METADATA[key] = self.kms_client.describe_key(
KeyId='{0}'.format(key)
)
else:
self.KEY_METADATA[key] = self.kms_client.describe_key(
KeyId='{0}'.format(key)
)
return self.KEY_METADATA[key]['KeyMetadata']['Arn']

def _get_key_alias_from_cache(self, key_arn):
Expand Down Expand Up @@ -240,12 +247,22 @@ def decrypt_token(self, username, token):
'''
Decrypt a token.
'''
time_start = datetime.datetime.utcnow()
version, user_type, _from = self._parse_username(username)
if (version > self.maximum_token_version or
version < self.minimum_token_version):
raise TokenValidationError('Unacceptable token version.')
if self.stats:
self.stats.incr('token_version_{0}'.format(version))
# Checkpoint 1: After username parsing
checkpoint_1 = (datetime.datetime.utcnow() - time_start).total_seconds() * 1000 # noqa: E501
self.stats.timing('checkpoint_1_after_parse', checkpoint_1) # noqa: E501

self.stats.incr('token_version_{version}')
self.stats.incr(f'cache_key_from_{_from}')
self.stats.incr(f'cache_key_to_{self.to_auth_context}')
self.stats.incr(f'cache_key_user_type_{user_type}')

try:
token_key = '{0}{1}{2}{3}'.format(
hashlib.sha256(ensure_bytes(token)).hexdigest(),
Expand All @@ -255,22 +272,58 @@ def decrypt_token(self, username, token):
)
except Exception:
raise TokenValidationError('Authentication error.')
if token_key not in self.TOKENS:

if self.stats:
# Checkpoint 2: After cache key generation
checkpoint_2 = (datetime.datetime.utcnow() - time_start).total_seconds() * 1000 # noqa: E501
self.stats.timing('checkpoint_2_after_cache_key', checkpoint_2) # noqa: E501

cache_miss = token_key not in self.TOKENS

if self.stats:
# Checkpoint 3: After cache lookup
checkpoint_3 = (datetime.datetime.utcnow() - time_start).total_seconds() * 1000 # noqa: E501
self.stats.timing('checkpoint_3_after_cache_lookup', checkpoint_3) # noqa: E501

if cache_miss:
if self.stats:
self.stats.incr('token_cache_miss')
self.stats.gauge('token_cache_size_at_miss', len(self.TOKENS))
if len(self.TOKENS) >= self.token_cache_size:
self.stats.incr('token_cache_eviction')

# Checkpoint 3.5: After stats calls in cache miss
checkpoint_3_5 = (datetime.datetime.utcnow() - time_start).total_seconds() * 1000 # noqa: E501
self.stats.timing('checkpoint_3_5_after_cache_miss_stats', checkpoint_3_5) # noqa: E501

try:
token = base64.b64decode(token)
if self.stats:
# Checkpoint 3.7: After base64 decode
checkpoint_3_7 = (datetime.datetime.utcnow() - time_start).total_seconds() * 1000 # noqa: E501
self.stats.timing('checkpoint_3_7_after_base64_decode', checkpoint_3_7) # noqa: E501

# Ensure normal context fields override whatever is in
# extra_context.
context = copy.deepcopy(self.extra_context)
context['to'] = self.to_auth_context
context['from'] = _from
if version > 1:
context['user_type'] = user_type

if self.stats:
# Checkpoint 3.9: After context setup
checkpoint_3_9 = (datetime.datetime.utcnow() - time_start).total_seconds() * 1000 # noqa: E501
self.stats.timing('checkpoint_3_9_after_context_setup', checkpoint_3_9) # noqa: E501
if self.stats:
with self.stats.timer('kms_decrypt_token'):
data = self.kms_client.decrypt(
CiphertextBlob=token,
EncryptionContext=context
)
# Checkpoint 4: After KMS decrypt
checkpoint_4 = (datetime.datetime.utcnow() - time_start).total_seconds() * 1000 # noqa: E501
self.stats.timing('checkpoint_4_after_kms_decrypt', checkpoint_4) # noqa: E501
else:
data = self.kms_client.decrypt(
CiphertextBlob=token,
Expand All @@ -294,9 +347,17 @@ def decrypt_token(self, username, token):
raise TokenValidationError(
'Authentication error. Unsupported user_type.'
)
if self.stats:
# Checkpoint 5: After key validation
checkpoint_5 = (datetime.datetime.utcnow() - time_start).total_seconds() * 1000 # noqa: E501
self.stats.timing('checkpoint_5_after_key_validation', checkpoint_5) # noqa: E501
plaintext = data['Plaintext']
payload = json.loads(plaintext)
key_alias = self._get_key_alias_from_cache(key_arn)
if self.stats:
# Checkpoint 6: After JSON processing
checkpoint_6 = (datetime.datetime.utcnow() - time_start).total_seconds() * 1000 # noqa: E501
self.stats.timing('checkpoint_6_after_json_processing', checkpoint_6) # noqa: E501
ret = {'payload': payload, 'key_alias': key_alias}
except TokenValidationError:
raise
Expand All @@ -313,8 +374,21 @@ def decrypt_token(self, username, token):
'Authentication error. General error.'
)
else:
if self.stats:
self.stats.incr('token_cache_hit')
ret = self.TOKENS[token_key]
if self.stats:
# Checkpoint 7: After cache hit
checkpoint_7 = (datetime.datetime.utcnow() - time_start).total_seconds() * 1000 # noqa: E501
self.stats.timing('checkpoint_7_after_cache_hit', checkpoint_7) # noqa: E501

now = datetime.datetime.utcnow()
if self.stats:
# Total time from start to this point (before time validation)
pre_time_validation_duration = (now - time_start).total_seconds() * 1000 # noqa: E501
self.stats.timing('pre_time_validation_duration', pre_time_validation_duration) # noqa: E501
# Original total validation duration metric
self.stats.timing('decrypt_token_validation_duration', (now - time_start).total_seconds() * 1000) # noqa: E501
try:
not_before = datetime.datetime.strptime(
ret['payload']['not_before'],
Expand All @@ -326,14 +400,14 @@ def decrypt_token(self, username, token):
)
except Exception:
logging.exception(
'Failed to get not_before and not_after from token payload.'
'Failed to get not_before and not_after from token payload.' # noqa: E501
)
raise TokenValidationError(
'Authentication error. Missing validity.'
)
delta = (not_after - not_before).seconds / 60
if delta > self.auth_token_max_lifetime:
logging.warning('Token used which exceeds max token lifetime.')
logging.warning('Token used which exceeds max token lifetime.') # noqa: E501
raise TokenValidationError(
'Authentication error. Token lifetime exceeded.'
)
Expand All @@ -342,7 +416,22 @@ def decrypt_token(self, username, token):
raise TokenValidationError(
'Authentication error. Invalid time validity for token.'
)
if self.stats:
# Checkpoint 8: After time validation
checkpoint_8 = (datetime.datetime.utcnow() - time_start).total_seconds() * 1000 # noqa: E501
self.stats.timing('checkpoint_8_after_time_validation', checkpoint_8) # noqa: E501

cache_set_start = datetime.datetime.utcnow()
self.TOKENS[token_key] = ret
if self.stats:
cache_set_duration = (datetime.datetime.utcnow() - cache_set_start).total_seconds() * 1000 # noqa: E501
self.stats.timing('cache_set_duration', cache_set_duration) # noqa: E501

duration = (datetime.datetime.utcnow() - now).total_seconds() * 1000 # noqa: E501
if self.stats:
self.stats.timing('decrypt_token_duration_post_validation', duration) # noqa: E501
self.stats.incr('token_cache_set')
self.stats.gauge('token_cache_size_at_set', len(self.TOKENS)) # noqa: E501
return self.TOKENS[token_key]


Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from setuptools import setup, find_packages

VERSION = "0.6.3"
VERSION = "0.6.4.dev7"

requirements = [
# Boto3 is the Amazon Web Services (AWS) Software Development Kit (SDK)
Expand Down
Loading