Skip to content

Commit

Permalink
Added middleware to refresh access tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
daggaz committed Feb 28, 2023
1 parent 80adc32 commit 3d132a9
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 54 deletions.
36 changes: 33 additions & 3 deletions django_auth_adfs/backend.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import logging
from datetime import datetime, timedelta

import jwt
from django.contrib.auth import get_user_model
from django.contrib.auth import get_user_model, logout
from django.contrib.auth.backends import ModelBackend
from django.contrib.auth.models import Group
from django.core.exceptions import (ImproperlyConfigured, ObjectDoesNotExist,
PermissionDenied)
from requests import HTTPError

from django_auth_adfs import signals
from django_auth_adfs.config import provider_config, settings
Expand Down Expand Up @@ -398,10 +400,38 @@ def authenticate(self, request=None, authorization_code=None, **kwargs):
provider_config.load_config()

adfs_response = self.exchange_auth_code(authorization_code, request)
access_token = adfs_response["access_token"]
user = self.process_access_token(access_token, adfs_response)
user = self._process_adfs_response(request, adfs_response)
return user

def _process_adfs_response(self, request, adfs_response):
user = self.process_access_token(adfs_response['access_token'], adfs_response)
request.session['adfs_access_token'] = adfs_response['access_token']
expiry = datetime.now() + timedelta(seconds=adfs_response['expires_in'])
request.session['adfs_token_expiry'] = expiry.isoformat()
if 'refresh_token' in adfs_response:
request.session['adfs_refresh_token'] = adfs_response['refresh_token']
request.session.save()
return user

def process_request(self, request):
now = datetime.now() + settings.REFRESH_THRESHOLD
expiry = datetime.fromisoformat(request.session['adfs_token_expiry'])
if now > expiry:
try:
self._refresh_access_token(request, request.session['adfs_refresh_token'])
except (PermissionDenied, HTTPError):
logout(request)

def _refresh_access_token(self, request, refresh_token):
provider_config.load_config()
response = provider_config.session.post(
provider_config.token_endpoint,
data=f'grant_type=refresh_token&refresh_token={refresh_token}'
)
response.raise_for_status()
adfs_response = response.json()
self._process_adfs_response(request, adfs_response)


