Skip to content

Commit d23819c

Browse files
committed
HandlerToken class created to manage token related operations.
1 parent e00a6f7 commit d23819c

File tree

10 files changed

+431
-343
lines changed

10 files changed

+431
-343
lines changed

src/event_gate_lambda.py

Lines changed: 13 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -15,23 +15,19 @@
1515
#
1616

1717
"""Event Gate Lambda function implementation."""
18-
import base64
1918
import json
2019
import logging
2120
import os
2221
import sys
23-
from typing import Any, Dict, cast
22+
from typing import Any, Dict
2423

2524
import boto3
2625
import jwt
27-
import requests
2826
import urllib3
29-
from cryptography.exceptions import UnsupportedAlgorithm
30-
from cryptography.hazmat.primitives import serialization
31-
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey
3227
from jsonschema import validate
3328
from jsonschema.exceptions import ValidationError
3429

30+
from src.handlers.handler_token import HandlerToken
3531
from src.writers import writer_eventbridge, writer_kafka, writer_postgres
3632
from src.utils.conf_path import CONF_DIR, INVALID_CONF_ENV
3733

@@ -65,51 +61,28 @@
6561
logger.debug("Loaded TOPICS")
6662

6763
with open(os.path.join(_CONF_DIR, "config.json"), "r", encoding="utf-8") as file:
68-
CONFIG = json.load(file)
64+
config = json.load(file)
6965
logger.debug("Loaded main CONFIG")
7066

7167
aws_s3 = boto3.Session().resource("s3", verify=False) # nosec Boto verify disabled intentionally
7268
logger.debug("Initialized AWS S3 Client")
7369

74-
if CONFIG["access_config"].startswith("s3://"):
75-
name_parts = CONFIG["access_config"].split("/")
70+
if config["access_config"].startswith("s3://"):
71+
name_parts = config["access_config"].split("/")
7672
BUCKET_NAME = name_parts[2]
7773
BUCKET_OBJECT_KEY = "/".join(name_parts[3:])
7874
ACCESS = json.loads(aws_s3.Bucket(BUCKET_NAME).Object(BUCKET_OBJECT_KEY).get()["Body"].read().decode("utf-8"))
7975
else:
80-
with open(CONFIG["access_config"], "r", encoding="utf-8") as file:
76+
with open(config["access_config"], "r", encoding="utf-8") as file:
8177
ACCESS = json.load(file)
8278
logger.debug("Loaded ACCESS definitions")
8379

84-
# Initialize token public keys
85-
TOKEN_PROVIDER_URL = CONFIG.get("token_provider_url")
86-
TOKEN_PUBLIC_KEYS_URL = CONFIG.get("token_public_keys_url") or CONFIG.get("token_public_key_url")
87-
88-
try:
89-
response_json = requests.get(TOKEN_PUBLIC_KEYS_URL, verify=False, timeout=5).json()
90-
raw_keys: list[str] = []
91-
if isinstance(response_json, dict):
92-
if "keys" in response_json and isinstance(response_json["keys"], list):
93-
for item in response_json["keys"]:
94-
if "key" in item:
95-
raw_keys.append(item["key"].strip())
96-
elif "key" in response_json:
97-
raw_keys.append(response_json["key"].strip())
98-
99-
if not raw_keys:
100-
raise KeyError(f"No public keys found in {TOKEN_PUBLIC_KEYS_URL} endpoint response")
101-
102-
TOKEN_PUBLIC_KEYS: list[RSAPublicKey] = [
103-
cast(RSAPublicKey, serialization.load_der_public_key(base64.b64decode(raw_key))) for raw_key in raw_keys
104-
]
105-
logger.debug("Loaded %d TOKEN_PUBLIC_KEYS", len(TOKEN_PUBLIC_KEYS))
106-
except (requests.RequestException, ValueError, KeyError, UnsupportedAlgorithm) as exc:
107-
logger.exception("Failed to fetch or deserialize token public key from %s", TOKEN_PUBLIC_KEYS_URL)
108-
raise RuntimeError("Token public key initialization failed") from exc
80+
# Initialize token handler and load token public keys
81+
handler_token = HandlerToken(config).load_public_keys()
10982

11083
# Initialize EventGate writers
111-
writer_eventbridge.init(logger, CONFIG)
112-
writer_kafka.init(logger, CONFIG)
84+
writer_eventbridge.init(logger, config)
85+
writer_kafka.init(logger, config)
11386
writer_postgres.init(logger)
11487

11588

@@ -141,12 +114,6 @@ def get_api() -> Dict[str, Any]:
141114
return {"statusCode": 200, "body": API}
142115

143116

