Skip to content

Commit a0295d4

Browse files
authored
Periodically refresh JWT public key set (#88)
* Periodic JWT Public Key Refresh Implementation
1 parent d23819c commit a0295d4

File tree

2 files changed

+79
-1
lines changed

2 files changed

+79
-1
lines changed

src/handlers/handler_token.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import base64
2222
import logging
2323
import os
24+
from datetime import datetime, timedelta, timezone
2425
from typing import Dict, Any, cast
2526

2627
import jwt
@@ -41,10 +42,31 @@ class HandlerToken:
4142
HandlerToken manages token provider URL and public keys for JWT verification.
4243
"""
4344

45+
_REFRESH_INTERVAL = timedelta(minutes=28)
46+
4447
def __init__(self, config):
4548
self.provider_url: str = config.get(TOKEN_PROVIDER_URL, "")
4649
self.public_keys_url: str = config.get(TOKEN_PUBLIC_KEYS_URL) or config.get(TOKEN_PUBLIC_KEY_URL)
4750
self.public_keys: list[RSAPublicKey] = []
51+
self._last_loaded_at: datetime | None = None
52+
53+
def _refresh_keys_if_needed(self) -> None:
54+
"""
55+
Refresh the public keys if the refresh interval has passed.
56+
"""
57+
logger.debug("Checking if the token public keys need refresh")
58+
59+
if self._last_loaded_at is None:
60+
return
61+
now = datetime.now(timezone.utc)
62+
if now - self._last_loaded_at < self._REFRESH_INTERVAL:
63+
logger.debug("Token public keys are up to date, no refresh needed")
64+
return
65+
try:
66+
logger.debug("Token public keys are stale, refreshing now")
67+
self.load_public_keys()
68+
except RuntimeError:
69+
logger.warning("Token public key refresh failed, using existing keys")
4870

4971
def load_public_keys(self) -> "HandlerToken":
5072
"""
@@ -75,6 +97,7 @@ def load_public_keys(self) -> "HandlerToken":
7597
cast(RSAPublicKey, serialization.load_der_public_key(base64.b64decode(raw_key))) for raw_key in raw_keys
7698
]
7799
logger.debug("Loaded %d token public keys", len(self.public_keys))
100+
self._last_loaded_at = datetime.now(timezone.utc)
78101

79102
return self
80103
except (requests.RequestException, ValueError, KeyError, UnsupportedAlgorithm) as exc:
@@ -91,6 +114,8 @@ def decode_jwt(self, token_encoded: str) -> Dict[str, Any]:
91114
Raises:
92115
jwt.PyJWTError: If verification fails for all public keys.
93116
"""
117+
self._refresh_keys_if_needed()
118+
94119
logger.debug("Decoding JWT")
95120
for public_key in self.public_keys:
96121
try:

tests/handlers/test_handler_token.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,22 @@
1515
#
1616

1717
import json
18-
from unittest.mock import patch
18+
from datetime import datetime, timedelta, timezone
19+
from unittest.mock import patch, Mock
1920

2021
import pytest
22+
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey
2123

2224
from src.handlers.handler_token import HandlerToken
2325

2426

27+
@pytest.fixture
28+
def token_handler():
29+
"""Create a HandlerToken instance for testing."""
30+
config = {"token_public_keys_url": "https://example.com/keys"}
31+
return HandlerToken(config)
32+
33+
2534
def test_get_token_endpoint(event_gate_module, make_event):
2635
event = make_event("/token")
2736
resp = event_gate_module.lambda_handler(event, None)
@@ -89,3 +98,47 @@ def test_extract_token_empty():
8998
def test_extract_token_direct_bearer_header():
9099
token = HandlerToken.extract_token({"Bearer": " tok123 "})
91100
assert token == "tok123"
101+
102+
103+
## Checking the freshness of public keys
104+
def test_refresh_keys_not_needed_when_keys_fresh(token_handler):
105+
"""Keys loaded less than 30 minutes ago should not trigger refresh."""
106+
token_handler._last_loaded_at = datetime.now(timezone.utc) - timedelta(minutes=10)
107+
token_handler.public_keys = [Mock(spec=RSAPublicKey)]
108+
109+
with patch.object(token_handler, "load_public_keys") as mock_load:
110+
token_handler._refresh_keys_if_needed()
111+
mock_load.assert_not_called()
112+
113+
114+
def test_refresh_keys_triggered_when_keys_stale(token_handler):
115+
"""Keys loaded more than 30 minutes ago should trigger refresh."""
116+
token_handler._last_loaded_at = datetime.now(timezone.utc) - timedelta(minutes=29)
117+
token_handler.public_keys = [Mock(spec=RSAPublicKey)]
118+
119+
with patch.object(token_handler, "load_public_keys") as mock_load:
120+
token_handler._refresh_keys_if_needed()
121+
mock_load.assert_called_once()
122+
123+
124+
def test_refresh_keys_handles_load_failure_gracefully(token_handler):
125+
"""If key refresh fails, should log warning and continue with existing keys."""
126+
old_key = Mock(spec=RSAPublicKey)
127+
token_handler.public_keys = [old_key]
128+
token_handler._last_loaded_at = datetime.now(timezone.utc) - timedelta(minutes=29)
129+
130+
with patch.object(token_handler, "load_public_keys", side_effect=RuntimeError("Network error")):
131+
token_handler._refresh_keys_if_needed()
132+
assert token_handler.public_keys == [old_key]
133+
134+
135+
def test_decode_jwt_triggers_refresh_check(token_handler):
136+
"""Decoding JWT should check if keys need refresh before decoding."""
137+
dummy_key = Mock(spec=RSAPublicKey)
138+
token_handler.public_keys = [dummy_key]
139+
token_handler._last_loaded_at = datetime.now(timezone.utc) - timedelta(minutes=10)
140+
141+
with patch.object(token_handler, "_refresh_keys_if_needed") as mock_refresh:
142+
with patch("jwt.decode", return_value={"sub": "TestUser"}):
143+
token_handler.decode_jwt("dummy-token")
144+
mock_refresh.assert_called_once()

0 commit comments

Comments
 (0)