Skip to content
32 changes: 23 additions & 9 deletions oauth2_provider/oauth2_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,19 +202,32 @@ def _load_application(self, client_id, request):
If request.client was not set, load application instance for given
client_id and store it in request.client
"""

# we want to be sure that request has the client attribute!
assert hasattr(request, "client"), '"request" instance has no "client" attribute'

if request.client:
# check for cached client, to save the db hit if this has already been loaded
if not isinstance(request.client, Application):
log.debug("request.client is not an Application, something else set request.client erroneously, resetting request.client.")
request.client = None
elif request.client.client_id != client_id:
log.debug("request.client client_id does not match the given client_id, resetting request.client.")
request.client = None
elif not request.client.is_usable(request):
log.debug("request.client is a valid Application, but is not usable, resetting request.client.")
request.client = None
else:
log.debug("request.client is a valid Application, reusing it.")
return request.client
try:
request.client = request.client or Application.objects.get(client_id=client_id)
# Check that the application can be used (defaults to always True)
if not request.client.is_usable(request):
log.debug("Failed body authentication: Application %r is disabled" % (client_id))
# cache wasn't hit, load from db
log.debug("cache not hit, Loading application from database for client_id %r", client_id)
client = Application.objects.get(client_id=client_id)
if not client.is_usable(request):
log.debug("Failed to load application: Application %r is not usable" % (client_id))
return None
log.debug("Loaded application %r from database", client)
request.client = client
return request.client
except Application.DoesNotExist:
log.debug("Failed body authentication: Application %r does not exist" % (client_id))
log.debug("Failed to load application: Application %r does not exist" % (client_id))
return None

def _set_oauth2_error_on_request(self, request, access_token, scopes):
Expand Down Expand Up @@ -277,6 +290,7 @@ def client_authentication_required(self, request, *args, **kwargs):
pass

self._load_application(request.client_id, request)
log.debug("Determining if client authentication is required for client %r", request.client)
if request.client:
return request.client.client_type == AbstractApplication.CLIENT_CONFIDENTIAL

Expand Down
21 changes: 21 additions & 0 deletions tests/test_authorization_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -1308,6 +1308,27 @@ def test_request_body_params(self):
self.assertEqual(content["scope"], "read write")
self.assertEqual(content["expires_in"], self.oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS)

def test_request_body_params_client_typo(self):
"""
Request an access token using client_type: public
"""
self.client.login(username="test_user", password="123456")
authorization_code = self.get_auth()

token_request_data = {
"grant_type": "authorization_code",
"code": authorization_code,
"redirect_uri": "http://example.org",
"client": self.application.client_id,
"client_secret": CLEARTEXT_SECRET,
}

response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data)
self.assertEqual(response.status_code, 401)

content = json.loads(response.content.decode("utf-8"))
self.assertEqual(content["error"], "invalid_client")

def test_public(self):
"""
Request an access token using client_type: public
Expand Down
48 changes: 46 additions & 2 deletions tests/test_oauth2_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,52 @@ def test_client_authentication_required(self):
self.request.client = ""
self.assertTrue(self.validator.client_authentication_required(self.request))

def test_load_application_fails_when_request_has_no_client(self):
self.assertRaises(AssertionError, self.validator.authenticate_client_id, "client_id", {})
def test_load_application_loads_client_id_when_request_has_no_client(self):
self.request.client = None
application = self.validator._load_application("client_id", self.request)
self.assertEqual(application, self.application)

def test_load_application_uses_cached_when_request_has_valid_client_matching_client_id(self):
self.request.client = self.application
application = self.validator._load_application("client_id", self.request)
self.assertIs(application, self.application)
self.assertIs(self.request.client, self.application)

def test_load_application_succeeds_when_request_has_invalid_client_valid_client_id(self):
self.request.client = 'invalid_client'
application = self.validator._load_application("client_id", self.request)
self.assertEqual(application, self.application)
self.assertEqual(self.request.client, self.application)

def test_load_application_overwrites_client_on_client_id_mismatch(self):
another_application = Application.objects.create(
client_id="another_client_id",
client_secret=CLEARTEXT_SECRET,
user=self.user,
client_type=Application.CLIENT_PUBLIC,
authorization_grant_type=Application.GRANT_PASSWORD,
)
self.request.client = another_application
application = self.validator._load_application("client_id", self.request)
self.assertEqual(application, self.application)
self.assertEqual(self.request.client, self.application)
another_application.delete()

@mock.patch.object(Application, "is_usable")
def test_load_application_returns_none_when_client_not_usable_cached(self, mock_is_usable):
mock_is_usable.return_value = False
self.request.client = self.application
application = self.validator._load_application("client_id", self.request)
self.assertIsNone(application)
self.assertIsNone(self.request.client)

@mock.patch.object(Application, "is_usable")
def test_load_application_returns_none_when_client_not_usable_db_lookup(self, mock_is_usable):
mock_is_usable.return_value = False
self.request.client = None
application = self.validator._load_application("client_id", self.request)
self.assertIsNone(application)
self.assertIsNone(self.request.client)

def test_rotate_refresh_token__is_true(self):
self.assertTrue(self.validator.rotate_refresh_token(mock.MagicMock()))
Expand Down
Loading