Skip to content

Commit 511b5e1

Browse files
committed
Add endpoint for JWT refresh tokens
1 parent ca2e600 commit 511b5e1

File tree

6 files changed

+328
-1
lines changed

6 files changed

+328
-1
lines changed

changelog.d/3270.added.md

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add endpoint for JWT refresh tokens

python/nav/web/api/v1/urls.py

+1
Original file line numberDiff line numberDiff line change
@@ -73,4 +73,5 @@
7373
name="prefix-usage-detail",
7474
),
7575
re_path(r'^', include(router.urls)),
76+
re_path(r'^refresh/$', views.JWTRefreshViewSet.as_view(), name='jwt-refresh'),
7677
]

python/nav/web/api/v1/views.py

+50
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,21 @@
4545
from oidc_auth.authentication import JSONWebTokenAuthentication
4646

4747
from nav.models import manage, event, cabling, rack, profiles
48+
from nav.models.api import JWTRefreshToken
4849
from nav.models.fields import INFINITY, UNRESOLVED
4950
from nav.web.servicecheckers import load_checker_classes
5051
from nav.util import auth_token, is_valid_cidr
5152

5253
from nav.buildconf import VERSION
5354
from nav.web.api.v1 import serializers, alert_serializers
5455
from nav.web.status2 import STATELESS_THRESHOLD
56+
from nav.web.jwtgen import (
57+
generate_access_token,
58+
generate_refresh_token,
59+
hash_token,
60+
decode_token,
61+
is_active,
62+
)
5563
from nav.macaddress import MacPrefix
5664
from .auth import (
5765
APIPermission,
@@ -1153,3 +1161,45 @@ class ModuleViewSet(NAVAPIMixin, viewsets.ReadOnlyModelViewSet):
11531161
'device__serial',
11541162
)
11551163
serializer_class = serializers.ModuleSerializer
1164+
1165+
1166+
class JWTRefreshViewSet(APIView):
1167+
"""
1168+
Accepts a valid refresh token.
1169+
Returns a new refresh token and an access token.
1170+
"""
1171+
1172+
def post(self, request):
1173+
incoming_token = request.data.get('refresh_token')
1174+
token_hash = hash_token(incoming_token)
1175+
try:
1176+
# If hash exists in the database, then we know it is a real token
1177+
db_token = JWTRefreshToken.objects.get(hash=token_hash)
1178+
except JWTRefreshToken.DoesNotExist:
1179+
return Response("Invalid token", status=status.HTTP_403_FORBIDDEN)
1180+
1181+
claims = decode_token(incoming_token)
1182+
if not is_active(claims['exp'], claims['nbf']):
1183+
return Response("Inactive token", status=status.HTTP_403_FORBIDDEN)
1184+
1185+
if db_token.revoked:
1186+
return Response(
1187+
"This token has been revoked", status=status.HTTP_403_FORBIDDEN
1188+
)
1189+
1190+
access_token = generate_access_token(claims)
1191+
refresh_token = generate_refresh_token(claims)
1192+
1193+
new_claims = decode_token(refresh_token)
1194+
new_hash = hash_token(refresh_token)
1195+
db_token.hash = new_hash
1196+
db_token.expires = datetime.fromtimestamp(new_claims['exp'])
1197+
db_token.activates = datetime.fromtimestamp(new_claims['nbf'])
1198+
db_token.last_used = datetime.now()
1199+
db_token.save()
1200+
1201+
response_data = {
1202+
'access_token': access_token,
1203+
'refresh_token': refresh_token,
1204+
}
1205+
return Response(response_data)

python/nav/web/jwtgen.py

+13
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from datetime import datetime, timedelta, timezone
22
from typing import Any, Optional
3+
import hashlib
34

45
import jwt
56