144-
def get_token() -> Dict[str, Any]:
145-
"""Return 303 redirect to token provider endpoint."""
146-
logger.debug("Handling GET Token")
147-
return {"statusCode": 303, "headers": {"Location": TOKEN_PROVIDER_URL}}
148-
149-
150117
def get_topics() -> Dict[str, Any]:
151118
"""Return list of available topic names."""
152119
logger.debug("Handling GET Topics")
@@ -180,7 +147,7 @@ def post_topic_message(topic_name: str, topic_message: Dict[str, Any], token_enc
180147
"""
181148
logger.debug("Handling POST %s", topic_name)
182149
try:
183-
token = decode_jwt_all(token_encoded)
150+
token: Dict[str, Any] = handler_token.decode_jwt(token_encoded)
184151
except jwt.PyJWTError: # type: ignore[attr-defined]
185152
return _error_response(401, "auth", "Invalid or missing token")
186153

@@ -222,55 +189,6 @@ def post_topic_message(topic_name: str, topic_message: Dict[str, Any], token_enc
222189
}
223190

224191

225-
def decode_jwt_all(token_encoded: str) -> Dict[str, Any]:
226-
"""Decode JWT using any of the loaded public keys.
227-
228-
Args:
229-
token_encoded: Encoded bearer JWT token string.
230-
"""
231-
for public_key in TOKEN_PUBLIC_KEYS:
232-
try:
233-
return jwt.decode(token_encoded, public_key, algorithms=["RS256"])
234-
except jwt.PyJWTError:
235-
continue
236-
raise jwt.PyJWTError("Verification failed for all public keys")
237-
238-
239-
def extract_token(event_headers: Dict[str, str]) -> str:
240-
"""Extract bearer token from headers (case-insensitive).
241-
242-
Supports:
243-
- Custom 'bearer' header (any casing) whose value is the raw token
244-
- Standard 'Authorization: Bearer <token>' header (case-insensitive scheme & key)
245-
Returns empty string if token not found or malformed.
246-
"""
247-
if not event_headers:
248-
return ""
249-
250-
# Normalize keys to lowercase for case-insensitive lookup
251-
lowered = {str(k).lower(): v for k, v in event_headers.items()}
252-
253-
# Direct bearer header (raw token)
254-
if "bearer" in lowered and isinstance(lowered["bearer"], str):
255-
token_candidate = lowered["bearer"].strip()
256-
if token_candidate:
257-
return token_candidate
258-
259-
# Authorization header with Bearer scheme
260-
auth_val = lowered.get("authorization", "")
261-
if not isinstance(auth_val, str): # defensive
262-
return ""
263-
auth_val = auth_val.strip()
264-
if not auth_val:
265-
return ""
266-
267-
# Case-insensitive match for 'Bearer ' prefix
268-
if not auth_val.lower().startswith("bearer "):
269-
return ""
270-
token_part = auth_val[7:].strip() # len('Bearer ')==7
271-
return token_part
272-
273-
274192
def lambda_handler(event: Dict[str, Any], context: Any): # pylint: disable=unused-argument,too-many-return-statements
275193
"""AWS Lambda entry point.
276194
@@ -281,7 +199,7 @@ def lambda_handler(event: Dict[str, Any], context: Any): # pylint: disable=unus
281199
if resource == "/api":
282200
return get_api()
283201
if resource == "/token":
284-
return get_token()
202+
return handler_token.get_token()
285203
if resource == "/topics":
286204
return get_topics()
287205
if resource == "/topics/{topic_name}":
@@ -292,7 +210,7 @@ def lambda_handler(event: Dict[str, Any], context: Any): # pylint: disable=unus
292210
return post_topic_message(
293211
event["pathParameters"]["topic_name"].lower(),
294212
json.loads(event["body"]),
295-
extract_token(event.get("headers", {})),
213+
handler_token.extract_token(event.get("headers", {})),
296214
)
297215
if resource == "/terminate":
298216
sys.exit("TERMINATING") # pragma: no cover - deliberate termination path

src/handlers/__init__.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#
2+
# Copyright 2025 ABSA Group Limited
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#

src/handlers/handler_token.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
#
2+
# Copyright 2025 ABSA Group Limited
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
17+
"""
18+
This module provides the HandlerToken class for managing the token related operations.
19+
"""
20+
21+
import base64
22+
import logging
23+
import os
24+
from typing import Dict, Any, cast
25+
26+
import jwt
27+
import requests
28+
from cryptography.exceptions import UnsupportedAlgorithm
29+
from cryptography.hazmat.primitives import serialization
30+
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey
31+
32+
from src.utils.constants import TOKEN_PROVIDER_URL, TOKEN_PUBLIC_KEYS_URL, TOKEN_PUBLIC_KEY_URL
33+
34+
logger = logging.getLogger(__name__)
35+
log_level = os.environ.get("LOG_LEVEL", "INFO")
36+
logger.setLevel(log_level)
37+
38+
39+
class HandlerToken:
40+
"""
41+
HandlerToken manages token provider URL and public keys for JWT verification.
42+
"""
43+
44+
def __init__(self, config):
45+
self.provider_url: str = config.get(TOKEN_PROVIDER_URL, "")
46+
self.public_keys_url: str = config.get(TOKEN_PUBLIC_KEYS_URL) or config.get(TOKEN_PUBLIC_KEY_URL)
47+
self.public_keys: list[RSAPublicKey] = []
48+
49+
def load_public_keys(self) -> "HandlerToken":
50+
"""
51+
Load token public keys from the configured URL.
52+
Returns:
53+
HandlerToken: The current instance with loaded public keys.
54+
Raises:
55+
RuntimeError: If fetching or deserializing the public keys fails.
56+
"""
57+
logger.debug("Loading token public keys from %s", self.public_keys_url)
58+
59+
try:
60+
response_json = requests.get(self.public_keys_url, verify=False, timeout=5).json()
61+
raw_keys: list[str] = []
62+
63+
if isinstance(response_json, dict):
64+
if "keys" in response_json and isinstance(response_json["keys"], list):
65+
for item in response_json["keys"]:
66+
if "key" in item:
67+
raw_keys.append(item["key"].strip())
68+
elif "key" in response_json:
69+
raw_keys.append(response_json["key"].strip())
70+
71+
if not raw_keys:
72+
raise KeyError(f"No public keys found in {self.public_keys_url} endpoint response")
73+
74+
self.public_keys = [
75+
cast(RSAPublicKey, serialization.load_der_public_key(base64.b64decode(raw_key))) for raw_key in raw_keys
76+
]
77+
logger.debug("Loaded %d token public keys", len(self.public_keys))
78+
79+
return self
80+
except (requests.RequestException, ValueError, KeyError, UnsupportedAlgorithm) as exc:
81+
logger.exception("Failed to fetch or deserialize token public key from %s", self.public_keys_url)
82+
raise RuntimeError("Token public key initialization failed") from exc
83+
84+
def decode_jwt(self, token_encoded: str) -> Dict[str, Any]:
85+
"""
86+
Decode and verify a JWT using the loaded public keys.
87+
Args:
88+
token_encoded (str): The encoded JWT token.
89+
Returns:
90+
Dict[str, Any]: The decoded JWT payload.
91+
Raises:
92+
jwt.PyJWTError: If verification fails for all public keys.
93+
"""
94+
logger.debug("Decoding JWT")
95+
for public_key in self.public_keys:
96+
try:
97+
return jwt.decode(token_encoded, public_key, algorithms=["RS256"])
98+
except jwt.PyJWTError:
99+
continue
100+
raise jwt.PyJWTError("Verification failed for all public keys")
101+
102+
def get_token(self) -> Dict[str, Any]:
103+
"""
104+
Returns: A 303 redirect response to the token provider URL.
105+
"""
106+
logger.debug("Handling GET Token")
107+
return {"statusCode": 303, "headers": {"Location": self.provider_url}}
108+
109+
@staticmethod
110+
def extract_token(event_headers: Dict[str, str]) -> str:
111+
"""
112+
Extracts the bearer (custom/standard) token from event headers.
113+
Args:
114+
event_headers (Dict[str, str]): The event headers.
115+
Returns:
116+
str: The extracted bearer token, or an empty string if not found.
117+
"""
118+
if not event_headers:
119+
return ""
120+
121+
# Normalize keys to lowercase for case-insensitive lookup
122+
lowered = {str(k).lower(): v for k, v in event_headers.items()}
123+
124+
# Direct bearer header (raw token)
125+
if "bearer" in lowered and isinstance(lowered["bearer"], str):
126+
token_candidate = lowered["bearer"].strip()
127+
if token_candidate:
128+
return token_candidate
129+
130+
# Authorization header with Bearer scheme
131+
auth_val = lowered.get("authorization", "")
132+
if not isinstance(auth_val, str): # defensive
133+
return ""
134+
auth_val = auth_val.strip()
135+
if not auth_val:
136+
return ""
137+
138+
# Case-insensitive match for 'Bearer ' prefix
139+
if not auth_val.lower().startswith("bearer "):
140+
return ""
141+
token_part = auth_val[7:].strip() # len('Bearer ')==7
142+
return token_part

src/utils/constants.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#
2+
# Copyright 2025 ABSA Group Limited
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
17+
"""
18+
This module contains all constants and enums used across the project.
19+
"""
20+
21+
# Token related constants
22+
TOKEN_PROVIDER_URL = "token_provider_url"
23+
TOKEN_PUBLIC_KEY_URL = "token_public_key_url"
24+
TOKEN_PUBLIC_KEYS_URL = "token_public_keys_url"

0 commit comments

Comments
 (0)