Skip to content

Commit a9347e2

Browse files
authored
Token verification with previous public token (#83)
* Token verification with previous public token * HandlerToken class created to manage token related operations.
1 parent 39c21a2 commit a9347e2

File tree

13 files changed

+517
-268
lines changed

13 files changed

+517
-268
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ Example (sanitized):
8080
{
8181
"access_config": "s3://<bucket>/access.json",
8282
"token_provider_url": "https://<token-ui.example>",
83-
"token_public_key_url": "https://<token-api.example>/public-key",
83+
"token_public_keys_url": "https://<token-api.example>/token/public-keys",
8484
"kafka_bootstrap_server": "broker1:9092,broker2:9092",
8585
"event_bus_arn": "arn:aws:events:region:acct:event-bus/your-bus"
8686
}
@@ -137,7 +137,7 @@ Use when Kafka access needs Kerberos / SASL_SSL or custom `librdkafka` build.
137137
| Code coverage | [Code Coverage](./DEVELOPER.md#code-coverage) |
138138

139139
## Security & Authorization
140-
- JWT tokens must be RS256 signed; the public key is fetched at cold start from `token_public_key_url` (DER base64 inside JSON `{ "key": "..." }`).
140+
- JWT tokens must be RS256 signed; current and previous public keys are fetched at cold start from `token_public_keys_url` as DER base64 values (list `keys[*].key`, with single-key fallback `{ "key": "..." }`).
141141
- Subject claim (`sub`) is matched against `ACCESS[topicName]`.
142142
- Authorization header forms accepted:
143143
- `Authorization: Bearer <token>` (preferred)

conf/config.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
{
22
"access_config": "s3://<redacted>/access.json",
33
"token_provider_url": "https://<redacted>",
4-
"token_public_key_url": "https://<redacted>",
4+
"token_public_keys_url": "https://<redacted>",
55
"kafka_bootstrap_server": "localhost:9092",
66
"event_bus_arn": "arn:aws:events:<redacted>"
77
}

src/event_gate_lambda.py

Lines changed: 14 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
#
1616

1717
"""Event Gate Lambda function implementation."""
18-
import base64
1918
import json
2019
import logging
2120
import os
@@ -24,13 +23,11 @@
2423

2524
import boto3
2625
import jwt
27-
import requests
2826
import urllib3
29-
from cryptography.exceptions import UnsupportedAlgorithm
30-
from cryptography.hazmat.primitives import serialization
3127
from jsonschema import validate
3228
from jsonschema.exceptions import ValidationError
3329

30+
from src.handlers.handler_token import HandlerToken
3431
from src.writers import writer_eventbridge, writer_kafka, writer_postgres
3532
from src.utils.conf_path import CONF_DIR, INVALID_CONF_ENV
3633

@@ -64,35 +61,28 @@
6461
logger.debug("Loaded TOPICS")
6562

6663
with open(os.path.join(_CONF_DIR, "config.json"), "r", encoding="utf-8") as file:
67-
CONFIG = json.load(file)
64+
config = json.load(file)
6865
logger.debug("Loaded main CONFIG")
6966

7067
aws_s3 = boto3.Session().resource("s3", verify=False) # nosec Boto verify disabled intentionally
7168
logger.debug("Initialized AWS S3 Client")
7269

73-
if CONFIG["access_config"].startswith("s3://"):
74-
name_parts = CONFIG["access_config"].split("/")
70+
if config["access_config"].startswith("s3://"):
71+
name_parts = config["access_config"].split("/")
7572
BUCKET_NAME = name_parts[2]
7673
BUCKET_OBJECT_KEY = "/".join(name_parts[3:])
7774
ACCESS = json.loads(aws_s3.Bucket(BUCKET_NAME).Object(BUCKET_OBJECT_KEY).get()["Body"].read().decode("utf-8"))
7875
else:
79-
with open(CONFIG["access_config"], "r", encoding="utf-8") as file:
76+
with open(config["access_config"], "r", encoding="utf-8") as file:
8077
ACCESS = json.load(file)
8178
logger.debug("Loaded ACCESS definitions")
8279

83-
TOKEN_PROVIDER_URL = CONFIG["token_provider_url"]
84-
# Add timeout to avoid hanging requests; wrap in robust error handling so failures are explicit
85-
try:
86-
response_json = requests.get(CONFIG["token_public_key_url"], verify=False, timeout=5).json() # nosec external
87-
token_public_key_encoded = response_json["key"]
88-
TOKEN_PUBLIC_KEY: Any = serialization.load_der_public_key(base64.b64decode(token_public_key_encoded))
89-
logger.debug("Loaded TOKEN_PUBLIC_KEY")
90-
except (requests.RequestException, ValueError, KeyError, UnsupportedAlgorithm) as exc:
91-
logger.exception("Failed to fetch or deserialize token public key from %s", CONFIG.get("token_public_key_url"))
92-
raise RuntimeError("Token public key initialization failed") from exc
93-
94-
writer_eventbridge.init(logger, CONFIG)
95-
writer_kafka.init(logger, CONFIG)
80+
# Initialize token handler and load token public keys
81+
handler_token = HandlerToken(config).load_public_keys()
82+
83+
# Initialize EventGate writers
84+
writer_eventbridge.init(logger, config)
85+
writer_kafka.init(logger, config)
9686
writer_postgres.init(logger)
9787

9888

@@ -124,12 +114,6 @@ def get_api() -> Dict[str, Any]:
124114
return {"statusCode": 200, "body": API}
125115

126116

127-
def get_token() -> Dict[str, Any]:
128-
"""Return 303 redirect to token provider endpoint."""
129-
logger.debug("Handling GET Token")
130-
return {"statusCode": 303, "headers": {"Location": TOKEN_PROVIDER_URL}}
131-
132-
133117
def get_topics() -> Dict[str, Any]:
134118
"""Return list of available topic names."""
135119
logger.debug("Handling GET Topics")
@@ -163,7 +147,7 @@ def post_topic_message(topic_name: str, topic_message: Dict[str, Any], token_enc
163147
"""
164148
logger.debug("Handling POST %s", topic_name)
165149
try:
166-
token = jwt.decode(token_encoded, TOKEN_PUBLIC_KEY, algorithms=["RS256"]) # type: ignore[arg-type]
150+
token: Dict[str, Any] = handler_token.decode_jwt(token_encoded)
167151
except jwt.PyJWTError: # type: ignore[attr-defined]
168152
return _error_response(401, "auth", "Invalid or missing token")
169153

@@ -205,41 +189,6 @@ def post_topic_message(topic_name: str, topic_message: Dict[str, Any], token_enc
205189
}
206190

207191

208-
def extract_token(event_headers: Dict[str, str]) -> str:
209-
"""Extract bearer token from headers (case-insensitive).
210-
211-
Supports:
212-
- Custom 'bearer' header (any casing) whose value is the raw token
213-
- Standard 'Authorization: Bearer <token>' header (case-insensitive scheme & key)
214-
Returns empty string if token not found or malformed.
215-
"""
216-
if not event_headers:
217-
return ""
218-
219-
# Normalize keys to lowercase for case-insensitive lookup
220-
lowered = {str(k).lower(): v for k, v in event_headers.items()}
221-
222-
# Direct bearer header (raw token)
223-
if "bearer" in lowered and isinstance(lowered["bearer"], str):
224-
token_candidate = lowered["bearer"].strip()
225-
if token_candidate:
226-
return token_candidate
227-
228-
# Authorization header with Bearer scheme
229-
auth_val = lowered.get("authorization", "")
230-
if not isinstance(auth_val, str): # defensive
231-
return ""
232-
auth_val = auth_val.strip()
233-
if not auth_val:
234-
return ""
235-
236-
# Case-insensitive match for 'Bearer ' prefix
237-
if not auth_val.lower().startswith("bearer "):
238-
return ""
239-
token_part = auth_val[7:].strip() # len('Bearer ')==7
240-
return token_part
241-
242-
243192
def lambda_handler(event: Dict[str, Any], context: Any): # pylint: disable=unused-argument,too-many-return-statements
244193
"""AWS Lambda entry point.
245194
@@ -250,7 +199,7 @@ def lambda_handler(event: Dict[str, Any], context: Any): # pylint: disable=unus
250199
if resource == "/api":
251200
return get_api()
252201
if resource == "/token":
253-
return get_token()
202+
return handler_token.get_token_provider_info()
254203
if resource == "/topics":
255204
return get_topics()
256205
if resource == "/topics/{topic_name}":
@@ -261,7 +210,7 @@ def lambda_handler(event: Dict[str, Any], context: Any): # pylint: disable=unus
261210
return post_topic_message(
262211
event["pathParameters"]["topic_name"].lower(),
263212
json.loads(event["body"]),
264-
extract_token(event.get("headers", {})),
213+
handler_token.extract_token(event.get("headers", {})),
265214
)
266215
if resource == "/terminate":
267216
sys.exit("TERMINATING") # pragma: no cover - deliberate termination path

src/handlers/__init__.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#
2+
# Copyright 2025 ABSA Group Limited
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#

src/handlers/handler_token.py

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
#
2+
# Copyright 2025 ABSA Group Limited
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
17+
"""
18+
This module provides the HandlerToken class for managing the token related operations.
19+
"""
20+
21+
import base64
22+
import logging
23+
import os
24+
from datetime import datetime, timedelta, timezone
25+
from typing import Dict, Any, cast
26+
27+
import jwt
28+
import requests
29+
from cryptography.exceptions import UnsupportedAlgorithm
30+
from cryptography.hazmat.primitives import serialization
31+
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey
32+
33+
from src.utils.constants import TOKEN_PROVIDER_URL_KEY, TOKEN_PUBLIC_KEYS_URL_KEY, TOKEN_PUBLIC_KEY_URL_KEY
34+
35+
logger = logging.getLogger(__name__)
36+
log_level = os.environ.get("LOG_LEVEL", "INFO")
37+
logger.setLevel(log_level)
38+
39+
40+
class HandlerToken:
41+
"""
42+
HandlerToken manages token provider URL and public keys for JWT verification.
43+
"""
44+
45+
_REFRESH_INTERVAL = timedelta(minutes=28)
46+
47+
def __init__(self, config):
48+
self.provider_url: str = config.get(TOKEN_PROVIDER_URL_KEY, "")
49+
self.public_keys_url: str = config.get(TOKEN_PUBLIC_KEYS_URL_KEY) or config.get(TOKEN_PUBLIC_KEY_URL_KEY)
50+
self.public_keys: list[RSAPublicKey] = []
51+
self._last_loaded_at: datetime | None = None
52+
53+
def _refresh_keys_if_needed(self) -> None:
54+
"""
55+
Refresh the public keys if the refresh interval has passed.
56+
"""
57+
logger.debug("Checking if the token public keys need refresh")
58+
59+
if self._last_loaded_at is None:
60+
return
61+
now = datetime.now(timezone.utc)
62+
if now - self._last_loaded_at < self._REFRESH_INTERVAL:
63+
logger.debug("Token public keys are up to date, no refresh needed")
64+
return
65+
try:
66+
logger.debug("Token public keys are stale, refreshing now")
67+
self.load_public_keys()
68+
except RuntimeError:
69+
logger.warning("Token public key refresh failed, using existing keys")
70+
71+
def load_public_keys(self) -> "HandlerToken":
72+
"""
73+
Load token public keys from the configured URL.
74+
Returns:
75+
HandlerToken: The current instance with loaded public keys.
76+
Raises:
77+
RuntimeError: If fetching or deserializing the public keys fails.
78+
"""
79+
logger.debug("Loading token public keys from %s", self.public_keys_url)
80+
81+
try:
82+
response_json = requests.get(self.public_keys_url, verify=False, timeout=5).json()
83+
raw_keys: list[str] = []
84+
85+
if isinstance(response_json, dict):
86+
if "keys" in response_json and isinstance(response_json["keys"], list):
87+
for item in response_json["keys"]:
88+
if "key" in item:
89+
raw_keys.append(item["key"].strip())
90+
elif "key" in response_json:
91+
raw_keys.append(response_json["key"].strip())
92+
93+
if not raw_keys:
94+
raise KeyError(f"No public keys found in {self.public_keys_url} endpoint response")
95+
96+
self.public_keys = [
97+
cast(RSAPublicKey, serialization.load_der_public_key(base64.b64decode(raw_key))) for raw_key in raw_keys
98+
]
99+
logger.debug("Loaded %d token public keys", len(self.public_keys))
100+
self._last_loaded_at = datetime.now(timezone.utc)
101+
102+
return self
103+
except (requests.RequestException, ValueError, KeyError, UnsupportedAlgorithm) as exc:
104+
logger.exception("Failed to fetch or deserialize token public key from %s", self.public_keys_url)
105+
raise RuntimeError("Token public key initialization failed") from exc
106+
107+
def decode_jwt(self, token_encoded: str) -> Dict[str, Any]:
108+
"""
109+
Decode and verify a JWT using the loaded public keys.
110+
Args:
111+
token_encoded (str): The encoded JWT token.
112+
Returns:
113+
Dict[str, Any]: The decoded JWT payload.
114+
Raises:
115+
jwt.PyJWTError: If verification fails for all public keys.
116+
"""
117+
self._refresh_keys_if_needed()
118+
119+
logger.debug("Decoding JWT")
120+
for public_key in self.public_keys:
121+
try:
122+
return jwt.decode(token_encoded, public_key, algorithms=["RS256"])
123+
except jwt.PyJWTError:
124+
continue
125+
raise jwt.PyJWTError("Verification failed for all public keys")
126+
127+
def get_token_provider_info(self) -> Dict[str, Any]:
128+
"""
129+
Returns: A 303 redirect response to the token provider URL.
130+
"""
131+
logger.debug("Handling GET Token")
132+
return {"statusCode": 303, "headers": {"Location": self.provider_url}}
133+
134+
@staticmethod
135+
def extract_token(event_headers: Dict[str, str]) -> str:
136+
"""
137+
Extracts the bearer (custom/standard) token from event headers.
138+
Args:
139+
event_headers (Dict[str, str]): The event headers.
140+
Returns:
141+
str: The extracted bearer token, or an empty string if not found.
142+
"""
143+
if not event_headers:
144+
return ""
145+
146+
# Normalize keys to lowercase for case-insensitive lookup
147+
lowered = {str(k).lower(): v for k, v in event_headers.items()}
148+
149+
# Direct bearer header (raw token)
150+
if "bearer" in lowered and isinstance(lowered["bearer"], str):
151+
token_candidate = lowered["bearer"].strip()
152+
if token_candidate:
153+
return token_candidate
154+
155+
# Authorization header with Bearer scheme
156+
auth_val = lowered.get("authorization", "")
157+
if not isinstance(auth_val, str): # defensive
158+
return ""
159+
auth_val = auth_val.strip()
160+
if not auth_val:
161+
return ""
162+
163+
# Case-insensitive match for 'Bearer ' prefix
164+
if not auth_val.lower().startswith("bearer "):
165+
return ""
166+
token_part = auth_val[7:].strip() # len('Bearer ')==7
167+
return token_part

src/utils/constants.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#
2+
# Copyright 2025 ABSA Group Limited
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
17+
"""
18+
This module contains all constants and enums used across the project.
19+
"""
20+
21+
# Token related configuration keys
22+
TOKEN_PROVIDER_URL_KEY = "token_provider_url"
23+
TOKEN_PUBLIC_KEY_URL_KEY = "token_public_key_url"
24+
TOKEN_PUBLIC_KEYS_URL_KEY = "token_public_keys_url"

0 commit comments

Comments
 (0)