class AdfsAccessTokenBackend(AdfsBaseBackend):
"""
Expand Down
1 change: 1 addition & 0 deletions django_auth_adfs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def __init__(self):
self.USERNAME_CLAIM = "winaccountname"
self.GUEST_USERNAME_CLAIM = None
self.JWT_LEEWAY = 0
self.REFRESH_THRESHOLD = timedelta(minutes=5)
self.CUSTOM_FAILED_RESPONSE_VIEW = lambda request, error_message, status: render(
request, 'django_auth_adfs/login_failed.html', {'error_message': error_message}, status=status
)
Expand Down
16 changes: 16 additions & 0 deletions django_auth_adfs/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
from re import compile

from django.conf import settings as django_settings
from django.contrib import auth
from django.contrib.auth.views import redirect_to_login
from django.urls import reverse

from django_auth_adfs.backend import AdfsAuthCodeBackend
from django_auth_adfs.exceptions import MFARequired
from django_auth_adfs.config import settings

Expand Down Expand Up @@ -49,3 +51,17 @@ def __call__(self, request):
return redirect_to_login('django_auth_adfs:login-force-mfa')

return self.get_response(request)


def adfs_refresh_middleware(get_response):
def middleware(request):
try:
backend_str = request.session[auth.BACKEND_SESSION_KEY]
except KeyError:
pass
else:
backend = auth.load_backend(backend_str)
if isinstance(backend, AdfsAuthCodeBackend):
backend.process_request(request)
return get_response()
return middleware
102 changes: 51 additions & 51 deletions tests/test_authentication.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import base64

from django.urls import reverse

from django_auth_adfs.exceptions import MFARequired

try:
Expand All @@ -16,7 +18,6 @@
from mock import Mock, patch

from django_auth_adfs import signals
from django_auth_adfs.backend import AdfsAuthCodeBackend
from django_auth_adfs.config import ProviderConfig, Settings

from .models import Profile
Expand All @@ -34,14 +35,13 @@ def setUp(self):

@mock_adfs("2012")
def test_post_authenticate_signal_send(self):
backend = AdfsAuthCodeBackend()
backend.authenticate(self.request, authorization_code="dummycode")
response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"})
self.assertEqual(self.signal_handler.call_count, 1)

@mock_adfs("2012")
def test_with_auth_code_2012(self):
backend = AdfsAuthCodeBackend()
user = backend.authenticate(self.request, authorization_code="dummycode")
response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"})
user = response.wsgi_request.user
self.assertIsInstance(user, User)
self.assertEqual(user.first_name, "John")
self.assertEqual(user.last_name, "Doe")
Expand All @@ -52,8 +52,8 @@ def test_with_auth_code_2012(self):

@mock_adfs("2016")
def test_with_auth_code_2016(self):
backend = AdfsAuthCodeBackend()
user = backend.authenticate(self.request, authorization_code="dummycode")
response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"})
user = response.wsgi_request.user
self.assertIsInstance(user, User)
self.assertEqual(user.first_name, "John")
self.assertEqual(user.last_name, "Doe")
Expand All @@ -64,9 +64,15 @@ def test_with_auth_code_2016(self):

@mock_adfs("2016", mfa_error=True)
def test_mfa_error_backends(self):
with self.assertRaises(MFARequired):
backend = AdfsAuthCodeBackend()
backend.authenticate(self.request, authorization_code="dummycode")
response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"})
self.assertEqual(response.status_code, 302)
self.assertEqual(
response['Location'],
"https://adfs.example.com/adfs/oauth2/authorize/?response_type=code&"
"client_id=your-configured-client-id&resource=your-adfs-RPT-name&"
"redirect_uri=http%3A%2F%2Ftestserver%2Foauth2%2Fcallback&state=Lw%3D%3D&scope=openid&"
"amr_values=ngcmfa"
)

@mock_adfs("azure")
def test_with_auth_code_azure(self):
Expand All @@ -77,8 +83,8 @@ def test_with_auth_code_azure(self):
with patch("django_auth_adfs.config.django_settings", settings):
with patch("django_auth_adfs.config.settings", Settings()):
with patch("django_auth_adfs.backend.provider_config", ProviderConfig()):
backend = AdfsAuthCodeBackend()
user = backend.authenticate(self.request, authorization_code="dummycode")
response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"})
user = response.wsgi_request.user
self.assertIsInstance(user, User)
self.assertEqual(user.first_name, "John")
self.assertEqual(user.last_name, "Doe")
Expand All @@ -100,9 +106,8 @@ def test_with_auth_code_azure_guest_block(self):
with patch('django_auth_adfs.backend.settings', Settings()):
with patch("django_auth_adfs.config.settings", Settings()):
with patch("django_auth_adfs.backend.provider_config", ProviderConfig()):
with self.assertRaises(PermissionDenied, msg=''):
backend = AdfsAuthCodeBackend()
_ = backend.authenticate(self.request, authorization_code="dummycode")
response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"})
self.assertEqual(response.status_code, 401)

@mock_adfs("azure", guest=True)
def test_with_auth_code_azure_guest_no_block(self):
Expand All @@ -117,8 +122,8 @@ def test_with_auth_code_azure_guest_no_block(self):
with patch('django_auth_adfs.backend.settings', Settings()):
with patch("django_auth_adfs.config.settings", Settings()):
with patch("django_auth_adfs.backend.provider_config", ProviderConfig()):
backend = AdfsAuthCodeBackend()
user = backend.authenticate(self.request, authorization_code="dummycode")
response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"})
user = response.wsgi_request.user
self.assertIsInstance(user, User)
self.assertEqual(user.first_name, "John")
self.assertEqual(user.last_name, "Doe")
Expand All @@ -139,8 +144,8 @@ def test_version_two_endpoint_calls_correct_url(self):
with patch('django_auth_adfs.backend.settings', Settings()):
with patch("django_auth_adfs.config.settings", Settings()):
with patch("django_auth_adfs.backend.provider_config", ProviderConfig()):
backend = AdfsAuthCodeBackend()
user = backend.authenticate(self.request, authorization_code="dummycode")
response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"})
user = response.wsgi_request.user
self.assertIsInstance(user, User)
self.assertEqual(user.first_name, "John")
self.assertEqual(user.last_name, "Doe")
Expand All @@ -151,14 +156,15 @@ def test_version_two_endpoint_calls_correct_url(self):

@mock_adfs("2016")
def test_empty(self):
backend = AdfsAuthCodeBackend()
self.assertIsNone(backend.authenticate(self.request))
response = self.client.get(reverse('django_auth_adfs:callback'))
user = response.wsgi_request.user
self.assertTrue(user.is_anonymous)

@mock_adfs("2016")
def test_group_claim(self):
backend = AdfsAuthCodeBackend()
with patch("django_auth_adfs.backend.settings.GROUPS_CLAIM", "nonexisting"):
user = backend.authenticate(self.request, authorization_code="dummycode")
response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"})
user = response.wsgi_request.user
self.assertIsInstance(user, User)
self.assertEqual(user.first_name, "John")
self.assertEqual(user.last_name, "Doe")
Expand All @@ -167,9 +173,9 @@ def test_group_claim(self):

@mock_adfs("2016")
def test_no_group_claim(self):
backend = AdfsAuthCodeBackend()
with patch("django_auth_adfs.backend.settings.GROUPS_CLAIM", None):
user = backend.authenticate(self.request, authorization_code="dummycode")
response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"})
user = response.wsgi_request.user
self.assertIsInstance(user, User)
self.assertEqual(user.first_name, "John")
self.assertEqual(user.last_name, "Doe")
Expand All @@ -181,9 +187,9 @@ def test_group_claim_with_mirror_groups(self):
# Remove one group
Group.objects.filter(name="group1").delete()

backend = AdfsAuthCodeBackend()
with patch("django_auth_adfs.backend.settings.MIRROR_GROUPS", True):
user = backend.authenticate(self.request, authorization_code="dummycode")
response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"})
user = response.wsgi_request.user
self.assertIsInstance(user, User)
self.assertEqual(user.first_name, "John")
self.assertEqual(user.last_name, "Doe")
Expand All @@ -197,9 +203,9 @@ def test_group_claim_without_mirror_groups(self):
# Remove one group
Group.objects.filter(name="group1").delete()

backend = AdfsAuthCodeBackend()
with patch("django_auth_adfs.backend.settings.MIRROR_GROUPS", False):
user = backend.authenticate(self.request, authorization_code="dummycode")
response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"})
user = response.wsgi_request.user
self.assertIsInstance(user, User)
self.assertEqual(user.first_name, "John")
self.assertEqual(user.last_name, "Doe")
Expand All @@ -210,9 +216,9 @@ def test_group_claim_without_mirror_groups(self):

@mock_adfs("2016", empty_keys=True)
def test_empty_keys(self):
backend = AdfsAuthCodeBackend()
with patch("django_auth_adfs.config.provider_config.signing_keys", []):
self.assertRaises(PermissionDenied, backend.authenticate, self.request, authorization_code='testcode')
response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "testcode"})
self.assertEqual(response.status_code, 401)

@mock_adfs("2016")
def test_group_removal(self):
Expand All @@ -227,9 +233,8 @@ def test_group_removal(self):
self.assertEqual(user.groups.all()[0].name, "group3")
self.assertEqual(len(user.groups.all()), 1)

backend = AdfsAuthCodeBackend()

user = backend.authenticate(self.request, authorization_code="dummycode")
response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"})
user = response.wsgi_request.user
self.assertIsInstance(user, User)
self.assertEqual(user.first_name, "John")
self.assertEqual(user.last_name, "Doe")
Expand All @@ -253,9 +258,8 @@ def test_group_removal_overlap(self):
self.assertEqual(user.groups.all()[1].name, "group3")
self.assertEqual(len(user.groups.all()), 2)

backend = AdfsAuthCodeBackend()

user = backend.authenticate(self.request, authorization_code="dummycode")
response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"})
user = response.wsgi_request.user
self.assertIsInstance(user, User)
self.assertEqual(user.first_name, "John")
self.assertEqual(user.last_name, "Doe")
Expand All @@ -272,9 +276,8 @@ def test_group_to_flag_mapping(self):
}
with patch("django_auth_adfs.backend.settings.GROUP_TO_FLAG_MAPPING", group_to_flag_mapping):
with patch("django_auth_adfs.backend.settings.BOOLEAN_CLAIM_MAPPING", {}):
backend = AdfsAuthCodeBackend()

user = backend.authenticate(self.request, authorization_code="dummycode")
response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"})
user = response.wsgi_request.user
self.assertIsInstance(user, User)
self.assertEqual(user.first_name, "John")
self.assertEqual(user.last_name, "Doe")
Expand All @@ -289,9 +292,8 @@ def test_boolean_claim_mapping(self):
"is_superuser": "user_is_superuser",
}
with patch("django_auth_adfs.backend.settings.BOOLEAN_CLAIM_MAPPING", boolean_claim_mapping):
backend = AdfsAuthCodeBackend()

user = backend.authenticate(self.request, authorization_code="dummycode")
response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"})
user = response.wsgi_request.user
self.assertIsInstance(user, User)
self.assertEqual(user.first_name, "John")
self.assertEqual(user.last_name, "Doe")
Expand All @@ -312,9 +314,8 @@ def test_extended_model_claim_mapping_missing_instance(self):
},
}
with patch("django_auth_adfs.backend.settings.CLAIM_MAPPING", claim_mapping):
backend = AdfsAuthCodeBackend()

user = backend.authenticate(self.request, authorization_code="dummycode")
response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"})
user = response.wsgi_request.user
self.assertIsInstance(user, User)
self.assertEqual(user.first_name, "John")
self.assertEqual(user.last_name, "Doe")
Expand All @@ -340,9 +341,8 @@ def create_profile(sender, instance, created, **kwargs):
},
}
with patch("django_auth_adfs.backend.settings.CLAIM_MAPPING", claim_mapping):
backend = AdfsAuthCodeBackend()

user = backend.authenticate(self.request, authorization_code="dummycode")
response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"})
user = response.wsgi_request.user
self.assertIsInstance(user, User)
self.assertEqual(user.first_name, "John")
self.assertEqual(user.last_name, "Doe")
Expand Down Expand Up @@ -493,5 +493,5 @@ def test_nonexisting_user(self):
settings.AUTH_ADFS["CREATE_NEW_USERS"] = False
with patch("django_auth_adfs.config.django_settings", settings),\
patch("django_auth_adfs.backend.settings", Settings()):
backend = AdfsAuthCodeBackend()
self.assertRaises(PermissionDenied, backend.authenticate, self.request, authorization_code='testcode')
response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "testcode"})
self.assertEqual(response.status_code, 401)

0 comments on commit 3d132a9

Please sign in to comment.