@@ -64,3 +65,15 @@ def is_active(exp: float, nbf: float) -> bool:
6465
expires = datetime.fromtimestamp(exp, tz=timezone.utc)
6566
activates = datetime.fromtimestamp(nbf, tz=timezone.utc)
6667
return now >= activates and now < expires
68+
69+
70+
def hash_token(token: str) -> str:
71+
"""Hashes a token with SHA256"""
72+
hash_object = hashlib.sha256(token.encode('utf-8'))
73+
hex_dig = hash_object.hexdigest()
74+
return hex_dig
75+
76+
77+
def decode_token(token: str) -> dict[str, Any]:
78+
"""Decodes a token in JWT format and returns the data of the decoded token"""
79+
return jwt.decode(token, options={'verify_signature': False})
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
import hashlib
2+
from typing import Generator
3+
4+
import jwt
5+
import pytest
6+
from datetime import datetime, timedelta, timezone
7+
8+
from unittest.mock import MagicMock, Mock, patch
9+
10+
from django.urls import reverse
11+
from nav.models.api import JWTRefreshToken
12+
13+
14+
def test_token_not_in_database_should_be_rejected(db, api_client, url, active_token):
15+
token_hash = hashlib.sha256(active_token.encode('utf-8')).hexdigest()
16+
assert not JWTRefreshToken.objects.filter(hash=token_hash).exists()
17+
response = api_client.post(
18+
url,
19+
follow=True,
20+
data={
21+
'refresh_token': active_token,
22+
},
23+
)
24+
assert response.status_code == 403
25+
26+
27+
def test_inactive_token_should_be_rejected(db, api_client, url, inactive_token):
28+
now = datetime.now()
29+
token_hash = hashlib.sha256(inactive_token.encode('utf-8')).hexdigest()
30+
db_token = JWTRefreshToken(
31+
name="testtoken",
32+
hash=token_hash,
33+
expires=now - timedelta(hours=1),
34+
activates=now - timedelta(hours=2),
35+
)
36+
db_token.save()
37+
38+
response = api_client.post(
39+
url,
40+
follow=True,
41+
data={
42+
'refresh_token': inactive_token,
43+
},
44+
)
45+
46+
assert response.status_code == 403
47+
48+
49+
def test_valid_token_should_be_accepted(db, api_client, url, active_token):
50+
data = jwt.decode(active_token, options={'verify_signature': False})
51+
token_hash = token_hash = hashlib.sha256(active_token.encode('utf-8')).hexdigest()
52+
db_token = JWTRefreshToken(
53+
name="testtoken",
54+
hash=token_hash,
55+
expires=datetime.fromtimestamp(data['exp']),
56+
activates=datetime.fromtimestamp(data['nbf']),
57+
)
58+
db_token.save()
59+
response = api_client.post(
60+
url,
61+
follow=True,
62+
data={
63+
'refresh_token': active_token,
64+
},
65+
)
66+
assert response.status_code == 200
67+
68+
69+
def test_valid_token_should_be_replaced_by_new_token_in_db(
70+
db, api_client, url, active_token
71+
):
72+
token_hash = token_hash = hashlib.sha256(active_token.encode('utf-8')).hexdigest()
73+
data = jwt.decode(active_token, options={'verify_signature': False})
74+
db_token = JWTRefreshToken(
75+
name="testtoken",
76+
hash=token_hash,
77+
expires=datetime.fromtimestamp(data['exp']),
78+
activates=datetime.fromtimestamp(data['nbf']),
79+
)
80+
db_token.save()
81+
response = api_client.post(
82+
url,
83+
follow=True,
84+
data={
85+
'refresh_token': active_token,
86+
},
87+
)
88+
assert response.status_code == 200
89+
assert not JWTRefreshToken.objects.filter(hash=token_hash).exists()
90+
new_token = response.data.get("refresh_token")
91+
new_hash = hashlib.sha256(new_token.encode('utf-8')).hexdigest()
92+
assert JWTRefreshToken.objects.filter(hash=new_hash).exists()
93+
94+
95+
def test_should_include_access_and_refresh_token_in_response(
96+
db, api_client, url, active_token
97+
):
98+
token_hash = hashlib.sha256(active_token.encode('utf-8')).hexdigest()
99+
data = jwt.decode(active_token, options={'verify_signature': False})
100+
db_token = JWTRefreshToken(
101+
name="testtoken",
102+
hash=token_hash,
103+
expires=datetime.fromtimestamp(data['exp']),
104+
activates=datetime.fromtimestamp(data['nbf']),
105+
)
106+
db_token.save()
107+
response = api_client.post(
108+
url,
109+
follow=True,
110+
data={
111+
'refresh_token': active_token,
112+
},
113+
)
114+
assert response.status_code == 200
115+
assert "access_token" in response.data
116+
assert "refresh_token" in response.data
117+
118+
119+
def test_revoked_token_should_be_rejected(db, api_client, url, active_token):
120+
token_hash = hashlib.sha256(active_token.encode('utf-8')).hexdigest()
121+
data = jwt.decode(active_token, options={'verify_signature': False})
122+
db_token = JWTRefreshToken(
123+
name="testtoken",
124+
hash=token_hash,
125+
expires=datetime.fromtimestamp(data['exp']),
126+
activates=datetime.fromtimestamp(data['nbf']),
127+
revoked=True,
128+
)
129+
db_token.save()
130+
response = api_client.post(
131+
url,
132+
follow=True,
133+
data={
134+
'refresh_token': active_token,
135+
},
136+
)
137+
assert response.status_code == 403
138+
139+
140+
def test_last_used_should_be_updated_after_token_is_used(
141+
db, api_client, url, active_token
142+
):
143+
token_hash = hashlib.sha256(active_token.encode('utf-8')).hexdigest()
144+
data = jwt.decode(active_token, options={'verify_signature': False})
145+
db_token = JWTRefreshToken(
146+
name="testtoken",
147+
hash=token_hash,
148+
expires=datetime.fromtimestamp(data['exp']),
149+
activates=datetime.fromtimestamp(data['nbf']),
150+
)
151+
db_token.save()
152+
assert db_token.last_used is None
153+
response = api_client.post(
154+
url,
155+
follow=True,
156+
data={
157+
'refresh_token': active_token,
158+
},
159+
)
160+
new_token = response.data.get("refresh_token")
161+
new_hash = hashlib.sha256(new_token.encode('utf-8')).hexdigest()
162+
assert JWTRefreshToken.objects.get(hash=new_hash).last_used is not None
163+
164+
165+
@pytest.fixture()
166+
def inactive_token(nav_name, rsa_private_key) -> str:
167+
now = datetime.now(timezone.utc)
168+
claims = {
169+
'exp': (now - timedelta(hours=1)).timestamp(),
170+
'nbf': (now - timedelta(hours=2)).timestamp(),
171+
'iat': (now - timedelta(hours=2)).timestamp(),
172+
'aud': nav_name,
173+
'iss': nav_name,
174+
'token_type': 'refresh_token',
175+
}
176+
token = jwt.encode(claims, rsa_private_key, algorithm="RS256")
177+
return token
178+
179+
180+
@pytest.fixture()
181+
def active_token(nav_name, rsa_private_key) -> str:
182+
now = datetime.now(timezone.utc)
183+
claims = {
184+
'exp': (now + timedelta(hours=1)).timestamp(),
185+
'nbf': now.timestamp(),
186+
'iat': now.timestamp(),
187+
'aud': nav_name,
188+
'iss': nav_name,
189+
'token_type': 'refresh_token',
190+
}
191+
token = jwt.encode(claims, rsa_private_key, algorithm="RS256")
192+
return token
193+
194+
195+
@pytest.fixture()
196+
def url():
197+
return reverse('api:1:jwt-refresh')
198+
199+
200+
@pytest.fixture(scope="module", autouse=True)
201+
def jwtconf_mock(rsa_private_key, nav_name) -> Generator[MagicMock, None, None]:
202+
"""Mocks the get_nave_name and get_nav_private_key functions for
203+
the JWTConf class
204+
"""
205+
with patch("nav.web.jwtgen.JWTConf") as _jwtconf_mock:
206+
instance = _jwtconf_mock.return_value
207+
instance.get_nav_name = Mock(return_value=nav_name)
208+
instance.get_nav_private_key = Mock(return_value=rsa_private_key)
209+
yield _jwtconf_mock
210+
211+
212+
@pytest.fixture(scope="module")
213+
def nav_name() -> str:
214+
return "nav"

