|
15 | 15 | # |
16 | 16 |
|
17 | 17 | import json |
18 | | -from unittest.mock import patch |
| 18 | +from datetime import datetime, timedelta, timezone |
| 19 | +from unittest.mock import patch, Mock |
19 | 20 |
|
20 | 21 | import pytest |
| 22 | +from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey |
21 | 23 |
|
22 | 24 | from src.handlers.handler_token import HandlerToken |
23 | 25 |
|
24 | 26 |
|
| 27 | +@pytest.fixture |
| 28 | +def token_handler(): |
| 29 | + """Create a HandlerToken instance for testing.""" |
| 30 | + config = {"token_public_keys_url": "https://example.com/keys"} |
| 31 | + return HandlerToken(config) |
| 32 | + |
| 33 | + |
25 | 34 | def test_get_token_endpoint(event_gate_module, make_event): |
26 | 35 | event = make_event("/token") |
27 | 36 | resp = event_gate_module.lambda_handler(event, None) |
@@ -89,3 +98,47 @@ def test_extract_token_empty(): |
89 | 98 | def test_extract_token_direct_bearer_header(): |
90 | 99 | token = HandlerToken.extract_token({"Bearer": " tok123 "}) |
91 | 100 | assert token == "tok123" |
| 101 | + |
| 102 | + |
| 103 | +## Checking the freshness of public keys |
| 104 | +def test_refresh_keys_not_needed_when_keys_fresh(token_handler): |
| 105 | + """Keys loaded less than 30 minutes ago should not trigger refresh.""" |
| 106 | + token_handler._last_loaded_at = datetime.now(timezone.utc) - timedelta(minutes=10) |
| 107 | + token_handler.public_keys = [Mock(spec=RSAPublicKey)] |
| 108 | + |
| 109 | + with patch.object(token_handler, "load_public_keys") as mock_load: |
| 110 | + token_handler._refresh_keys_if_needed() |
| 111 | + mock_load.assert_not_called() |
| 112 | + |
| 113 | + |
| 114 | +def test_refresh_keys_triggered_when_keys_stale(token_handler): |
| 115 | + """Keys loaded more than 30 minutes ago should trigger refresh.""" |
| 116 | + token_handler._last_loaded_at = datetime.now(timezone.utc) - timedelta(minutes=29) |
| 117 | + token_handler.public_keys = [Mock(spec=RSAPublicKey)] |
| 118 | + |
| 119 | + with patch.object(token_handler, "load_public_keys") as mock_load: |
| 120 | + token_handler._refresh_keys_if_needed() |
| 121 | + mock_load.assert_called_once() |
| 122 | + |
| 123 | + |
| 124 | +def test_refresh_keys_handles_load_failure_gracefully(token_handler): |
| 125 | + """If key refresh fails, should log warning and continue with existing keys.""" |
| 126 | + old_key = Mock(spec=RSAPublicKey) |
| 127 | + token_handler.public_keys = [old_key] |
| 128 | + token_handler._last_loaded_at = datetime.now(timezone.utc) - timedelta(minutes=29) |
| 129 | + |
| 130 | + with patch.object(token_handler, "load_public_keys", side_effect=RuntimeError("Network error")): |
| 131 | + token_handler._refresh_keys_if_needed() |
| 132 | + assert token_handler.public_keys == [old_key] |
| 133 | + |
| 134 | + |
| 135 | +def test_decode_jwt_triggers_refresh_check(token_handler): |
| 136 | + """Decoding JWT should check if keys need refresh before decoding.""" |
| 137 | + dummy_key = Mock(spec=RSAPublicKey) |
| 138 | + token_handler.public_keys = [dummy_key] |
| 139 | + token_handler._last_loaded_at = datetime.now(timezone.utc) - timedelta(minutes=10) |
| 140 | + |
| 141 | + with patch.object(token_handler, "_refresh_keys_if_needed") as mock_refresh: |
| 142 | + with patch("jwt.decode", return_value={"sub": "TestUser"}): |
| 143 | + token_handler.decode_jwt("dummy-token") |
| 144 | + mock_refresh.assert_called_once() |
0 commit comments