diff --git a/.github/workflows/pull_request.yml b/.github/workflows/pull_request.yml index 94f73b1..d217124 100644 --- a/.github/workflows/pull_request.yml +++ b/.github/workflows/pull_request.yml @@ -4,11 +4,11 @@ jobs: runs-on: ubuntu-22.04 steps: - name: Checkout - uses: actions/checkout@v1 - - name: Setup python 3.8 - uses: actions/setup-python@v1 + uses: actions/checkout@v4 + - name: Setup python 3.10 + uses: actions/setup-python@v4 with: - python-version: 3.8 + python-version: "3.10" - name: Install pre-commit run: pip install pre-commit - name: Run pre-commit @@ -17,12 +17,12 @@ jobs: runs-on: ubuntu-22.04 strategy: matrix: - python-version: ['3.7.x', '3.8.x', '3.9.x', '3.10.x', '3.11.x'] + python-version: ['3.9.x', '3.10.x', '3.11.x'] steps: - name: Checkout - uses: actions/checkout@v1 + uses: actions/checkout@v4 - name: Setup python ${{ matrix.python-version }} - uses: actions/setup-python@v1 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} - name: Install dependencies diff --git a/.github/workflows/push.yml b/.github/workflows/push.yml index 079342a..fbdb5a7 100644 --- a/.github/workflows/push.yml +++ b/.github/workflows/push.yml @@ -11,11 +11,11 @@ jobs: runs-on: ubuntu-22.04 steps: - name: Checkout - uses: actions/checkout@v1 - - name: Setup python 3.8 - uses: actions/setup-python@v1 + uses: actions/checkout@v4 + - name: Setup python 3.10 + uses: actions/setup-python@v4 with: - python-version: 3.8 + python-version: "3.10" - name: Add wheel dependency run: pip install wheel - name: Generate dist diff --git a/kmsauth/__init__.py b/kmsauth/__init__.py index 451b976..db929bc 100644 --- a/kmsauth/__init__.py +++ b/kmsauth/__init__.py @@ -129,6 +129,7 @@ def __init__( else: self.extra_context = extra_context self.TOKENS = LRU(token_cache_size) + self.token_cache_size = token_cache_size self.KEY_METADATA = {} self.stats = stats self._validate() @@ -174,9 +175,15 @@ def _get_key_arn(self, key): 'KeyMetadata': {'Arn': key} } if key not in self.KEY_METADATA: - self.KEY_METADATA[key] = self.kms_client.describe_key( - KeyId='{0}'.format(key) - ) + if self.stats: + with self.stats.timer('kms_describe_key'): + self.KEY_METADATA[key] = self.kms_client.describe_key( + KeyId='{0}'.format(key) + ) + else: + self.KEY_METADATA[key] = self.kms_client.describe_key( + KeyId='{0}'.format(key) + ) return self.KEY_METADATA[key]['KeyMetadata']['Arn'] def _get_key_alias_from_cache(self, key_arn): @@ -240,12 +247,22 @@ def decrypt_token(self, username, token): ''' Decrypt a token. ''' + time_start = datetime.datetime.utcnow() version, user_type, _from = self._parse_username(username) if (version > self.maximum_token_version or version < self.minimum_token_version): raise TokenValidationError('Unacceptable token version.') if self.stats: self.stats.incr('token_version_{0}'.format(version)) + # Checkpoint 1: After username parsing + checkpoint_1 = (datetime.datetime.utcnow() - time_start).total_seconds() * 1000 # noqa: E501 + self.stats.timing('checkpoint_1_after_parse', checkpoint_1) # noqa: E501 + + self.stats.incr('token_version_{version}') + self.stats.incr(f'cache_key_from_{_from}') + self.stats.incr(f'cache_key_to_{self.to_auth_context}') + self.stats.incr(f'cache_key_user_type_{user_type}') + try: token_key = '{0}{1}{2}{3}'.format( hashlib.sha256(ensure_bytes(token)).hexdigest(), @@ -255,9 +272,37 @@ def decrypt_token(self, username, token): ) except Exception: raise TokenValidationError('Authentication error.') - if token_key not in self.TOKENS: + + if self.stats: + # Checkpoint 2: After cache key generation + checkpoint_2 = (datetime.datetime.utcnow() - time_start).total_seconds() * 1000 # noqa: E501 + self.stats.timing('checkpoint_2_after_cache_key', checkpoint_2) # noqa: E501 + + cache_miss = token_key not in self.TOKENS + + if self.stats: + # Checkpoint 3: After cache lookup + checkpoint_3 = (datetime.datetime.utcnow() - time_start).total_seconds() * 1000 # noqa: E501 + self.stats.timing('checkpoint_3_after_cache_lookup', checkpoint_3) # noqa: E501 + + if cache_miss: + if self.stats: + self.stats.incr('token_cache_miss') + self.stats.gauge('token_cache_size_at_miss', len(self.TOKENS)) + if len(self.TOKENS) >= self.token_cache_size: + self.stats.incr('token_cache_eviction') + + # Checkpoint 3.5: After stats calls in cache miss + checkpoint_3_5 = (datetime.datetime.utcnow() - time_start).total_seconds() * 1000 # noqa: E501 + self.stats.timing('checkpoint_3_5_after_cache_miss_stats', checkpoint_3_5) # noqa: E501 + try: token = base64.b64decode(token) + if self.stats: + # Checkpoint 3.7: After base64 decode + checkpoint_3_7 = (datetime.datetime.utcnow() - time_start).total_seconds() * 1000 # noqa: E501 + self.stats.timing('checkpoint_3_7_after_base64_decode', checkpoint_3_7) # noqa: E501 + # Ensure normal context fields override whatever is in # extra_context. context = copy.deepcopy(self.extra_context) @@ -265,12 +310,20 @@ def decrypt_token(self, username, token): context['from'] = _from if version > 1: context['user_type'] = user_type + + if self.stats: + # Checkpoint 3.9: After context setup + checkpoint_3_9 = (datetime.datetime.utcnow() - time_start).total_seconds() * 1000 # noqa: E501 + self.stats.timing('checkpoint_3_9_after_context_setup', checkpoint_3_9) # noqa: E501 if self.stats: with self.stats.timer('kms_decrypt_token'): data = self.kms_client.decrypt( CiphertextBlob=token, EncryptionContext=context ) + # Checkpoint 4: After KMS decrypt + checkpoint_4 = (datetime.datetime.utcnow() - time_start).total_seconds() * 1000 # noqa: E501 + self.stats.timing('checkpoint_4_after_kms_decrypt', checkpoint_4) # noqa: E501 else: data = self.kms_client.decrypt( CiphertextBlob=token, @@ -294,9 +347,17 @@ def decrypt_token(self, username, token): raise TokenValidationError( 'Authentication error. Unsupported user_type.' ) + if self.stats: + # Checkpoint 5: After key validation + checkpoint_5 = (datetime.datetime.utcnow() - time_start).total_seconds() * 1000 # noqa: E501 + self.stats.timing('checkpoint_5_after_key_validation', checkpoint_5) # noqa: E501 plaintext = data['Plaintext'] payload = json.loads(plaintext) key_alias = self._get_key_alias_from_cache(key_arn) + if self.stats: + # Checkpoint 6: After JSON processing + checkpoint_6 = (datetime.datetime.utcnow() - time_start).total_seconds() * 1000 # noqa: E501 + self.stats.timing('checkpoint_6_after_json_processing', checkpoint_6) # noqa: E501 ret = {'payload': payload, 'key_alias': key_alias} except TokenValidationError: raise @@ -313,8 +374,21 @@ def decrypt_token(self, username, token): 'Authentication error. General error.' ) else: + if self.stats: + self.stats.incr('token_cache_hit') ret = self.TOKENS[token_key] + if self.stats: + # Checkpoint 7: After cache hit + checkpoint_7 = (datetime.datetime.utcnow() - time_start).total_seconds() * 1000 # noqa: E501 + self.stats.timing('checkpoint_7_after_cache_hit', checkpoint_7) # noqa: E501 + now = datetime.datetime.utcnow() + if self.stats: + # Total time from start to this point (before time validation) + pre_time_validation_duration = (now - time_start).total_seconds() * 1000 # noqa: E501 + self.stats.timing('pre_time_validation_duration', pre_time_validation_duration) # noqa: E501 + # Original total validation duration metric + self.stats.timing('decrypt_token_validation_duration', (now - time_start).total_seconds() * 1000) # noqa: E501 try: not_before = datetime.datetime.strptime( ret['payload']['not_before'], @@ -326,14 +400,14 @@ def decrypt_token(self, username, token): ) except Exception: logging.exception( - 'Failed to get not_before and not_after from token payload.' + 'Failed to get not_before and not_after from token payload.' # noqa: E501 ) raise TokenValidationError( 'Authentication error. Missing validity.' ) delta = (not_after - not_before).seconds / 60 if delta > self.auth_token_max_lifetime: - logging.warning('Token used which exceeds max token lifetime.') + logging.warning('Token used which exceeds max token lifetime.') # noqa: E501 raise TokenValidationError( 'Authentication error. Token lifetime exceeded.' ) @@ -342,7 +416,22 @@ def decrypt_token(self, username, token): raise TokenValidationError( 'Authentication error. Invalid time validity for token.' ) + if self.stats: + # Checkpoint 8: After time validation + checkpoint_8 = (datetime.datetime.utcnow() - time_start).total_seconds() * 1000 # noqa: E501 + self.stats.timing('checkpoint_8_after_time_validation', checkpoint_8) # noqa: E501 + + cache_set_start = datetime.datetime.utcnow() self.TOKENS[token_key] = ret + if self.stats: + cache_set_duration = (datetime.datetime.utcnow() - cache_set_start).total_seconds() * 1000 # noqa: E501 + self.stats.timing('cache_set_duration', cache_set_duration) # noqa: E501 + + duration = (datetime.datetime.utcnow() - now).total_seconds() * 1000 # noqa: E501 + if self.stats: + self.stats.timing('decrypt_token_duration_post_validation', duration) # noqa: E501 + self.stats.incr('token_cache_set') + self.stats.gauge('token_cache_size_at_set', len(self.TOKENS)) # noqa: E501 return self.TOKENS[token_key] diff --git a/setup.py b/setup.py index 2b6ce72..82f1833 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,7 @@ from setuptools import setup, find_packages -VERSION = "0.6.3" +VERSION = "0.6.4.dev7" requirements = [ # Boto3 is the Amazon Web Services (AWS) Software Development Kit (SDK)