tests/unittests/web/jwtgen_test.py

+49-1
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
1+
from typing import Any
12
import pytest
23
from unittest.mock import Mock, patch
34
from datetime import datetime
45

56
import jwt
67

7-
from nav.web.jwtgen import generate_access_token, generate_refresh_token
8+
from nav.web.jwtgen import (
9+
generate_access_token,
10+
generate_refresh_token,
11+
hash_token,
12+
decode_token,
13+
)
814

915

1016
class TestTokenGeneration:
@@ -55,6 +61,16 @@ def test_token_type_should_be_refresh_token(self):
5561
assert data['token_type'] == "refresh_token"
5662

5763

64+
class TestHashToken:
65+
def test_should_return_correct_hash(self, token_string, token_hash):
66+
assert hash_token(token_string) == token_hash
67+
68+
69+
class TestDecodeToken:
70+
def test_should_return_expected_data(self, token_string, token_data):
71+
assert decode_token(token_string) == token_data
72+
73+
5874
@pytest.fixture(scope="module", autouse=True)
5975
def jwtconf_mock(rsa_private_key, nav_name) -> str:
6076
"""Mocks the get_nav_name and get_nav_private_key functions for
@@ -70,3 +86,35 @@ def jwtconf_mock(rsa_private_key, nav_name) -> str:
7086
@pytest.fixture(scope="module")
7187
def nav_name() -> str:
7288
yield "nav"
89+
90+
91+
@pytest.fixture(scope="module")
92+
def token_string() -> str:
93+
"""String representation of a token. Matching data is in `token_data`
94+
and expected hash is in `token_hash`
95+
"""
96+
token = (
97+
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpYXQ"
98+
"iOjE3NDA0Nzg4NTQsImV4cCI6MTc0MDU2NTI1NH0.2GbcpbwzVOAV7"
99+
"nv4lAS_ISrw-g9WKvhKKnpN9dhSL6s"
100+
)
101+
return token
102+
103+
104+
@pytest.fixture(scope="module")
105+
def token_hash() -> str:
106+
"""SHA256 hash of a token. Matching data is in `token_data`
107+
and the actual token string is in `token_string`
108+
"""
109+
return "91d0d189dde6a7423b884f8bb285b17f9706d21e6d0ce45aac028a22b3067395"
110+
111+
112+
@pytest.fixture(scope="module")
113+
def token_data() -> dict[str, Any]:
114+
"""Payload of a token. The actual token string is in `token_string`
115+
and hash of the token in `token_hash`
116+
"""
117+
return {
118+
"iat": 1740478854,
119+
"exp": 1740565254,
120+
}

0 commit comments

Comments
 (0)