diff --git a/.gitignore b/.gitignore index af644d1..c22ef00 100644 --- a/.gitignore +++ b/.gitignore @@ -25,7 +25,7 @@ __pycache__ pip-log.txt # Unit test / coverage reports -.cache +.pytest_cache .coverage .tox .pytest_cache/ diff --git a/.travis.yml b/.travis.yml index 41b7f88..1715b59 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,31 +1,18 @@ -# https://travis-ci.org/evonove/django-oauth-toolkit -sudo: false -language: python +# https://travis-ci.org/jazzband/django-oauth-toolkit +dist: xenial -python: "3.6" +language: python -env: - - TOXENV=py27-django111 - - TOXENV=py34-django111 - - TOXENV=py36-django111 - - TOXENV=py36-django20 - - TOXENV=py36-djangomaster - - TOXENV=docs - - TOXENV=flake8 +python: + - "3.4" cache: directories: - $HOME/.cache/pip - $TRAVIS_BUILD_DIR/.tox -matrix: - fast_finish: true - - allow_failures: - - env: TOXENV=py36-djangomaster - install: - - pip install coveralls tox + - pip install coveralls tox tox-travis script: - tox diff --git a/CHANGELOG.md b/CHANGELOG.md index a5140d1..56cabbf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,7 +1,18 @@ -### 1.1.3 [2018-10-12] +### 1.3.0 [unreleased] + +* Fix a race condition in creation of AccessToken with external oauth2 server. +* **Backwards-incompatible** squashed migrations: + If you are currently on a release < 1.2.0, you will need to first install 1.2.x then `manage.py migrate` before + upgrading to >= 1.3.0. + +### 1.2.0 [2018-06-03] + +* **Compatibility**: Python 3.4 is the new minimum required version. +* **Compatibility**: Django 2.0 is the new minimum required version. +* **New feature**: Added TokenMatchesOASRequirements Permissions. +* validators.URIValidator has been updated to match URLValidator behaviour more closely. +* Moved `redirect_uris` validation to the application clean() method. -* Fix a concurrency issue with Refresh Tokens (#638) -* Fix Refresh Token revocation when the Access Token does not exist (#625) ### 1.1.2 [2018-05-12] diff --git a/README.rst b/README.rst index 2df9e2a..a647f54 100644 --- a/README.rst +++ b/README.rst @@ -42,8 +42,8 @@ Please report any security issues to the JazzBand security team at " -X POST -d"username=foo&password=bar" http://localhost:8000/users/ +Some time has passed and your access token is about to expire, you can get renew the access token issued using the `refresh token`: + +:: + + curl -X POST -d "grant_type=refresh_token&refresh_token=&client_id=&client_secret=" http://localhost:8000/o/token/ + +Your response should be similar to your first access_token request, containing a new access_token and refresh_token: + +.. code-block:: javascript + + { + "access_token": "", + "token_type": "Bearer", + "expires_in": 36000, + "refresh_token": "", + "scope": "read write groups" + } + + + Step 5: Testing Restricted Access --------------------------------- diff --git a/docs/rest-framework/openapi.yaml b/docs/rest-framework/openapi.yaml new file mode 100644 index 0000000..5c2e9a5 --- /dev/null +++ b/docs/rest-framework/openapi.yaml @@ -0,0 +1,49 @@ +openapi: "3.0.0" +info: + title: songs + version: v1 +components: + securitySchemes: + song_auth: + type: oauth2 + flows: + implicit: + authorizationUrl: http://localhost:8000/o/authorize + scopes: + read: read about a song + create: create a new song + update: update an existing song + delete: delete a song + post: create a new song + widget: widget scope + scope2: scope too + scope3: another scope +paths: + /songs: + get: + security: + - song_auth: [read] + responses: + '200': + description: A list of songs. + post: + security: + - song_auth: [create] + - song_auth: [post, widget] + responses: + '201': + description: new song added + put: + security: + - song_auth: [update] + - song_auth: [put, widget] + responses: + '204': + description: song updated + delete: + security: + - song_auth: [delete] + - song_auth: [scope2, scope3] + responses: + '200': + description: song deleted diff --git a/docs/rest-framework/permissions.rst b/docs/rest-framework/permissions.rst index b84c0a0..1058aed 100644 --- a/docs/rest-framework/permissions.rst +++ b/docs/rest-framework/permissions.rst @@ -48,6 +48,7 @@ For example: When a request is performed both the `READ_SCOPE` \\ `WRITE_SCOPE` and 'music' scopes are required to be authorized for the current access token. + TokenHasResourceScope ---------------------- The `TokenHasResourceScope` permission class allows access only when the current access token has been authorized for **all** the scopes listed in the `required_scopes` field of the view but according of request's method. @@ -81,3 +82,36 @@ For example: required_scopes = ['music'] The `required_scopes` attribute is mandatory. + + +TokenMatchesOASRequirements +------------------------------ + +The `TokenMatchesOASRequirements` permission class allows the access based on a per-method basis +and with alternative lists of required scopes. This permission provides full functionality +required by REST API specifications like the +`OpenAPI Specification (OAS) security requirement object `_. + +The `required_alternate_scopes` attribute is a required map keyed by HTTP method name where each value is +a list of alternative lists of required scopes. + +In the follow example GET requires "read" scope, POST requires either "create" scope **OR** "post" and "widget" scopes, +etc. + +.. code-block:: python + + class SongView(views.APIView): + authentication_classes = [OAuth2Authentication] + permission_classes = [TokenMatchesOASRequirements] + required_alternate_scopes = { + "GET": [["read"]], + "POST": [["create"], ["post", "widget"]], + "PUT": [["update"], ["put", "widget"]], + "DELETE": [["delete"], ["scope2", "scope3"]], + } + +The following is a minimal OAS declaration that shows the same required alternate scopes. It is complete enough +to try it in the `swagger editor `_. + +.. literalinclude:: openapi.yaml + :language: YAML \ No newline at end of file diff --git a/docs/settings.rst b/docs/settings.rst index 506d57d..49a0608 100644 --- a/docs/settings.rst +++ b/docs/settings.rst @@ -185,3 +185,10 @@ RESOURCE_SERVER_TOKEN_CACHING_SECONDS The number of seconds an authorization token received from the introspection endpoint remains valid. If the expire time of the received token is less than ``RESOURCE_SERVER_TOKEN_CACHING_SECONDS`` the expire time will be used. + + +PKCE_REQUIRED +~~~~~~~~~~~~~ +Default: ``False`` + +Whether or not PKCE is required. Can be either a bool or a callable that takes a client id and returns a bool. diff --git a/docs/tutorial/tutorial_01.rst b/docs/tutorial/tutorial_01.rst index 9f32c87..eaaab05 100644 --- a/docs/tutorial/tutorial_01.rst +++ b/docs/tutorial/tutorial_01.rst @@ -85,7 +85,8 @@ the API, subject to approval by its users. Let's register your application. -Point your browser to http://localhost:8000/o/applications/ and add an Application instance. +You need to be logged in before registration. So, go to http://localhost:8000/admin and log in. After that +point your browser to http://localhost:8000/o/applications/ and add an Application instance. `Client id` and `Client Secret` are automatically generated; you have to provide the rest of the informations: * `User`: the owner of the Application (e.g. a developer, or the currently logged in user.) diff --git a/oauth2_provider/apps.py b/oauth2_provider/apps.py index 79453ee..887e4e3 100644 --- a/oauth2_provider/apps.py +++ b/oauth2_provider/apps.py @@ -3,4 +3,4 @@ class DOTConfig(AppConfig): name = "oauth2_provider" - verbose_name = "Django/GeoNode OAuth Toolkit" + verbose_name = "Django OAuth Toolkit" diff --git a/oauth2_provider/compat.py b/oauth2_provider/compat.py index 6e455b0..0c83cb3 100644 --- a/oauth2_provider/compat.py +++ b/oauth2_provider/compat.py @@ -1,27 +1,4 @@ """ The `compat` module provides support for backwards compatibility with older -versions of django and python. +versions of Django and Python. """ -# flake8: noqa -from __future__ import unicode_literals - - -# urlparse in python3 has been renamed to urllib.parse -try: - from urlparse import parse_qs, parse_qsl, urlparse, urlsplit, urlunparse, urlunsplit -except ImportError: - from urllib.parse import parse_qs, parse_qsl, urlparse, urlsplit, urlunsplit, urlunparse - -try: - from urllib import urlencode, quote_plus, unquote_plus -except ImportError: - from urllib.parse import urlencode, quote_plus, unquote_plus - -# bastb Django 1.10 has updated Middleware. This code imports the Mixin required to get old-style -# middleware working again -# More? -# https://docs.djangoproject.com/en/1.10/topics/http/middleware/#upgrading-pre-django-1-10-style-middleware -try: - from django.utils.deprecation import MiddlewareMixin -except ImportError: - MiddlewareMixin = object diff --git a/oauth2_provider/contrib/rest_framework/__init__.py b/oauth2_provider/contrib/rest_framework/__init__.py index 4b82672..a004c18 100644 --- a/oauth2_provider/contrib/rest_framework/__init__.py +++ b/oauth2_provider/contrib/rest_framework/__init__.py @@ -1,4 +1,6 @@ # flake8: noqa from .authentication import OAuth2Authentication -from .permissions import TokenHasScope, TokenHasReadWriteScope, TokenHasResourceScope -from .permissions import IsAuthenticatedOrTokenHasScope +from .permissions import ( + TokenHasScope, TokenHasReadWriteScope, TokenMatchesOASRequirements, + TokenHasResourceScope, IsAuthenticatedOrTokenHasScope +) diff --git a/oauth2_provider/contrib/rest_framework/authentication.py b/oauth2_provider/contrib/rest_framework/authentication.py index 30a2d52..2283619 100644 --- a/oauth2_provider/contrib/rest_framework/authentication.py +++ b/oauth2_provider/contrib/rest_framework/authentication.py @@ -39,7 +39,8 @@ def authenticate_header(self, request): www_authenticate_attributes = OrderedDict([ ("realm", self.www_authenticate_realm,), ]) - www_authenticate_attributes.update(request.oauth2_error) + oauth2_error = getattr(request, "oauth2_error", {}) + www_authenticate_attributes.update(oauth2_error) return "Bearer {attributes}".format( attributes=self._dict_to_string(www_authenticate_attributes), ) diff --git a/oauth2_provider/contrib/rest_framework/permissions.py b/oauth2_provider/contrib/rest_framework/permissions.py index 00a1ca0..7ba1c5c 100644 --- a/oauth2_provider/contrib/rest_framework/permissions.py +++ b/oauth2_provider/contrib/rest_framework/permissions.py @@ -2,7 +2,9 @@ from django.core.exceptions import ImproperlyConfigured from rest_framework.exceptions import PermissionDenied -from rest_framework.permissions import BasePermission, IsAuthenticated, SAFE_METHODS +from rest_framework.permissions import ( + SAFE_METHODS, BasePermission, IsAuthenticated +) from ...settings import oauth2_settings from .authentication import OAuth2Authentication @@ -65,7 +67,7 @@ class TokenHasReadWriteScope(TokenHasScope): def get_scopes(self, request, view): try: - required_scopes = super(TokenHasReadWriteScope, self).get_scopes(request, view) + required_scopes = super().get_scopes(request, view) except ImproperlyConfigured: required_scopes = [] @@ -85,9 +87,7 @@ class TokenHasResourceScope(TokenHasScope): def get_scopes(self, request, view): try: - view_scopes = ( - super(TokenHasResourceScope, self).get_scopes(request, view) - ) + view_scopes = super().get_scopes(request, view) except ImproperlyConfigured: view_scopes = [] @@ -121,3 +121,58 @@ def has_permission(self, request, view): token_has_scope = TokenHasScope() return (is_authenticated and not oauth2authenticated) or token_has_scope.has_permission(request, view) + + +class TokenMatchesOASRequirements(BasePermission): + """ + :attr:alternate_required_scopes: dict keyed by HTTP method name with value: iterable alternate scope lists + + This fulfills the [Open API Specification (OAS; formerly Swagger)](https://www.openapis.org/) + list of alternative Security Requirements Objects for oauth2 or openIdConnect: + When a list of Security Requirement Objects is defined on the Open API object or Operation Object, + only one of Security Requirement Objects in the list needs to be satisfied to authorize the request. + [1](https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.0.md#securityRequirementObject) + + For each method, a list of lists of allowed scopes is tried in order and the first to match succeeds. + + @example + required_alternate_scopes = { + 'GET': [['read']], + 'POST': [['create1','scope2'], ['alt-scope3'], ['alt-scope4','alt-scope5']], + } + + TODO: DRY: subclass TokenHasScope and iterate over values of required_scope? + """ + + def has_permission(self, request, view): + token = request.auth + + if not token: + return False + + if hasattr(token, "scope"): # OAuth 2 + required_alternate_scopes = self.get_required_alternate_scopes(request, view) + + m = request.method.upper() + if m in required_alternate_scopes: + log.debug("Required scopes alternatives to access resource: {0}" + .format(required_alternate_scopes[m])) + for alt in required_alternate_scopes[m]: + if token.is_valid(alt): + return True + return False + else: + log.warning("no scope alternates defined for method {0}".format(m)) + return False + + assert False, ("TokenMatchesOASRequirements requires the" + "`oauth2_provider.rest_framework.OAuth2Authentication` authentication " + "class to be used.") + + def get_required_alternate_scopes(self, request, view): + try: + return getattr(view, "required_alternate_scopes") + except AttributeError: + raise ImproperlyConfigured( + "TokenMatchesOASRequirements requires the view to" + " define the required_alternate_scopes attribute") diff --git a/oauth2_provider/exceptions.py b/oauth2_provider/exceptions.py index 6c81d5d..2155155 100644 --- a/oauth2_provider/exceptions.py +++ b/oauth2_provider/exceptions.py @@ -3,7 +3,7 @@ class OAuthToolkitError(Exception): Base class for exceptions """ def __init__(self, error=None, redirect_uri=None, *args, **kwargs): - super(OAuthToolkitError, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) self.oauthlib_error = error if redirect_uri: diff --git a/oauth2_provider/forms.py b/oauth2_provider/forms.py index c75cd52..41129c4 100644 --- a/oauth2_provider/forms.py +++ b/oauth2_provider/forms.py @@ -9,3 +9,5 @@ class AllowForm(forms.Form): client_id = forms.CharField(widget=forms.HiddenInput()) state = forms.CharField(required=False, widget=forms.HiddenInput()) response_type = forms.CharField(widget=forms.HiddenInput()) + code_challenge = forms.CharField(required=False, widget=forms.HiddenInput()) + code_challenge_method = forms.CharField(required=False, widget=forms.HiddenInput()) diff --git a/oauth2_provider/generators.py b/oauth2_provider/generators.py index 6e81249..a548088 100644 --- a/oauth2_provider/generators.py +++ b/oauth2_provider/generators.py @@ -1,7 +1,5 @@ -from __future__ import absolute_import, unicode_literals - -from oauthlib.common import generate_client_id as oauthlib_generate_client_id from oauthlib.common import UNICODE_ASCII_CHARACTER_SET +from oauthlib.common import generate_client_id as oauthlib_generate_client_id from .settings import oauth2_settings diff --git a/oauth2_provider/http.py b/oauth2_provider/http.py index 781f2f8..a44b82c 100644 --- a/oauth2_provider/http.py +++ b/oauth2_provider/http.py @@ -1,9 +1,12 @@ +try: + from urllib.parse import urlparse +except ImportError: + from urlparse import urlparse + from django.core.exceptions import DisallowedRedirect from django.http import HttpResponse from django.utils.encoding import iri_to_uri -from .compat import urlparse - class OAuth2ResponseRedirect(HttpResponse): """ @@ -31,7 +34,3 @@ def validate_redirect(self, redirect_to): raise DisallowedRedirect( "Redirect to scheme {!r} is not permitted".format(parsed.scheme) ) - - -# Backwards compatibility (as of 1.0.0) -HttpResponseUriRedirect = OAuth2ResponseRedirect diff --git a/oauth2_provider/management/commands/createapplication.py b/oauth2_provider/management/commands/createapplication.py new file mode 100644 index 0000000..e63d542 --- /dev/null +++ b/oauth2_provider/management/commands/createapplication.py @@ -0,0 +1,85 @@ +from django.core.exceptions import ValidationError +from django.core.management.base import BaseCommand + +from oauth2_provider.models import get_application_model + +Application = get_application_model() + + +class Command(BaseCommand): + help = "Shortcut to create a new application in a programmatic way" + + def add_arguments(self, parser): + parser.add_argument( + 'client_type', + type=str, + help='The client type, can be confidential or public', + ) + parser.add_argument( + 'authorization_grant_type', + type=str, + help='The type of authorization grant to be used', + ) + parser.add_argument( + '--client-id', + type=str, + help='The ID of the new application', + ) + parser.add_argument( + '--user', + type=str, + help='The user the application belongs to', + ) + parser.add_argument( + '--redirect-uris', + type=str, + help='The redirect URIs, this must be a space separated string e.g "URI1 URI2', + ) + parser.add_argument( + '--client-secret', + type=str, + help='The secret for this application', + ) + parser.add_argument( + '--name', + type=str, + help='The name this application', + ) + parser.add_argument( + '--skip-authorization', + action='store_true', + help='The ID of the new application', + ) + + def handle(self, *args, **options): + # Extract all fields related to the application, this will work now and in the future + # and also with custom application models. + application_fields = [field.name for field in Application._meta.fields] + application_data = {} + for key, value in options.items(): + # Data in options must be cleaned because there are unneded key-value like + # verbosity and others. Also do not pass any None to the Application + # instance so default values will be generated for those fields + if key in application_fields and value: + if key == 'user': + application_data.update({'user_id': value}) + else: + application_data.update({key: value}) + + new_application = Application(**application_data) + + try: + new_application.full_clean() + except ValidationError as exc: + errors = "\n ".join(['- ' + err_key + ': ' + str(err_value) for err_key, + err_value in exc.message_dict.items()]) + self.stdout.write( + self.style.ERROR( + 'Please correct the following errors:\n %s' % errors + ) + ) + else: + new_application.save() + self.stdout.write( + self.style.SUCCESS('New application created successfully') + ) diff --git a/oauth2_provider/middleware.py b/oauth2_provider/middleware.py index f41f3f3..b94cb71 100644 --- a/oauth2_provider/middleware.py +++ b/oauth2_provider/middleware.py @@ -1,7 +1,6 @@ from django.contrib.auth import authenticate from django.utils.cache import patch_vary_headers - -from .compat import MiddlewareMixin +from django.utils.deprecation import MiddlewareMixin class OAuth2TokenMiddleware(MiddlewareMixin): @@ -23,6 +22,7 @@ class OAuth2TokenMiddleware(MiddlewareMixin): It also adds "Authorization" to the "Vary" header, so that django's cache middleware or a reverse proxy can create proper cache keys. """ + def process_request(self, request): # do something only if request contains a Bearer token if request.META.get("HTTP_AUTHORIZATION", "").startswith("Bearer"): diff --git a/oauth2_provider/migrations/0001_initial.py b/oauth2_provider/migrations/0001_initial.py index f415cb6..1d1a38e 100644 --- a/oauth2_provider/migrations/0001_initial.py +++ b/oauth2_provider/migrations/0001_initial.py @@ -1,15 +1,22 @@ -# -*- coding: utf-8 -*- -from __future__ import unicode_literals +from django.conf import settings +import django.db.models.deletion +from django.db import migrations, models -from oauth2_provider.settings import oauth2_settings -from django.db import models, migrations -import oauth2_provider.validators import oauth2_provider.generators -from django.conf import settings +import oauth2_provider.validators +from oauth2_provider.settings import oauth2_settings class Migration(migrations.Migration): - + """ + The following migrations are squashed here: + - 0001_initial.py + - 0002_08_updates.py + - 0003_auto_20160316_1503.py + - 0004_auto_20160525_1623.py + - 0005_auto_20170514_1141.py + - 0006_auto_20171214_2232.py + """ dependencies = [ migrations.swappable_dependency(settings.AUTH_USER_MODEL) ] @@ -18,14 +25,17 @@ class Migration(migrations.Migration): migrations.CreateModel( name='Application', fields=[ - ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), + ('id', models.BigAutoField(serialize=False, primary_key=True)), ('client_id', models.CharField(default=oauth2_provider.generators.generate_client_id, unique=True, max_length=100, db_index=True)), - ('redirect_uris', models.TextField(help_text='Allowed URIs list, space separated', blank=True, validators=[oauth2_provider.validators.validate_uris])), + ('redirect_uris', models.TextField(help_text='Allowed URIs list, space separated', blank=True)), ('client_type', models.CharField(max_length=32, choices=[('confidential', 'Confidential'), ('public', 'Public')])), ('authorization_grant_type', models.CharField(max_length=32, choices=[('authorization-code', 'Authorization code'), ('implicit', 'Implicit'), ('password', 'Resource owner password-based'), ('client-credentials', 'Client credentials')])), ('client_secret', models.CharField(default=oauth2_provider.generators.generate_client_secret, max_length=255, db_index=True, blank=True)), ('name', models.CharField(max_length=255, blank=True)), - ('user', models.ForeignKey(to=settings.AUTH_USER_MODEL, on_delete=models.CASCADE)), + ('user', models.ForeignKey(related_name="oauth2_provider_application", blank=True, to=settings.AUTH_USER_MODEL, null=True, on_delete=models.CASCADE)), + ('skip_authorization', models.BooleanField(default=False)), + ('created', models.DateTimeField(auto_now_add=True)), + ('updated', models.DateTimeField(auto_now=True)), ], options={ 'abstract': False, @@ -35,12 +45,16 @@ class Migration(migrations.Migration): migrations.CreateModel( name='AccessToken', fields=[ - ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), - ('token', models.CharField(max_length=255, db_index=True)), + ('id', models.BigAutoField(serialize=False, primary_key=True)), + ('token', models.CharField(unique=True, max_length=255)), ('expires', models.DateTimeField()), ('scope', models.TextField(blank=True)), - ('application', models.ForeignKey(to=oauth2_settings.APPLICATION_MODEL, on_delete=models.CASCADE)), - ('user', models.ForeignKey(to=settings.AUTH_USER_MODEL, on_delete=models.CASCADE)), + ('application', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, to=oauth2_settings.APPLICATION_MODEL)), + ('user', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='oauth2_provider_accesstoken', to=settings.AUTH_USER_MODEL)), + ('created', models.DateTimeField(auto_now_add=True)), + ('updated', models.DateTimeField(auto_now=True)), + # Circular reference. Can't add it here. + #('source_refresh_token', models.OneToOneField(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, to=oauth2_settings.REFRESH_TOKEN_MODEL, related_name="refreshed_access_token")), ], options={ 'abstract': False, @@ -50,13 +64,15 @@ class Migration(migrations.Migration): migrations.CreateModel( name='Grant', fields=[ - ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), - ('code', models.CharField(max_length=255, db_index=True)), + ('id', models.BigAutoField(serialize=False, primary_key=True)), + ('code', models.CharField(unique=True, max_length=255)), ('expires', models.DateTimeField()), ('redirect_uri', models.CharField(max_length=255)), ('scope', models.TextField(blank=True)), ('application', models.ForeignKey(to=oauth2_settings.APPLICATION_MODEL, on_delete=models.CASCADE)), - ('user', models.ForeignKey(to=settings.AUTH_USER_MODEL, on_delete=models.CASCADE)), + ('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='oauth2_provider_grant', to=settings.AUTH_USER_MODEL)), + ('created', models.DateTimeField(auto_now_add=True)), + ('updated', models.DateTimeField(auto_now=True)), ], options={ 'abstract': False, @@ -66,15 +82,24 @@ class Migration(migrations.Migration): migrations.CreateModel( name='RefreshToken', fields=[ - ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), - ('token', models.CharField(max_length=255, db_index=True)), - ('access_token', models.OneToOneField(related_name='refresh_token', to=oauth2_settings.ACCESS_TOKEN_MODEL, on_delete=models.CASCADE)), + ('id', models.BigAutoField(serialize=False, primary_key=True)), + ('token', models.CharField(max_length=255)), + ('access_token', models.OneToOneField(blank=True, null=True, related_name="refresh_token", to=oauth2_settings.ACCESS_TOKEN_MODEL, on_delete=models.SET_NULL)), ('application', models.ForeignKey(to=oauth2_settings.APPLICATION_MODEL, on_delete=models.CASCADE)), - ('user', models.ForeignKey(to=settings.AUTH_USER_MODEL, on_delete=models.CASCADE)), + ('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='oauth2_provider_refreshtoken', to=settings.AUTH_USER_MODEL)), + ('created', models.DateTimeField(auto_now_add=True)), + ('updated', models.DateTimeField(auto_now=True)), + ('revoked', models.DateTimeField(null=True)), ], options={ 'abstract': False, 'swappable': 'OAUTH2_PROVIDER_REFRESH_TOKEN_MODEL', + 'unique_together': set([("token", "revoked")]), }, ), + migrations.AddField( + model_name='AccessToken', + name='source_refresh_token', + field=models.OneToOneField(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, to=oauth2_settings.REFRESH_TOKEN_MODEL, related_name="refreshed_access_token"), + ), ] diff --git a/oauth2_provider/migrations/0002_08_updates.py b/oauth2_provider/migrations/0002_08_updates.py deleted file mode 100644 index 01e1a4a..0000000 --- a/oauth2_provider/migrations/0002_08_updates.py +++ /dev/null @@ -1,36 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import unicode_literals - -from oauth2_provider.settings import oauth2_settings -from django.db import models, migrations -import oauth2_provider.validators -import oauth2_provider.generators -from django.conf import settings - - -class Migration(migrations.Migration): - - dependencies = [ - ('oauth2_provider', '0001_initial'), - ] - - operations = [ - migrations.AddField( - model_name='Application', - name='skip_authorization', - field=models.BooleanField(default=False), - preserve_default=True, - ), - migrations.AlterField( - model_name='Application', - name='user', - field=models.ForeignKey(related_name='oauth2_provider_application', to=settings.AUTH_USER_MODEL, on_delete=models.CASCADE), - preserve_default=True, - ), - migrations.AlterField( - model_name='AccessToken', - name='user', - field=models.ForeignKey(blank=True, to=settings.AUTH_USER_MODEL, null=True, on_delete=models.CASCADE), - preserve_default=True, - ), - ] diff --git a/oauth2_provider/migrations/0002_auto_20190406_1805.py b/oauth2_provider/migrations/0002_auto_20190406_1805.py new file mode 100644 index 0000000..8ca177a --- /dev/null +++ b/oauth2_provider/migrations/0002_auto_20190406_1805.py @@ -0,0 +1,23 @@ +# Generated by Django 2.2 on 2019-04-06 18:05 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('oauth2_provider', '0001_initial'), + ] + + operations = [ + migrations.AddField( + model_name='grant', + name='code_challenge', + field=models.CharField(blank=True, default='', max_length=128), + ), + migrations.AddField( + model_name='grant', + name='code_challenge_method', + field=models.CharField(blank=True, choices=[('plain', 'plain'), ('S256', 'S256')], default='', max_length=10), + ), + ] diff --git a/oauth2_provider/migrations/0003_auto_20160316_1503.py b/oauth2_provider/migrations/0003_auto_20160316_1503.py deleted file mode 100644 index 49cfb4b..0000000 --- a/oauth2_provider/migrations/0003_auto_20160316_1503.py +++ /dev/null @@ -1,20 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import unicode_literals - -from django.db import migrations, models -from django.conf import settings - - -class Migration(migrations.Migration): - - dependencies = [ - ('oauth2_provider', '0002_08_updates'), - ] - - operations = [ - migrations.AlterField( - model_name='application', - name='user', - field=models.ForeignKey(related_name='oauth2_provider_application', blank=True, to=settings.AUTH_USER_MODEL, null=True, on_delete=models.CASCADE), - ), - ] diff --git a/oauth2_provider/migrations/0006_auto_20170903_1632.py b/oauth2_provider/migrations/0003_auto_20190413_2007.py similarity index 60% rename from oauth2_provider/migrations/0006_auto_20170903_1632.py rename to oauth2_provider/migrations/0003_auto_20190413_2007.py index dc2d7cb..4728861 100644 --- a/oauth2_provider/migrations/0006_auto_20170903_1632.py +++ b/oauth2_provider/migrations/0003_auto_20190413_2007.py @@ -1,6 +1,4 @@ -# -*- coding: utf-8 -*- -# Generated by Django 1.11.4 on 2017-09-03 16:32 -from __future__ import unicode_literals +# Generated by Django 2.2 on 2019-04-13 20:07 from django.db import migrations, models @@ -8,10 +6,15 @@ class Migration(migrations.Migration): dependencies = [ - ('oauth2_provider', '0005_auto_20170514_1141'), + ('oauth2_provider', '0002_auto_20190406_1805'), ] operations = [ + migrations.AddField( + model_name='application', + name='algorithm', + field=models.CharField(choices=[('RS256', 'RSA with SHA-2 256'), ('HS256', 'HMAC with SHA-2 256')], default='RS256', max_length=5), + ), migrations.AlterField( model_name='application', name='authorization_grant_type', diff --git a/oauth2_provider/migrations/0004_auto_20160525_1623.py b/oauth2_provider/migrations/0004_auto_20160525_1623.py deleted file mode 100644 index 5ada5db..0000000 --- a/oauth2_provider/migrations/0004_auto_20160525_1623.py +++ /dev/null @@ -1,29 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import unicode_literals - -from django.db import migrations, models - - -class Migration(migrations.Migration): - - dependencies = [ - ('oauth2_provider', '0003_auto_20160316_1503'), - ] - - operations = [ - migrations.AlterField( - model_name='accesstoken', - name='token', - field=models.CharField(unique=True, max_length=255), - ), - migrations.AlterField( - model_name='grant', - name='code', - field=models.CharField(unique=True, max_length=255), - ), - migrations.AlterField( - model_name='refreshtoken', - name='token', - field=models.CharField(unique=True, max_length=255), - ), - ] diff --git a/oauth2_provider/migrations/0008_idtoken.py b/oauth2_provider/migrations/0004_idtoken.py similarity index 74% rename from oauth2_provider/migrations/0008_idtoken.py rename to oauth2_provider/migrations/0004_idtoken.py index 3f0ae10..e0d43b2 100644 --- a/oauth2_provider/migrations/0008_idtoken.py +++ b/oauth2_provider/migrations/0004_idtoken.py @@ -1,20 +1,15 @@ -# -*- coding: utf-8 -*- -# Generated by Django 1.11.4 on 2017-10-01 19:13 -from __future__ import unicode_literals +# Generated by Django 2.2 on 2019-04-16 14:36 from django.conf import settings from django.db import migrations, models import django.db.models.deletion -from oauth2_provider.settings import oauth2_settings - class Migration(migrations.Migration): dependencies = [ migrations.swappable_dependency(settings.AUTH_USER_MODEL), - migrations.swappable_dependency(oauth2_settings.APPLICATION_MODEL), - ('oauth2_provider', '0007_application_algorithm'), + ('oauth2_provider', '0003_auto_20190413_2007'), ] operations = [ @@ -27,7 +22,7 @@ class Migration(migrations.Migration): ('scope', models.TextField(blank=True)), ('created', models.DateTimeField(auto_now_add=True)), ('updated', models.DateTimeField(auto_now=True)), - ('application', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, to=oauth2_settings.APPLICATION_MODEL)), + ('application', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, to=settings.OAUTH2_PROVIDER_APPLICATION_MODEL)), ('user', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='oauth2_provider_idtoken', to=settings.AUTH_USER_MODEL)), ], options={ diff --git a/oauth2_provider/migrations/0005_accesstoken_id_token.py b/oauth2_provider/migrations/0005_accesstoken_id_token.py new file mode 100644 index 0000000..a6ca7dd --- /dev/null +++ b/oauth2_provider/migrations/0005_accesstoken_id_token.py @@ -0,0 +1,20 @@ +# Generated by Django 2.2 on 2019-04-16 14:39 + +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ('oauth2_provider', '0004_idtoken'), + ] + + operations = [ + migrations.AddField( + model_name='accesstoken', + name='id_token', + field=models.OneToOneField(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='access_token', to=settings.OAUTH2_PROVIDER_ID_TOKEN_MODEL), + ), + ] diff --git a/oauth2_provider/migrations/0005_auto_20170514_1141.py b/oauth2_provider/migrations/0005_auto_20170514_1141.py deleted file mode 100644 index 4eca6c8..0000000 --- a/oauth2_provider/migrations/0005_auto_20170514_1141.py +++ /dev/null @@ -1,102 +0,0 @@ -# -*- coding: utf-8 -*- -# Generated by Django 1.11.1 on 2017-05-14 11:41 -from __future__ import unicode_literals - -from oauth2_provider.settings import oauth2_settings -from django.conf import settings -from django.db import migrations, models -import django.db.models.deletion - - -class Migration(migrations.Migration): - - dependencies = [ - ('oauth2_provider', '0004_auto_20160525_1623'), - ] - - operations = [ - migrations.AlterField( - model_name='accesstoken', - name='application', - field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, to=oauth2_settings.APPLICATION_MODEL), - ), - migrations.AlterField( - model_name='accesstoken', - name='id', - field=models.BigAutoField(primary_key=True, serialize=False), - ), - migrations.AlterField( - model_name='accesstoken', - name='user', - field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='oauth2_provider_accesstoken', to=settings.AUTH_USER_MODEL), - ), - migrations.AlterField( - model_name='application', - name='id', - field=models.BigAutoField(primary_key=True, serialize=False), - ), - migrations.AlterField( - model_name='grant', - name='id', - field=models.BigAutoField(primary_key=True, serialize=False), - ), - migrations.AlterField( - model_name='grant', - name='user', - field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='oauth2_provider_grant', to=settings.AUTH_USER_MODEL), - ), - migrations.AlterField( - model_name='refreshtoken', - name='id', - field=models.BigAutoField(primary_key=True, serialize=False), - ), - migrations.AlterField( - model_name='refreshtoken', - name='user', - field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='oauth2_provider_refreshtoken', to=settings.AUTH_USER_MODEL), - ), - migrations.AddField( - model_name='accesstoken', - name='created', - field=models.DateTimeField(auto_now_add=True, default=django.utils.timezone.now), - preserve_default=False, - ), - migrations.AddField( - model_name='accesstoken', - name='updated', - field=models.DateTimeField(auto_now=True), - ), - migrations.AddField( - model_name='application', - name='created', - field=models.DateTimeField(auto_now_add=True, default=django.utils.timezone.now), - preserve_default=False, - ), - migrations.AddField( - model_name='application', - name='updated', - field=models.DateTimeField(auto_now=True), - ), - migrations.AddField( - model_name='grant', - name='created', - field=models.DateTimeField(auto_now_add=True, default=django.utils.timezone.now), - preserve_default=False, - ), - migrations.AddField( - model_name='grant', - name='updated', - field=models.DateTimeField(auto_now=True), - ), - migrations.AddField( - model_name='refreshtoken', - name='created', - field=models.DateTimeField(auto_now_add=True, default=django.utils.timezone.now), - preserve_default=False, - ), - migrations.AddField( - model_name='refreshtoken', - name='updated', - field=models.DateTimeField(auto_now=True), - ), - ] diff --git a/oauth2_provider/migrations/0006_auto_20171214_2232.py b/oauth2_provider/migrations/0006_auto_20171214_2232.py deleted file mode 100644 index b264671..0000000 --- a/oauth2_provider/migrations/0006_auto_20171214_2232.py +++ /dev/null @@ -1,44 +0,0 @@ -# -*- coding: utf-8 -*- -# Generated by Django 1.11.1 on 2017-05-14 11:41 -from __future__ import unicode_literals - -from oauth2_provider.settings import oauth2_settings -from django.conf import settings -from django.db import migrations, models -import django.db.models.deletion - - -class Migration(migrations.Migration): - - dependencies = [ - ('oauth2_provider', '0005_auto_20170514_1141'), - ] - - operations = [ - migrations.AddField( - model_name='accesstoken', - name='source_refresh_token', - field=models.OneToOneField(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, to=oauth2_settings.REFRESH_TOKEN_MODEL, related_name='refreshed_access_token'), - preserve_default=False, - ), - migrations.AddField( - model_name='refreshtoken', - name='revoked', - field=models.DateTimeField(null=True, default=None), - preserve_default=False, - ), - migrations.AlterField( - model_name='refreshtoken', - name='token', - field=models.CharField(max_length=255), - ), - migrations.AlterField( - model_name='refreshtoken', - name='access_token', - field=models.OneToOneField(blank=True, null=True, related_name='refresh_token', to=oauth2_settings.ACCESS_TOKEN_MODEL, on_delete=models.SET_NULL), - ), - migrations.AlterUniqueTogether( - name='refreshtoken', - unique_together=set([('token', 'revoked')]), - ), - ] diff --git a/oauth2_provider/migrations/0007_application_algorithm.py b/oauth2_provider/migrations/0007_application_algorithm.py deleted file mode 100644 index 319d99e..0000000 --- a/oauth2_provider/migrations/0007_application_algorithm.py +++ /dev/null @@ -1,20 +0,0 @@ -# -*- coding: utf-8 -*- -# Generated by Django 1.11.4 on 2017-09-16 18:55 -from __future__ import unicode_literals - -from django.db import migrations, models - - -class Migration(migrations.Migration): - - dependencies = [ - ('oauth2_provider', '0006_auto_20170903_1632'), - ] - - operations = [ - migrations.AddField( - model_name='application', - name='algorithm', - field=models.CharField(choices=[('RS256', 'RSA with SHA-2 256'), ('HS256', 'HMAC with SHA-2 256')], default='RS256', max_length=5), - ), - ] diff --git a/oauth2_provider/migrations/0009_merge_20180606_1314.py b/oauth2_provider/migrations/0009_merge_20180606_1314.py deleted file mode 100644 index 91c0613..0000000 --- a/oauth2_provider/migrations/0009_merge_20180606_1314.py +++ /dev/null @@ -1,16 +0,0 @@ -# -*- coding: utf-8 -*- -# Generated by Django 1.11.11 on 2018-06-06 13:14 -from __future__ import unicode_literals - -from django.db import migrations - - -class Migration(migrations.Migration): - - dependencies = [ - ('oauth2_provider', '0006_auto_20171214_2232'), - ('oauth2_provider', '0008_idtoken'), - ] - - operations = [ - ] diff --git a/oauth2_provider/migrations/0010_auto_20190419_1604.py b/oauth2_provider/migrations/0010_auto_20190419_1604.py deleted file mode 100644 index c7c23bb..0000000 --- a/oauth2_provider/migrations/0010_auto_20190419_1604.py +++ /dev/null @@ -1,40 +0,0 @@ -# -*- coding: utf-8 -*- -# Generated by Django 1.11.20 on 2019-04-19 16:04 -from __future__ import unicode_literals - -from django.db import migrations, models - - -class Migration(migrations.Migration): - - dependencies = [ - ('oauth2_provider', '0009_merge_20180606_1314'), - ] - - operations = [ - migrations.AlterField( - model_name='accesstoken', - name='id', - field=models.AutoField(primary_key=True, serialize=False), - ), - migrations.AlterField( - model_name='application', - name='id', - field=models.AutoField(primary_key=True, serialize=False), - ), - migrations.AlterField( - model_name='grant', - name='id', - field=models.AutoField(primary_key=True, serialize=False), - ), - migrations.AlterField( - model_name='idtoken', - name='id', - field=models.AutoField(primary_key=True, serialize=False), - ), - migrations.AlterField( - model_name='refreshtoken', - name='id', - field=models.AutoField(primary_key=True, serialize=False), - ), - ] diff --git a/oauth2_provider/models.py b/oauth2_provider/models.py index f6186c7..499f4c4 100644 --- a/oauth2_provider/models.py +++ b/oauth2_provider/models.py @@ -1,6 +1,12 @@ -from __future__ import unicode_literals - +import json from datetime import timedelta +try: + from urllib.parse import urlparse, parse_qsl +except ImportError: + from urlparse import urlparse, parse_qsl +import logging + +from jwcrypto import jwk, jwt from django.apps import apps from django.conf import settings @@ -8,17 +14,16 @@ from django.db import models, transaction from django.urls import reverse from django.utils import timezone -from django.utils.encoding import python_2_unicode_compatible from django.utils.translation import ugettext_lazy as _ -from .compat import parse_qsl, urlparse from .generators import generate_client_id, generate_client_secret from .scopes import get_scopes_backend from .settings import oauth2_settings -from .validators import validate_uris +from .validators import RedirectURIValidator, WildcardSet + +logger = logging.getLogger(__name__) -@python_2_unicode_compatible class AbstractApplication(models.Model): """ An Application instance represents a Client on the Authorization server. @@ -66,7 +71,7 @@ class AbstractApplication(models.Model): (HS256_ALGORITHM, _("HMAC with SHA-2 256")), ) - id = models.AutoField(primary_key=True) + id = models.BigAutoField(primary_key=True) client_id = models.CharField( max_length=100, unique=True, default=generate_client_id, db_index=True ) @@ -76,9 +81,8 @@ class AbstractApplication(models.Model): null=True, blank=True, on_delete=models.CASCADE ) - help_text = _("Allowed URIs list, space separated") redirect_uris = models.TextField( - blank=True, help_text=help_text, validators=[validate_uris] + blank=True, help_text=_("Allowed URIs list, space separated"), ) client_type = models.CharField(max_length=32, choices=CLIENT_TYPES) authorization_grant_type = models.CharField( @@ -139,12 +143,29 @@ def redirect_uri_allowed(self, uri): def clean(self): from django.core.exceptions import ValidationError - if not self.redirect_uris \ - and self.authorization_grant_type \ - in (AbstractApplication.GRANT_AUTHORIZATION_CODE, - AbstractApplication.GRANT_IMPLICIT): - error = _("Redirect_uris could not be empty with {grant_type} grant_type") - raise ValidationError(error.format(grant_type=self.authorization_grant_type)) + + grant_types = ( + AbstractApplication.GRANT_AUTHORIZATION_CODE, + AbstractApplication.GRANT_IMPLICIT, + ) + + redirect_uris = self.redirect_uris.strip().split() + allowed_schemes = set(s.lower() for s in self.get_allowed_schemes()) + + if redirect_uris: + validator = RedirectURIValidator(WildcardSet()) + for uri in redirect_uris: + validator(uri) + scheme = urlparse(uri).scheme + if scheme not in allowed_schemes: + raise ValidationError(_( + "Unauthorized redirect scheme: {scheme}" + ).format(scheme=scheme)) + + elif self.authorization_grant_type in grant_types: + raise ValidationError(_( + "redirect_uris cannot be empty with grant_type {grant_type}" + ).format(grant_type=self.authorization_grant_type)) def get_absolute_url(self): return reverse("oauth2_provider:detail", args=[str(self.id)]) @@ -183,7 +204,6 @@ def natural_key(self): return (self.client_id,) -@python_2_unicode_compatible class AbstractGrant(models.Model): """ A Grant instance represents a token with a short lifetime that can @@ -198,8 +218,17 @@ class AbstractGrant(models.Model): :data:`settings.AUTHORIZATION_CODE_EXPIRE_SECONDS` * :attr:`redirect_uri` Self explained * :attr:`scope` Required scopes, optional + * :attr:`code_challenge` PKCE code challenge + * :attr:`code_challenge_method` PKCE code challenge transform algorithm """ - id = models.AutoField(primary_key=True) + CODE_CHALLENGE_PLAIN = "plain" + CODE_CHALLENGE_S256 = "S256" + CODE_CHALLENGE_METHODS = ( + (CODE_CHALLENGE_PLAIN, "plain"), + (CODE_CHALLENGE_S256, "S256") + ) + + id = models.BigAutoField(primary_key=True) user = models.ForeignKey( settings.AUTH_USER_MODEL, on_delete=models.CASCADE, related_name="%(app_label)s_%(class)s" @@ -215,6 +244,10 @@ class AbstractGrant(models.Model): created = models.DateTimeField(auto_now_add=True) updated = models.DateTimeField(auto_now=True) + code_challenge = models.CharField(max_length=128, blank=True, default="") + code_challenge_method = models.CharField( + max_length=10, blank=True, default="", choices=CODE_CHALLENGE_METHODS) + def is_expired(self): """ Check token expiration with timezone awareness @@ -239,7 +272,6 @@ class Meta(AbstractGrant.Meta): swappable = "OAUTH2_PROVIDER_GRANT_MODEL" -@python_2_unicode_compatible class AbstractAccessToken(models.Model): """ An AccessToken instance represents the actual access token to @@ -254,7 +286,7 @@ class AbstractAccessToken(models.Model): * :attr:`expires` Date and time of token expiration, in DateTime format * :attr:`scope` Allowed scopes """ - id = models.AutoField(primary_key=True) + id = models.BigAutoField(primary_key=True) user = models.ForeignKey( settings.AUTH_USER_MODEL, on_delete=models.CASCADE, blank=True, null=True, related_name="%(app_label)s_%(class)s" @@ -265,6 +297,10 @@ class AbstractAccessToken(models.Model): related_name="refreshed_access_token" ) token = models.CharField(max_length=255, unique=True, ) + id_token = models.OneToOneField( + oauth2_settings.ID_TOKEN_MODEL, on_delete=models.CASCADE, blank=True, null=True, + related_name="access_token" + ) application = models.ForeignKey( oauth2_settings.APPLICATION_MODEL, on_delete=models.CASCADE, blank=True, null=True, ) @@ -333,7 +369,6 @@ class Meta(AbstractAccessToken.Meta): swappable = "OAUTH2_PROVIDER_ACCESS_TOKEN_MODEL" -@python_2_unicode_compatible class AbstractRefreshToken(models.Model): """ A RefreshToken instance represents a token that can be swapped for a new @@ -348,7 +383,7 @@ class AbstractRefreshToken(models.Model): bounded to * :attr:`revoked` Timestamp of when this refresh token was revoked """ - id = models.AutoField(primary_key=True) + id = models.BigAutoField(primary_key=True) user = models.ForeignKey( settings.AUTH_USER_MODEL, on_delete=models.CASCADE, related_name="%(app_label)s_%(class)s" @@ -378,7 +413,10 @@ def revoke(self): if not self: return - access_token_model.objects.get(id=self.access_token_id).revoke() + try: + access_token_model.objects.get(id=self.access_token_id).revoke() + except access_token_model.DoesNotExist: + pass self.access_token = None self.revoked = timezone.now() self.save() @@ -396,7 +434,6 @@ class Meta(AbstractRefreshToken.Meta): swappable = "OAUTH2_PROVIDER_REFRESH_TOKEN_MODEL" -@python_2_unicode_compatible class AbstractIDToken(models.Model): """ An IDToken instance represents the actual token to @@ -410,7 +447,7 @@ class AbstractIDToken(models.Model): * :attr:`expires` Date and time of token expiration, in DateTime format * :attr:`scope` Allowed scopes """ - id = models.AutoField(primary_key=True) + id = models.BigAutoField(primary_key=True) user = models.ForeignKey( settings.AUTH_USER_MODEL, on_delete=models.CASCADE, blank=True, null=True, related_name="%(app_label)s_%(class)s" @@ -472,6 +509,12 @@ def scopes(self): token_scopes = self.scope.split() return {name: desc for name, desc in all_scopes.items() if name in token_scopes} + @property + def claims(self): + key = jwk.JWK.from_pem(oauth2_settings.OIDC_RSA_PRIVATE_KEY.encode("utf8")) + jwt_token = jwt.JWT(key=key, jwt=self.token) + return json.loads(jwt_token.claims) + def __str__(self): return self.token @@ -527,7 +570,30 @@ def clear_expired(): with transaction.atomic(): if refresh_expire_at: - refresh_token_model.objects.filter(revoked__lt=refresh_expire_at).delete() - refresh_token_model.objects.filter(access_token__expires__lt=refresh_expire_at).delete() - access_token_model.objects.filter(refresh_token__isnull=True, expires__lt=now).delete() - grant_model.objects.filter(expires__lt=now).delete() + revoked = refresh_token_model.objects.filter( + revoked__lt=refresh_expire_at, + ) + expired = refresh_token_model.objects.filter( + access_token__expires__lt=refresh_expire_at, + ) + + logger.info('%s Revoked refresh tokens to be deleted', revoked.count()) + logger.info('%s Expired refresh tokens to be deleted', expired.count()) + + revoked.delete() + expired.delete() + else: + logger.info('refresh_expire_at is %s. No refresh tokens deleted.', + refresh_expire_at) + + access_tokens = access_token_model.objects.filter( + refresh_token__isnull=True, + expires__lt=now + ) + grants = grant_model.objects.filter(expires__lt=now) + + logger.info('%s Expired access tokens to be deleted', access_tokens.count()) + logger.info('%s Expired grant tokens to be deleted', grants.count()) + + access_tokens.delete() + grants.delete() diff --git a/oauth2_provider/oauth2_backends.py b/oauth2_provider/oauth2_backends.py index bf03a67..24f5459 100644 --- a/oauth2_provider/oauth2_backends.py +++ b/oauth2_provider/oauth2_backends.py @@ -1,11 +1,13 @@ -from __future__ import unicode_literals - import json +try: + from urllib.parse import urlparse, urlunparse +except ImportError: + from urlparse import urlparse, urlunparse + from oauthlib import oauth2 from oauthlib.common import quote, urlencode, urlencoded -from .compat import urlparse, urlunparse from .exceptions import FatalClientError, OAuthToolkitError from .settings import oauth2_settings @@ -87,7 +89,6 @@ def validate_authorization_request(self, request): """ try: uri, http_method, body, headers = self._extract_params(request) - scopes, credentials = self.server.validate_authorization_request( uri, http_method=http_method, body=body, headers=headers) @@ -183,6 +184,8 @@ def extract_body(self, request): """ try: body = json.loads(request.body.decode("utf-8")).items() + except AttributeError: + body = "" except ValueError: body = "" diff --git a/oauth2_provider/oauth2_validators.py b/oauth2_provider/oauth2_validators.py index 2abc0c5..467bae2 100644 --- a/oauth2_provider/oauth2_validators.py +++ b/oauth2_provider/oauth2_validators.py @@ -1,12 +1,15 @@ -from __future__ import unicode_literals - import base64 import binascii import json import hashlib import logging -from datetime import datetime, timedelta from collections import OrderedDict +from datetime import datetime, timedelta + +try: + from urllib.parse import unquote_plus +except ImportError: + from urllib import unquote_plus import requests from django.conf import settings @@ -24,7 +27,6 @@ from jwcrypto import jwk, jwt from jwcrypto.jwt import JWTExpired -from .compat import unquote_plus from .exceptions import FatalClientError from .models import ( AbstractApplication, @@ -60,7 +62,6 @@ class OAuth2Validator(RequestValidator): - def _extract_basic_auth(self, request): """ Return authentication string if request contains basic auth credentials, @@ -229,8 +230,7 @@ def client_authentication_required(self, request, *args, **kwargs): if request.client: return request.client.client_type == AbstractApplication.CLIENT_CONFIDENTIAL - return super(OAuth2Validator, self).client_authentication_required(request, - *args, **kwargs) + return super().client_authentication_required(request, *args, **kwargs) def authenticate_client(self, request, *args, **kwargs): """ @@ -348,20 +348,14 @@ def _get_token_from_authentication_server( scope = content.get("scope", "") expires = make_aware(expires) - try: - access_token = AccessToken.objects.select_related("application", "user").get(token=token) - except AccessToken.DoesNotExist: - access_token = AccessToken.objects.create( - token=token, - user=user, - application=None, - scope=scope, - expires=expires - ) - else: - access_token.expires = expires - access_token.scope = scope - access_token.save() + access_token, _created = AccessToken.objects.update_or_create( + token=token, + defaults={ + "user": user, + "application": None, + "scope": scope, + "expires": expires, + }) return access_token @@ -378,27 +372,11 @@ def validate_bearer_token(self, token, scopes, request): try: access_token = AccessToken.objects.select_related("application", "user").get(token=token) - # if there is a token but invalid then look up the token - if introspection_url and (introspection_token or introspection_credentials): - if not access_token.is_valid(scopes): - access_token = self._get_token_from_authentication_server( - token, - introspection_url, - introspection_token, - introspection_credentials - ) - if access_token and access_token.is_valid(scopes): - request.client = access_token.application - request.user = access_token.user - request.scopes = scopes - - # this is needed by django rest framework - request.access_token = access_token - return True - self._set_oauth2_error_on_request(request, access_token, scopes) - return False except AccessToken.DoesNotExist: - # there is no initial token, look up the token + access_token = None + + # if there is no token or it's invalid then introspect the token if there's an external OAuth server + if not access_token or not access_token.is_valid(scopes): if introspection_url and (introspection_token or introspection_credentials): access_token = self._get_token_from_authentication_server( token, @@ -406,15 +384,17 @@ def validate_bearer_token(self, token, scopes, request): introspection_token, introspection_credentials ) - if access_token and access_token.is_valid(scopes): - request.client = access_token.application - request.user = access_token.user - request.scopes = scopes - - # this is needed by django rest framework - request.access_token = access_token - return True - self._set_oauth2_error_on_request(request, None, scopes) + + if access_token and access_token.is_valid(scopes): + request.client = access_token.application + request.user = access_token.user + request.scopes = scopes + + # this is needed by django rest framework + request.access_token = access_token + return True + else: + self._set_oauth2_error_on_request(request, access_token, scopes) return False def validate_code(self, client_id, code, client, request, *args, **kwargs): @@ -442,8 +422,7 @@ def validate_response_type(self, client_id, response_type, client, request, *arg rfc:`8.4`, so validate the response_type only if it matches "code" or "token" """ if response_type == "code": - return client.allows_grant_type( - AbstractApplication.GRANT_AUTHORIZATION_CODE, AbstractApplication.GRANT_OPENID_HYBRID) + return client.allows_grant_type(AbstractApplication.GRANT_AUTHORIZATION_CODE) elif response_type == "token": return client.allows_grant_type(AbstractApplication.GRANT_IMPLICIT) elif response_type == "id_token": @@ -473,12 +452,38 @@ def get_default_scopes(self, client_id, request, *args, **kwargs): def validate_redirect_uri(self, client_id, redirect_uri, request, *args, **kwargs): return request.client.redirect_uri_allowed(redirect_uri) + def is_pkce_required(self, client_id, request): + """ + Enables or disables PKCE verification. + + Uses the setting PKCE_REQUIRED, which can be either a bool or a callable that + receives the client id and returns a bool. + """ + if callable(oauth2_settings.PKCE_REQUIRED): + return oauth2_settings.PKCE_REQUIRED(client_id) + return oauth2_settings.PKCE_REQUIRED + + def get_code_challenge(self, code, request): + grant = Grant.objects.get(code=code, application=request.client) + return grant.code_challenge or None + + def get_code_challenge_method(self, code, request): + grant = Grant.objects.get(code=code, application=request.client) + return grant.code_challenge_method or None + def save_authorization_code(self, client_id, code, request, *args, **kwargs): expires = timezone.now() + timedelta( seconds=oauth2_settings.AUTHORIZATION_CODE_EXPIRE_SECONDS) - g = Grant(application=request.client, user=request.user, code=code["code"], - expires=expires, redirect_uri=request.redirect_uri, - scope=" ".join(request.scopes)) + g = Grant( + application=request.client, + user=request.user, + code=code["code"], + expires=expires, + redirect_uri=request.redirect_uri, + scope=" ".join(request.scopes), + code_challenge=request.code_challenge or "", + code_challenge_method=request.code_challenge_method or "" + ) g.save() def get_authorization_code_scopes(self, client_id, code, redirect_uri, request): @@ -553,6 +558,14 @@ def save_bearer_token(self, token, request, *args, **kwargs): else: # revoke existing tokens if possible to allow reuse of grant if isinstance(refresh_token_instance, RefreshToken): + # First, to ensure we don't have concurrency issues, we refresh the refresh token + # from the db while acquiring a lock on it + # We also put it in the "request cache" + refresh_token_instance = RefreshToken.objects.select_for_update().get( + id=refresh_token_instance.id + ) + request.refresh_token_instance = refresh_token_instance + previous_access_token = AccessToken.objects.filter( source_refresh_token=refresh_token_instance ).first() @@ -576,17 +589,14 @@ def save_bearer_token(self, token, request, *args, **kwargs): source_refresh_token=refresh_token_instance, ) - refresh_token = RefreshToken( - user=request.user, - token=refresh_token_code, - application=request.client, - access_token=access_token - ) - refresh_token.save() + self._create_refresh_token(request, refresh_token_code, access_token) else: # make sure that the token data we're returning matches # the existing token token["access_token"] = previous_access_token.token + token["refresh_token"] = RefreshToken.objects.filter( + access_token=previous_access_token + ).first().token token["scope"] = previous_access_token.scope # No refresh token should be created, just access token @@ -597,17 +607,30 @@ def save_bearer_token(self, token, request, *args, **kwargs): token["expires_in"] = oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS def _create_access_token(self, expires, request, token, source_refresh_token=None): + id_token = token.get('id_token', None) + if id_token: + id_token = IDToken.objects.get(token=id_token) access_token = AccessToken( user=request.user, scope=token["scope"], expires=expires, token=token["access_token"], + id_token=id_token, application=request.client, source_refresh_token=source_refresh_token, ) access_token.save() return access_token + def _create_refresh_token(self, request, refresh_token_code, access_token): + refresh_token = RefreshToken( + user=request.user, + token=refresh_token_code, + application=request.client, + access_token=access_token + ) + refresh_token.save() + def revoke_token(self, token, token_type_hint, request, *args, **kwargs): """ Revoke an access or refresh token. @@ -675,6 +698,7 @@ def validate_refresh_token(self, refresh_token, client, request, *args, **kwargs @transaction.atomic def _save_id_token(self, token, request, expires, *args, **kwargs): + scopes = request.scope or " ".join(request.scopes) if request.grant_type == "client_credentials": @@ -693,6 +717,7 @@ def get_jwt_bearer_token(self, token, token_handler, request): return self.get_id_token(token, token_handler, request) def get_id_token(self, token, token_handler, request): + key = jwk.JWK.from_pem(oauth2_settings.OIDC_RSA_PRIVATE_KEY.encode("utf8")) # TODO: http://openid.net/specs/openid-connect-core-1_0.html#HybridIDToken2 @@ -705,9 +730,8 @@ def get_id_token(self, token, token_handler, request): expiration_time = timezone.now() + timedelta(seconds=oauth2_settings.ID_TOKEN_EXPIRE_SECONDS) # Required ID Token claims claims = { - "iss": 'https://id.olist.com', # HTTPS URL + "iss": oauth2_settings.OIDC_ISS_ENDPOINT, "sub": str(request.user.id), - "preferred_username": str(request.user.username), "aud": request.client_id, "exp": int(dateformat.format(expiration_time, "U")), "iat": int(dateformat.format(datetime.utcnow(), "U")), @@ -722,14 +746,11 @@ def get_id_token(self, token, token_handler, request): # http://openid.net/specs/openid-connect-core-1_0.html#CodeIDToken # http://openid.net/specs/openid-connect-core-1_0.html#ImplicitIDToken # if request.grant_type in 'authorization_code' and 'access_token' in token: - if (request.grant_type == "authorization_code" and "access_token" in token) or\ - request.response_type == "code id_token token" or\ - (request.response_type == "id_token token" and "access_token" in token): - access_token = token["access_token"] - sha256 = hashlib.sha256(access_token.encode("ascii")) + if (request.grant_type is "authorization_code" and "access_token" in token) or request.response_type == "code id_token token" or (request.response_type == "id_token token" and "access_token" in token): + acess_token = token["access_token"] + sha256 = hashlib.sha256(acess_token.encode("ascii")) bits128 = sha256.hexdigest()[:16] at_hash = base64.urlsafe_b64encode(bits128.encode("ascii")) - claims['access_token'] = access_token claims['at_hash'] = at_hash.decode("utf8") # TODO: create a function to check if we should include c_hash diff --git a/oauth2_provider/scopes.py b/oauth2_provider/scopes.py index d0eae57..d30f43e 100644 --- a/oauth2_provider/scopes.py +++ b/oauth2_provider/scopes.py @@ -1,5 +1,3 @@ -from __future__ import absolute_import, unicode_literals - from .settings import oauth2_settings diff --git a/oauth2_provider/settings.py b/oauth2_provider/settings.py index 8759b3a..978ad21 100644 --- a/oauth2_provider/settings.py +++ b/oauth2_provider/settings.py @@ -15,8 +15,6 @@ OAuth2 Provider settings, checking for user settings first, then falling back to the defaults. """ -from __future__ import unicode_literals - import importlib from django.conf import settings @@ -35,7 +33,7 @@ "CLIENT_ID_GENERATOR_CLASS": "oauth2_provider.generators.ClientIdGenerator", "CLIENT_SECRET_GENERATOR_CLASS": "oauth2_provider.generators.ClientSecretGenerator", "CLIENT_SECRET_GENERATOR_LENGTH": 128, - "OAUTH2_SERVER_CLASS": "oauthlib.oauth2.Server", + "OAUTH2_SERVER_CLASS": "oauthlib.openid.connect.core.endpoints.pre_configured.Server", "OAUTH2_VALIDATOR_CLASS": "oauth2_provider.oauth2_validators.OAuth2Validator", "OAUTH2_BACKEND_CLASS": "oauth2_provider.oauth2_backends.OAuthLibCore", "SCOPES": {"read": "Reading scope", "write": "Writing scope"}, @@ -82,6 +80,9 @@ "RESOURCE_SERVER_AUTH_TOKEN": None, "RESOURCE_SERVER_INTROSPECTION_CREDENTIALS": None, "RESOURCE_SERVER_TOKEN_CACHING_SECONDS": 36000, + + # Whether or not PKCE is required + "PKCE_REQUIRED": False } # List of settings that cannot be empty diff --git a/oauth2_provider/urls.py b/oauth2_provider/urls.py index f711f8e..333f119 100644 --- a/oauth2_provider/urls.py +++ b/oauth2_provider/urls.py @@ -1,5 +1,3 @@ -from __future__ import absolute_import - from django.conf.urls import url from . import views @@ -31,7 +29,8 @@ oidc_urlpatterns = [ url(r"^\.well-known/openid-configuration/$", views.ConnectDiscoveryInfoView.as_view(), name="oidc-connect-discovery-info"), - url(r"^jwks/$", views.JwksInfoView.as_view(), name="jwks-info") + url(r"^jwks/$", views.JwksInfoView.as_view(), name="jwks-info"), + url(r"^userinfo/$", views.UserInfoView.as_view(), name="user-info") ] diff --git a/oauth2_provider/validators.py b/oauth2_provider/validators.py index d49e101..933cc4e 100644 --- a/oauth2_provider/validators.py +++ b/oauth2_provider/validators.py @@ -1,71 +1,51 @@ -from __future__ import unicode_literals - import re +try: + from urllib.parse import urlsplit +except ImportError: + from urlparse import urlsplit + from django.core.exceptions import ValidationError -from django.core.validators import RegexValidator +from django.core.validators import URLValidator from django.utils.encoding import force_text -from django.utils.translation import ugettext_lazy as _ -from .compat import urlsplit, urlunsplit -from .settings import oauth2_settings +class URIValidator(URLValidator): + scheme_re = r"^(?:[a-z][a-z0-9\.\-\+]*)://" -class URIValidator(RegexValidator): - regex = re.compile( - r"^(?:[a-z][a-z0-9\.\-\+]*)://" # scheme... - r"(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|" # domain... - r"(?!-)[A-Z\d-]{1,63}(? ACE - except UnicodeError: # invalid domain part - raise e - url = urlunsplit((scheme, netloc, path, query, fragment)) - super(URIValidator, self).__call__(url) - else: - raise - else: - url = value + dotless_domain_re = r"(?!-)[A-Z\d-]{1,63}(? 1: - raise ValidationError("Redirect URIs must not contain fragments") scheme, netloc, path, query, fragment = urlsplit(value) - if scheme.lower() not in self.allowed_schemes: - raise ValidationError("Redirect URI scheme is not allowed.") + if fragment and not self.allow_fragments: + raise ValidationError("Redirect URIs must not contain fragments") + +## +# WildcardSet is a special set that contains everything. +# This is required in order to move validation of the scheme from +# URLValidator (the base class of URIValidator), to OAuth2Application.clean(). -def validate_uris(value): +class WildcardSet(set): """ - This validator ensures that `value` contains valid blank-separated URIs" + A set that always returns True on `in`. """ - v = RedirectURIValidator(oauth2_settings.ALLOWED_REDIRECT_URI_SCHEMES) - uris = value.split() - if not uris: - raise ValidationError("Redirect URI cannot be empty") - for uri in uris: - v(uri) + def __contains__(self, item): + return True diff --git a/oauth2_provider/views/__init__.py b/oauth2_provider/views/__init__.py index 2124dc7..9f2ac4f 100644 --- a/oauth2_provider/views/__init__.py +++ b/oauth2_provider/views/__init__.py @@ -9,5 +9,5 @@ ScopedProtectedResourceView ) from .introspect import IntrospectTokenView -from .oidc import ConnectDiscoveryInfoView, JwksInfoView +from .oidc import ConnectDiscoveryInfoView, JwksInfoView, UserInfoView from .token import AuthorizedTokenDeleteView, AuthorizedTokensListView diff --git a/oauth2_provider/views/application.py b/oauth2_provider/views/application.py index b685992..b38c907 100644 --- a/oauth2_provider/views/application.py +++ b/oauth2_provider/views/application.py @@ -38,7 +38,7 @@ def get_form_class(self): def form_valid(self, form): form.instance.user = self.request.user - return super(ApplicationRegistration, self).form_valid(form) + return super().form_valid(form) class ApplicationDetail(ApplicationOwnerIsUserMixin, DetailView): diff --git a/oauth2_provider/views/base.py b/oauth2_provider/views/base.py index b46fead..832fa9f 100644 --- a/oauth2_provider/views/base.py +++ b/oauth2_provider/views/base.py @@ -1,11 +1,6 @@ import json import logging -try: - import urlparse -except ImportError: - import urllib.parse as urlparse - from django.contrib.auth.mixins import LoginRequiredMixin from django.http import HttpResponse from django.utils import timezone @@ -16,13 +11,17 @@ from ..exceptions import OAuthToolkitError from ..forms import AllowForm -from ..http import HttpResponseUriRedirect +from ..http import OAuth2ResponseRedirect from ..models import get_access_token_model, get_application_model from ..scopes import get_scopes_backend from ..settings import oauth2_settings from ..signals import app_authorized from .mixins import OAuthLibMixin +try: + from urllib.parse import urlparse, parse_qsl +except ImportError: + from urlparse import urlparse, parse_qsl log = logging.getLogger("oauth2_provider") @@ -37,6 +36,7 @@ class BaseAuthorizationView(LoginRequiredMixin, OAuthLibMixin, View): * Implicit grant """ + def dispatch(self, request, *args, **kwargs): self.oauth2_data = {} return super(BaseAuthorizationView, self).dispatch(request, *args, **kwargs) @@ -46,7 +46,7 @@ def error_response(self, error, application, **kwargs): Handle errors either by redirecting to redirect_uri with a json in the body containing error details or providing an error response """ - redirect, error_response = super(BaseAuthorizationView, self).error_response(error, **kwargs) + redirect, error_response = super().error_response(error, **kwargs) if redirect: return self.redirect(error_response["url"], application) @@ -61,7 +61,7 @@ def redirect(self, redirect_to, application): allowed_schemes = oauth2_settings.ALLOWED_REDIRECT_URI_SCHEMES else: allowed_schemes = application.get_allowed_schemes() - return HttpResponseUriRedirect(redirect_to, allowed_schemes) + return OAuth2ResponseRedirect(redirect_to, allowed_schemes) class AuthorizationView(BaseAuthorizationView, FormView): @@ -103,6 +103,8 @@ def get_initial(self): "client_id": self.oauth2_data.get("client_id", None), "state": self.oauth2_data.get("state", None), "response_type": self.oauth2_data.get("response_type", None), + "code_challenge": self.oauth2_data.get("code_challenge", None), + "code_challenge_method": self.oauth2_data.get("code_challenge_method", None), } return initial_data @@ -113,18 +115,20 @@ def form_valid(self, form): "client_id": form.cleaned_data.get("client_id"), "redirect_uri": form.cleaned_data.get("redirect_uri"), "response_type": form.cleaned_data.get("response_type", None), - "state": form.cleaned_data.get("state", None), + "state": form.cleaned_data.get("state", None) } - + if form.cleaned_data.get("code_challenge", False): + credentials["code_challenge"] = form.cleaned_data.get("code_challenge") + if form.cleaned_data.get("code_challenge_method", False): + credentials["code_challenge_method"] = form.cleaned_data.get("code_challenge_method") body = { "nonce": form.cleaned_data.get("nonce") } - scopes = form.cleaned_data.get("scope") allow = form.cleaned_data.get("allow") try: - uri, headers, body, status = self.create_authorization_response( + redirect_uri, headers, body, status = self.create_authorization_response( self.request.get_raw_uri(), request=self.request, scopes=scopes, @@ -135,13 +139,23 @@ def form_valid(self, form): except OAuthToolkitError as error: return self.error_response(error, application) - self.success_url = uri + self.success_url = redirect_uri log.debug("Success url for the request: {0}".format(self.success_url)) return self.redirect(self.success_url, application) def get(self, request, *args, **kwargs): try: scopes, credentials = self.validate_authorization_request(request) + # TODO: Remove the two following lines after oauthlib updates its implementation + # https://github.com/jazzband/django-oauth-toolkit/pull/707#issuecomment-485011945 + credentials["code_challenge"] = credentials.get( + "code_challenge", + request.GET.get("code_challenge", None) + ) + credentials["code_challenge_method"] = credentials.get( + "code_challenge_method", + request.GET.get("code_challenge_method", None) + ) except OAuthToolkitError as error: # Application is not available at this time. return self.error_response(error, application=None) @@ -154,21 +168,23 @@ def get(self, request, *args, **kwargs): # TODO: Cache this! application = get_application_model().objects.get(client_id=credentials["client_id"]) - uri_query = urlparse.urlparse(self.request.get_raw_uri()).query - uri_query_params = dict(urlparse.parse_qsl(uri_query, keep_blank_values=True, strict_parsing=True)) + uri_query = urlparse(self.request.get_raw_uri()).query + uri_query_params = dict(parse_qsl(uri_query, keep_blank_values=True, strict_parsing=True)) kwargs["application"] = application kwargs["client_id"] = credentials["client_id"] kwargs["redirect_uri"] = credentials["redirect_uri"] kwargs["response_type"] = credentials["response_type"] kwargs["state"] = credentials["state"] + kwargs["code_challenge"] = credentials["code_challenge"] + kwargs["code_challenge_method"] = credentials["code_challenge_method"] kwargs["nonce"] = uri_query_params.get('nonce', None) self.oauth2_data = kwargs # following two loc are here only because of https://code.djangoproject.com/ticket/17795 form = self.get_form(self.get_form_class()) kwargs["form"] = form - allowed_schemes = oauth2_settings.ALLOWED_REDIRECT_URI_SCHEMES + # Check to see if the user has already granted access and return # a successful response depending on "approval_prompt" url parameter require_approval = request.GET.get("approval_prompt", oauth2_settings.REQUEST_APPROVAL_PROMPT) @@ -179,14 +195,14 @@ def get(self, request, *args, **kwargs): # This is useful for in-house applications-> assume an in-house applications # are already approved. if application.skip_authorization: - uri, headers, body, status = self.create_authorization_response( + redirect_uri, headers, body, status = self.create_authorization_response( self.request.get_raw_uri(), request=self.request, scopes=" ".join(scopes), credentials=credentials, - allow=True) - # return self.redirect(uri, application) - return HttpResponseUriRedirect(uri, allowed_schemes) + allow=True + ) + return self.redirect(redirect_uri, application) elif require_approval == "auto": tokens = get_access_token_model().objects.filter( @@ -198,14 +214,14 @@ def get(self, request, *args, **kwargs): # check past authorizations regarded the same scopes as the current one for token in tokens: if token.allow_scopes(scopes): - uri, headers, body, status = self.create_authorization_response( + redirect_uri, headers, body, status = self.create_authorization_response( self.request.get_raw_uri(), request=self.request, scopes=" ".join(scopes), credentials=credentials, - allow=True) - # return self.redirect(uri, application) - return HttpResponseUriRedirect(uri, allowed_schemes) + allow=True + ) + return self.redirect(redirect_uri, application) except OAuthToolkitError as error: return self.error_response(error, application) @@ -238,10 +254,7 @@ def post(self, request, *args, **kwargs): app_authorized.send( sender=self, request=request, token=token) - body = json.loads(body) - if 'id_token' in body: - body['access_token'] = body['id_token'] - response = HttpResponse(content=json.dumps(body), status=status) + response = HttpResponse(content=body, status=status) for k, v in headers.items(): response[k] = v diff --git a/oauth2_provider/views/introspect.py b/oauth2_provider/views/introspect.py index 0f3780c..5d5fcea 100644 --- a/oauth2_provider/views/introspect.py +++ b/oauth2_provider/views/introspect.py @@ -1,5 +1,3 @@ -from __future__ import unicode_literals - import calendar import json diff --git a/oauth2_provider/views/mixins.py b/oauth2_provider/views/mixins.py index bc2ef86..68015d2 100644 --- a/oauth2_provider/views/mixins.py +++ b/oauth2_provider/views/mixins.py @@ -1,5 +1,3 @@ -from __future__ import unicode_literals - import logging from django.core.exceptions import ImproperlyConfigured @@ -205,13 +203,13 @@ class ProtectedResourceMixin(OAuthLibMixin): def dispatch(self, request, *args, **kwargs): # let preflight OPTIONS requests pass if request.method.upper() == "OPTIONS": - return super(ProtectedResourceMixin, self).dispatch(request, *args, **kwargs) + return super().dispatch(request, *args, **kwargs) # check if the request is valid and the protected resource may be accessed valid, r = self.verify_request(request) if valid: request.resource_owner = r.user - return super(ProtectedResourceMixin, self).dispatch(request, *args, **kwargs) + return super().dispatch(request, *args, **kwargs) else: return HttpResponseForbidden() @@ -233,7 +231,7 @@ def __new__(cls, *args, **kwargs): ' to be in OAUTH2_PROVIDER["SCOPES"] list in settings'.format(read_write_scopes) ) - return super(ReadWriteScopedResourceMixin, cls).__new__(cls, *args, **kwargs) + return super().__new__(cls, *args, **kwargs) def dispatch(self, request, *args, **kwargs): if request.method.upper() in SAFE_HTTP_METHODS: @@ -241,10 +239,10 @@ def dispatch(self, request, *args, **kwargs): else: self.read_write_scope = oauth2_settings.WRITE_SCOPE - return super(ReadWriteScopedResourceMixin, self).dispatch(request, *args, **kwargs) + return super().dispatch(request, *args, **kwargs) def get_scopes(self, *args, **kwargs): - scopes = super(ReadWriteScopedResourceMixin, self).get_scopes(*args, **kwargs) + scopes = super().get_scopes(*args, **kwargs) # this returns a copy so that self.required_scopes is not modified return scopes + [self.read_write_scope] diff --git a/oauth2_provider/views/oidc.py b/oauth2_provider/views/oidc.py index 8c20908..7f3c9d5 100644 --- a/oauth2_provider/views/oidc.py +++ b/oauth2_provider/views/oidc.py @@ -5,6 +5,9 @@ from django.http import JsonResponse from django.urls import reverse_lazy from django.views.generic import View + +from rest_framework.views import APIView + from jwcrypto import jwk from ..settings import oauth2_settings @@ -49,3 +52,13 @@ def get(self, request, *args, **kwargs): response = JsonResponse(data) response["Access-Control-Allow-Origin"] = "*" return response + + +class UserInfoView(APIView): + """ + View used to show Claims about the authenticated End-User + """ + def get(self, request, *args, **kwargs): + response = JsonResponse(request.auth.id_token.claims) + response["Access-Control-Allow-Origin"] = "*" + return response diff --git a/oauth2_provider/views/token.py b/oauth2_provider/views/token.py index ebb4285..399953f 100644 --- a/oauth2_provider/views/token.py +++ b/oauth2_provider/views/token.py @@ -1,5 +1,3 @@ -from __future__ import absolute_import, unicode_literals - from django.contrib.auth.mixins import LoginRequiredMixin from django.urls import reverse_lazy from django.views.generic import DeleteView, ListView @@ -19,8 +17,9 @@ def get_queryset(self): """ Show only user"s tokens """ - return super(AuthorizedTokensListView, self).get_queryset()\ - .select_related("application").filter(user=self.request.user) + return super().get_queryset().select_related("application").filter( + user=self.request.user + ) class AuthorizedTokenDeleteView(LoginRequiredMixin, DeleteView): @@ -32,4 +31,4 @@ class AuthorizedTokenDeleteView(LoginRequiredMixin, DeleteView): model = get_access_token_model() def get_queryset(self): - return super(AuthorizedTokenDeleteView, self).get_queryset().filter(user=self.request.user) + return super().get_queryset().filter(user=self.request.user) diff --git a/setup.cfg b/setup.cfg index 06df903..aa4d14d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = geonode-oauth-toolkit -version = 1.1.4.1 +version = 1.1.4.1a0 description = OAuth2 Provider for Django/GeoNode author = Federico Frenguelli, Massimiliano Pippi, Alessio Fabiani author_email = synasius@gmail.com @@ -28,8 +28,8 @@ include_package_data = True zip_safe = False install_requires = django >= 1.11 - oauthlib >= 2.0.3 requests >= 2.13.0 + oauthlib >= 3.0.1 jwcrypto >= 0.4.2 [options.packages.find] diff --git a/tests/settings.py b/tests/settings.py index 151a7d7..f1ad8dd 100644 --- a/tests/settings.py +++ b/tests/settings.py @@ -130,3 +130,8 @@ "OIDC_USERINFO_ENDPOINT": "http://localhost/userinfo/", "OIDC_RSA_PRIVATE_KEY": "-----BEGIN RSA PRIVATE KEY-----\nMIICXQIBAAKBgQCbCYh5h2NmQuBqVO6G+/CO+cHm9VBzsb0MeA6bbQfDnbhstVOT\nj0hcnZJzDjYc6ajBZZf6gxVP9xrdm9Uh599VI3X5PFXLbMHrmzTAMzCGIyg+/fnP\n0gocYxmCX2+XKyj/Zvt1pUX8VAN2AhrJSfxNDKUHERTVEV9bRBJg4F0C3wIDAQAB\nAoGAP+i4nNw+Ec/8oWh8YSFm4xE6qKG0NdTtSMAOyWwy+KTB+vHuT1QPsLn1vj77\n+IQrX/moogg6F1oV9YdA3vat3U7rwt1sBGsRrLhA+Spp9WEQtglguNo4+QfVo2ju\nYBa2rG+h75qjiA3xnU//F3rvwnAsOWv0NUVdVeguyR+u6okCQQDBUmgWeH2WHmUn\n2nLNCz+9wj28rqhfOr9Ptem2gqk+ywJmuIr4Y5S1OdavOr2UZxOcEwncJ/MLVYQq\nMH+x4V5HAkEAzU2GMR5OdVLcxfVTjzuIC76paoHVWnLibd1cdANpPmE6SM+pf5el\nfVSwuH9Fmlizu8GiPCxbJUoXB/J1tGEKqQJBALhClEU+qOzpoZ6/voYi/6kdN3zc\nuEy0EN6n09AKb8gS9QH1STgAqh+ltjMkeMe3C2DKYK5/QU9/Pc58lWl1FkcCQG67\nZamQgxjcvJ85FvymS1aqW45KwNysIlzHjFo2jMlMf7dN6kobbPMQftDENLJvLWIT\nqoFyGycdsxZiPAIyZSECQQCZFn3Dl6hnJxWZH8Fsa9hj79kZ/WVkIXGmtdgt0fNr\ndTnvCVtA59ne4LEVie/PMH/odQWY0SxVm/76uBZv/1vY\n-----END RSA PRIVATE KEY-----" } + +OAUTH2_PROVIDER_ACCESS_TOKEN_MODEL = 'oauth2_provider.AccessToken' +OAUTH2_PROVIDER_APPLICATION_MODEL = 'oauth2_provider.Application' +OAUTH2_PROVIDER_REFRESH_TOKEN_MODEL = 'oauth2_provider.RefreshToken' +OAUTH2_PROVIDER_ID_TOKEN_MODEL = 'oauth2_provider.IDToken' diff --git a/tests/test_application_views.py b/tests/test_application_views.py index 580f2e6..64e112d 100644 --- a/tests/test_application_views.py +++ b/tests/test_application_views.py @@ -1,5 +1,3 @@ -from __future__ import unicode_literals - from django.contrib.auth import get_user_model from django.test import TestCase from django.urls import reverse diff --git a/tests/test_authorization_code.py b/tests/test_authorization_code.py index 2abbf34..527f69e 100644 --- a/tests/test_authorization_code.py +++ b/tests/test_authorization_code.py @@ -1,16 +1,21 @@ -from __future__ import unicode_literals - import base64 import datetime +import hashlib import json +try: + from urllib.parse import parse_qs, urlencode, urlparse +except ImportError: + from urlparse import parse_qs, urlparse + from urllib import urlencode + from django.contrib.auth import get_user_model from django.test import RequestFactory, TestCase from django.urls import reverse from django.utils import timezone +from django.utils.crypto import get_random_string from oauthlib.oauth2.rfc6749 import errors as oauthlib_errors -from oauth2_provider.compat import parse_qs, urlencode, urlparse from oauth2_provider.models import ( get_access_token_model, get_application_model, get_grant_model, get_refresh_token_model @@ -612,6 +617,40 @@ def get_auth(self, scope="read write"): query_dict = parse_qs(urlparse(response["Location"]).query) return query_dict["code"].pop() + def generate_pkce_codes(self, algorithm, length=43): + """ + Helper method to generate pkce codes + """ + code_verifier = get_random_string(length) + if algorithm == "S256": + code_challenge = base64.urlsafe_b64encode( + hashlib.sha256(code_verifier.encode()).digest() + ).decode().rstrip("=") + else: + code_challenge = code_verifier + return code_verifier, code_challenge + + def get_pkce_auth(self, code_challenge, code_challenge_method): + """ + Helper method to retrieve a valid authorization code using pkce + """ + oauth2_settings.PKCE_REQUIRED = True + authcode_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "http://example.org", + "response_type": "code", + "allow": True, + "code_challenge": code_challenge, + "code_challenge_method": code_challenge_method, + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) + query_dict = parse_qs(urlparse(response["Location"]).query) + oauth2_settings.PKCE_REQUIRED = False + return query_dict["code"].pop() + def test_basic_auth(self): """ Request an access token using basic authentication for client authentication @@ -674,7 +713,7 @@ def test_refresh(self): # check refresh token cannot be used twice response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) - self.assertEqual(response.status_code, 401) + self.assertEqual(response.status_code, 400) content = json.loads(response.content.decode("utf-8")) self.assertTrue("invalid_grant" in content.values()) @@ -711,19 +750,24 @@ def test_refresh_with_grace_period(self): "refresh_token": content["refresh_token"], "scope": content["scope"], } + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 200) content = json.loads(response.content.decode("utf-8")) self.assertTrue("access_token" in content) first_access_token = content["access_token"] + first_refresh_token = content["refresh_token"] - # check refresh token returns same data if used twice, see #497 + # check access token returns same data if used twice, see #497 response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 200) content = json.loads(response.content.decode("utf-8")) self.assertTrue("access_token" in content) self.assertEqual(content["access_token"], first_access_token) + # refresh token should be the same as well + self.assertTrue("refresh_token" in content) + self.assertEqual(content["refresh_token"], first_refresh_token) oauth2_settings.REFRESH_TOKEN_GRACE_PERIOD_SECONDS = 0 def test_refresh_invalidates_old_tokens(self): @@ -810,7 +854,7 @@ def test_refresh_bad_scopes(self): "scope": "read write nuke", } response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) - self.assertEqual(response.status_code, 401) + self.assertEqual(response.status_code, 400) def test_refresh_fail_repeating_requests(self): """ @@ -838,7 +882,7 @@ def test_refresh_fail_repeating_requests(self): response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 200) response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) - self.assertEqual(response.status_code, 401) + self.assertEqual(response.status_code, 400) def test_refresh_repeating_requests(self): """ @@ -877,7 +921,7 @@ def test_refresh_repeating_requests(self): rt.save() response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) - self.assertEqual(response.status_code, 401) + self.assertEqual(response.status_code, 400) oauth2_settings.REFRESH_TOKEN_GRACE_PERIOD_SECONDS = 0 def test_refresh_repeating_requests_non_rotating_tokens(self): @@ -926,7 +970,7 @@ def test_basic_auth_bad_authcode(self): auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) - self.assertEqual(response.status_code, 401) + self.assertEqual(response.status_code, 400) def test_basic_auth_bad_granttype(self): """ @@ -962,7 +1006,7 @@ def test_basic_auth_grant_expired(self): auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) - self.assertEqual(response.status_code, 401) + self.assertEqual(response.status_code, 400) def test_basic_auth_bad_secret(self): """ @@ -1079,6 +1123,307 @@ def test_id_token_public(self): self.assertIn("id_token", content) self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + def test_public_pkce_S256_authorize_get(self): + """ + Request an access token using client_type: public + and PKCE enabled. Tests if the authorize get is successfull + for the S256 algorithm + """ + self.client.login(username="test_user", password="123456") + + self.application.client_type = Application.CLIENT_PUBLIC + self.application.save() + code_verifier, code_challenge = self.generate_pkce_codes("S256") + oauth2_settings.PKCE_REQUIRED = True + + query_string = urlencode({ + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "http://example.org", + "response_type": "code", + "allow": True, + "code_challenge": code_challenge, + "code_challenge_method": "S256" + }) + url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + + response = self.client.get(url) + self.assertEqual(response.status_code, 200) + oauth2_settings.PKCE_REQUIRED = False + + def test_public_pkce_plain_authorize_get(self): + """ + Request an access token using client_type: public + and PKCE enabled. Tests if the authorize get is successfull + for the plain algorithm + """ + self.client.login(username="test_user", password="123456") + + self.application.client_type = Application.CLIENT_PUBLIC + self.application.save() + code_verifier, code_challenge = self.generate_pkce_codes("plain") + oauth2_settings.PKCE_REQUIRED = True + + query_string = urlencode({ + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "http://example.org", + "response_type": "code", + "allow": True, + "code_challenge": code_challenge, + "code_challenge_method": "plain" + }) + url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + + response = self.client.get(url) + self.assertEqual(response.status_code, 200) + oauth2_settings.PKCE_REQUIRED = False + + def test_public_pkce_S256(self): + """ + Request an access token using client_type: public + and PKCE enabled with the S256 algorithm + """ + self.client.login(username="test_user", password="123456") + + self.application.client_type = Application.CLIENT_PUBLIC + self.application.save() + code_verifier, code_challenge = self.generate_pkce_codes("S256") + authorization_code = self.get_pkce_auth(code_challenge, "S256") + oauth2_settings.PKCE_REQUIRED = True + + token_request_data = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": "http://example.org", + "client_id": self.application.client_id, + "code_verifier": code_verifier + } + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) + self.assertEqual(response.status_code, 200) + + content = json.loads(response.content.decode("utf-8")) + self.assertEqual(content["token_type"], "Bearer") + self.assertEqual(content["scope"], "read write") + self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + oauth2_settings.PKCE_REQUIRED = False + + def test_public_pkce_plain(self): + """ + Request an access token using client_type: public + and PKCE enabled with the plain algorithm + """ + self.client.login(username="test_user", password="123456") + + self.application.client_type = Application.CLIENT_PUBLIC + self.application.save() + code_verifier, code_challenge = self.generate_pkce_codes("plain") + authorization_code = self.get_pkce_auth(code_challenge, "plain") + oauth2_settings.PKCE_REQUIRED = True + + token_request_data = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": "http://example.org", + "client_id": self.application.client_id, + "code_verifier": code_verifier + } + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) + self.assertEqual(response.status_code, 200) + + content = json.loads(response.content.decode("utf-8")) + self.assertEqual(content["token_type"], "Bearer") + self.assertEqual(content["scope"], "read write") + self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + oauth2_settings.PKCE_REQUIRED = False + + def test_public_pkce_invalid_algorithm(self): + """ + Request an access token using client_type: public + and PKCE enabled with an invalid algorithm + """ + self.client.login(username="test_user", password="123456") + + self.application.client_type = Application.CLIENT_PUBLIC + self.application.save() + code_verifier, code_challenge = self.generate_pkce_codes("invalid") + oauth2_settings.PKCE_REQUIRED = True + + query_string = urlencode({ + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "http://example.org", + "response_type": "code", + "allow": True, + "code_challenge": code_challenge, + "code_challenge_method": "invalid", + }) + url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + + response = self.client.get(url) + self.assertEqual(response.status_code, 302) + self.assertIn("error=invalid_request", response["Location"]) + oauth2_settings.PKCE_REQUIRED = False + + def test_public_pkce_missing_code_challenge(self): + """ + Request an access token using client_type: public + and PKCE enabled but with the code_challenge missing + """ + self.client.login(username="test_user", password="123456") + + self.application.client_type = Application.CLIENT_PUBLIC + self.application.skip_authorization = True + self.application.save() + code_verifier, code_challenge = self.generate_pkce_codes("S256") + oauth2_settings.PKCE_REQUIRED = True + + query_string = urlencode({ + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "http://example.org", + "response_type": "code", + "allow": True, + "code_challenge_method": "S256" + }) + url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + + response = self.client.get(url) + self.assertEqual(response.status_code, 302) + self.assertIn("error=invalid_request", response["Location"]) + oauth2_settings.PKCE_REQUIRED = False + + def test_public_pkce_missing_code_challenge_method(self): + """ + Request an access token using client_type: public + and PKCE enabled but with the code_challenge_method missing + """ + self.client.login(username="test_user", password="123456") + + self.application.client_type = Application.CLIENT_PUBLIC + self.application.save() + code_verifier, code_challenge = self.generate_pkce_codes("S256") + oauth2_settings.PKCE_REQUIRED = True + + query_string = urlencode({ + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "http://example.org", + "response_type": "code", + "allow": True, + "code_challenge": code_challenge + }) + url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + + response = self.client.get(url) + self.assertEqual(response.status_code, 200) + oauth2_settings.PKCE_REQUIRED = False + + def test_public_pkce_S256_invalid_code_verifier(self): + """ + Request an access token using client_type: public + and PKCE enabled with the S256 algorithm and an invalid code_verifier + """ + self.client.login(username="test_user", password="123456") + + self.application.client_type = Application.CLIENT_PUBLIC + self.application.save() + code_verifier, code_challenge = self.generate_pkce_codes("S256") + authorization_code = self.get_pkce_auth(code_challenge, "S256") + oauth2_settings.PKCE_REQUIRED = True + + token_request_data = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": "http://example.org", + "client_id": self.application.client_id, + "code_verifier": "invalid" + } + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) + self.assertEqual(response.status_code, 400) + oauth2_settings.PKCE_REQUIRED = False + + def test_public_pkce_plain_invalid_code_verifier(self): + """ + Request an access token using client_type: public + and PKCE enabled with the plain algorithm and an invalid code_verifier + """ + self.client.login(username="test_user", password="123456") + + self.application.client_type = Application.CLIENT_PUBLIC + self.application.save() + code_verifier, code_challenge = self.generate_pkce_codes("plain") + authorization_code = self.get_pkce_auth(code_challenge, "plain") + oauth2_settings.PKCE_REQUIRED = True + + token_request_data = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": "http://example.org", + "client_id": self.application.client_id, + "code_verifier": "invalid" + } + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) + self.assertEqual(response.status_code, 400) + oauth2_settings.PKCE_REQUIRED = False + + def test_public_pkce_S256_missing_code_verifier(self): + """ + Request an access token using client_type: public + and PKCE enabled with the S256 algorithm and the code_verifier missing + """ + self.client.login(username="test_user", password="123456") + + self.application.client_type = Application.CLIENT_PUBLIC + self.application.save() + code_verifier, code_challenge = self.generate_pkce_codes("S256") + authorization_code = self.get_pkce_auth(code_challenge, "S256") + oauth2_settings.PKCE_REQUIRED = True + + token_request_data = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": "http://example.org", + "client_id": self.application.client_id + } + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) + self.assertEqual(response.status_code, 400) + oauth2_settings.PKCE_REQUIRED = False + + def test_public_pkce_plain_missing_code_verifier(self): + """ + Request an access token using client_type: public + and PKCE enabled with the plain algorithm and the code_verifier missing + """ + self.client.login(username="test_user", password="123456") + + self.application.client_type = Application.CLIENT_PUBLIC + self.application.save() + code_verifier, code_challenge = self.generate_pkce_codes("plain") + authorization_code = self.get_pkce_auth(code_challenge, "plain") + oauth2_settings.PKCE_REQUIRED = True + + token_request_data = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": "http://example.org", + "client_id": self.application.client_id + } + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) + self.assertEqual(response.status_code, 400) + oauth2_settings.PKCE_REQUIRED = False + def test_malicious_redirect_uri(self): """ Request an access token using client_type: public and ensure redirect_uri is diff --git a/tests/test_client_credential.py b/tests/test_client_credential.py index 7ec49ed..299e826 100644 --- a/tests/test_client_credential.py +++ b/tests/test_client_credential.py @@ -1,14 +1,16 @@ -from __future__ import unicode_literals - import json +try: + from urllib.parse import quote_plus +except ImportError: + from urlparse import quote_plus + from django.contrib.auth import get_user_model from django.test import RequestFactory, TestCase from django.urls import reverse from django.views.generic import View from oauthlib.oauth2 import BackendApplicationServer -from oauth2_provider.compat import quote_plus from oauth2_provider.models import get_access_token_model, get_application_model from oauth2_provider.oauth2_backends import OAuthLibCore from oauth2_provider.oauth2_validators import OAuth2Validator diff --git a/tests/test_commands.py b/tests/test_commands.py new file mode 100644 index 0000000..8f1ddc2 --- /dev/null +++ b/tests/test_commands.py @@ -0,0 +1,127 @@ +from io import StringIO + +from django.contrib.auth import get_user_model +from django.core.management import call_command +from django.core.management.base import CommandError +from django.test import TestCase + +from oauth2_provider.models import get_application_model + +Application = get_application_model() + + +class CreateApplicationTest(TestCase): + + def test_command_creates_application(self): + output = StringIO() + self.assertEqual(Application.objects.count(), 0) + call_command( + 'createapplication', + 'confidential', + 'authorization-code', + '--redirect-uris=http://example.com http://example2.com', + stdout=output, + ) + self.assertEqual(Application.objects.count(), 1) + self.assertIn('New application created successfully', output.getvalue()) + + def test_missing_required_args(self): + self.assertEqual(Application.objects.count(), 0) + with self.assertRaises(CommandError) as ctx: + call_command( + 'createapplication', + '--redirect-uris=http://example.com http://example2.com', + ) + + self.assertIn('client_type', ctx.exception.args[0]) + self.assertIn('authorization_grant_type', ctx.exception.args[0]) + self.assertEqual(Application.objects.count(), 0) + + def test_command_creates_application_with_skipped_auth(self): + self.assertEqual(Application.objects.count(), 0) + call_command( + 'createapplication', + 'confidential', + 'authorization-code', + '--redirect-uris=http://example.com http://example2.com', + '--skip-authorization', + ) + app = Application.objects.get() + + self.assertTrue(app.skip_authorization) + + def test_application_created_normally_with_no_skipped_auth(self): + call_command( + 'createapplication', + 'confidential', + 'authorization-code', + '--redirect-uris=http://example.com http://example2.com', + ) + app = Application.objects.get() + + self.assertFalse(app.skip_authorization) + + def test_application_created_with_name(self): + call_command( + 'createapplication', + 'confidential', + 'authorization-code', + '--redirect-uris=http://example.com http://example2.com', + '--name=TEST', + ) + app = Application.objects.get() + + self.assertEqual(app.name, 'TEST') + + def test_application_created_with_client_secret(self): + call_command( + 'createapplication', + 'confidential', + 'authorization-code', + '--redirect-uris=http://example.com http://example2.com', + '--client-secret=SECRET', + ) + app = Application.objects.get() + + self.assertEqual(app.client_secret, 'SECRET') + + def test_application_created_with_client_id(self): + call_command( + 'createapplication', + 'confidential', + 'authorization-code', + '--redirect-uris=http://example.com http://example2.com', + '--client-id=someId', + ) + app = Application.objects.get() + + self.assertEqual(app.client_id, 'someId') + + def test_application_created_with_user(self): + User = get_user_model() + user = User.objects.create() + call_command( + 'createapplication', + 'confidential', + 'authorization-code', + '--redirect-uris=http://example.com http://example2.com', + '--user=%s' % user.pk, + ) + app = Application.objects.get() + + self.assertEqual(app.user, user) + + def test_validation_failed_message(self): + output = StringIO() + call_command( + 'createapplication', + 'confidential', + 'authorization-code', + '--redirect-uris=http://example.com http://example2.com', + '--user=783', + stdout=output, + ) + + self.assertIn('user', output.getvalue()) + self.assertIn('783', output.getvalue()) + self.assertIn('does not exist', output.getvalue()) diff --git a/tests/test_generator.py b/tests/test_generator.py index 3e810c6..211713b 100644 --- a/tests/test_generator.py +++ b/tests/test_generator.py @@ -1,5 +1,3 @@ -from __future__ import unicode_literals - from django.test import TestCase from oauth2_provider.generators import ( diff --git a/tests/test_hybrid.py b/tests/test_hybrid.py index 43f7483..cc7417d 100644 --- a/tests/test_hybrid.py +++ b/tests/test_hybrid.py @@ -1,16 +1,19 @@ -from __future__ import unicode_literals - import base64 import datetime import json +try: + from urllib.parse import parse_qs, urlencode, urlparse +except ImportError: + from urlparse import parse_qs, urlparse + from urllib import urlencode + from django.contrib.auth import get_user_model from django.test import RequestFactory, TestCase from django.urls import reverse from django.utils import timezone from oauthlib.oauth2.rfc6749 import errors as oauthlib_errors -from oauth2_provider.compat import parse_qs, urlencode, urlparse from oauth2_provider.models import ( get_access_token_model, get_application_model, get_grant_model, get_refresh_token_model @@ -96,6 +99,7 @@ def test_request_is_not_overwritten_code_id_token(self): "state": "random_state_string", "scope": "openid read write", "redirect_uri": "http://example.org", + "nonce": "nonce", }) url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) @@ -111,6 +115,7 @@ def test_request_is_not_overwritten_code_id_token_token(self): "state": "random_state_string", "scope": "openid read write", "redirect_uri": "http://example.org", + "nonce": "nonce", }) url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) @@ -218,6 +223,7 @@ def test_id_token_pre_auth_valid_client(self): "state": "random_state_string", "scope": "openid", "redirect_uri": "http://example.org", + "nonce": "nonce", }) url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) @@ -414,6 +420,7 @@ def test_code_post_auth_allow_code_id_token(self): "redirect_uri": "http://example.org", "response_type": "code id_token", "allow": True, + "nonce": "nonce", } response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) @@ -436,6 +443,7 @@ def test_code_post_auth_allow_code_id_token_token(self): "redirect_uri": "http://example.org", "response_type": "code id_token token", "allow": True, + "nonce": "nonce", } response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) @@ -459,6 +467,7 @@ def test_id_token_code_post_auth_allow(self): "redirect_uri": "http://example.org", "response_type": "code id_token", "allow": True, + "nonce": "nonce", } response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) @@ -579,6 +588,7 @@ def test_code_post_auth_allow_custom_redirect_uri_scheme_code_id_token(self): "redirect_uri": "custom-scheme://example.com", "response_type": "code id_token", "allow": True, + "nonce": "nonce", } response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) @@ -602,6 +612,7 @@ def test_code_post_auth_allow_custom_redirect_uri_scheme_code_id_token_token(sel "redirect_uri": "custom-scheme://example.com", "response_type": "code id_token token", "allow": True, + "nonce": "nonce", } response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) @@ -671,6 +682,7 @@ def test_code_post_auth_redirection_uri_with_querystring_code_id_token(self): "redirect_uri": "http://example.com?foo=bar", "response_type": "code id_token", "allow": True, + "nonce": "nonce", } response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) @@ -694,6 +706,7 @@ def test_code_post_auth_redirection_uri_with_querystring_code_id_token_token(sel "redirect_uri": "http://example.com?foo=bar", "response_type": "code id_token token", "allow": True, + "nonce": "nonce", } response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) @@ -722,7 +735,7 @@ def test_code_post_auth_failing_redirection_uri_with_querystring(self): response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) self.assertEqual(response.status_code, 302) - self.assertEqual("http://example.com?foo=bar&error=access_denied", response["Location"]) + self.assertEqual("http://example.com?foo=bar&error=access_denied&state=random_state_string", response["Location"]) def test_code_post_auth_fails_when_redirect_uri_path_is_invalid(self): """ @@ -755,6 +768,7 @@ def get_auth(self, scope="read write"): "redirect_uri": "http://example.org", "response_type": "code id_token", "allow": True, + "nonce": "nonce", } response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) @@ -797,7 +811,7 @@ def test_basic_auth_bad_authcode(self): auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) - self.assertEqual(response.status_code, 401) + self.assertEqual(response.status_code, 400) def test_basic_auth_bad_granttype(self): """ @@ -833,7 +847,7 @@ def test_basic_auth_grant_expired(self): auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) - self.assertEqual(response.status_code, 401) + self.assertEqual(response.status_code, 400) def test_basic_auth_bad_secret(self): """ diff --git a/tests/test_implicit.py b/tests/test_implicit.py index 2bd2c76..353c516 100644 --- a/tests/test_implicit.py +++ b/tests/test_implicit.py @@ -1,4 +1,9 @@ -from __future__ import unicode_literals + +try: + from urllib.parse import parse_qs, urlencode, urlparse +except ImportError: + from urlparse import parse_qs, urlparse + from urllib import urlencode import json @@ -8,7 +13,6 @@ from jwcrypto import jwk, jwt -from oauth2_provider.compat import parse_qs, urlencode, urlparse from oauth2_provider.models import get_application_model from oauth2_provider.settings import oauth2_settings from oauth2_provider.views import ProtectedResourceView diff --git a/tests/test_introspection_auth.py b/tests/test_introspection_auth.py index 1c02c32..fd7504b 100644 --- a/tests/test_introspection_auth.py +++ b/tests/test_introspection_auth.py @@ -1,5 +1,3 @@ -from __future__ import unicode_literals - import calendar import datetime diff --git a/tests/test_introspection_view.py b/tests/test_introspection_view.py index 4c2695a..c4c6d15 100644 --- a/tests/test_introspection_view.py +++ b/tests/test_introspection_view.py @@ -1,5 +1,3 @@ -from __future__ import unicode_literals - import calendar import datetime diff --git a/tests/test_mixins.py b/tests/test_mixins.py index a4a1165..79988c9 100644 --- a/tests/test_mixins.py +++ b/tests/test_mixins.py @@ -1,5 +1,3 @@ -from __future__ import unicode_literals - from django.core.exceptions import ImproperlyConfigured from django.test import RequestFactory, TestCase from django.views.generic import View diff --git a/tests/test_models.py b/tests/test_models.py index 13afb09..ec8e2f9 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,26 +1,30 @@ -from __future__ import unicode_literals +from datetime import datetime as dt +import pytest from django.contrib.auth import get_user_model -from django.core.exceptions import ValidationError +from django.core.exceptions import ImproperlyConfigured, ValidationError from django.test import TestCase from django.test.utils import override_settings from django.utils import timezone from oauth2_provider.models import ( - get_access_token_model, get_application_model, - get_grant_model, get_refresh_token_model + clear_expired, get_access_token_model, get_application_model, + get_grant_model, get_refresh_token_model, get_id_token_model ) from oauth2_provider.settings import oauth2_settings +from .models import SampleRefreshToken Application = get_application_model() Grant = get_grant_model() AccessToken = get_access_token_model() RefreshToken = get_refresh_token_model() UserModel = get_user_model() +IDToken = get_id_token_model() class TestModels(TestCase): + def setUp(self): self.user = UserModel.objects.create_user("test_user", "test@example.com", "123456") @@ -120,6 +124,7 @@ def test_scopes_property(self): OAUTH2_PROVIDER_GRANT_MODEL="tests.SampleGrant" ) class TestCustomModels(TestCase): + def setUp(self): self.user = UserModel.objects.create_user("test_user", "test@example.com", "123456") @@ -262,6 +267,7 @@ def test_expires_can_be_none(self): class TestAccessTokenModel(TestCase): + def setUp(self): self.user = UserModel.objects.create_user("test_user", "test@example.com", "123456") @@ -291,3 +297,102 @@ class TestRefreshTokenModel(TestCase): def test_str(self): refresh_token = RefreshToken(token="test_token") self.assertEqual("%s" % refresh_token, refresh_token.token) + + +class TestClearExpired(TestCase): + + def setUp(self): + self.user = UserModel.objects.create_user("test_user", "test@example.com", "123456") + app1 = Application.objects.create( + name="Test Application", + redirect_uris=( + "http://localhost http://example.com http://example.org custom-scheme://example.com" + ), + user=self.user, + client_type=Application.CLIENT_CONFIDENTIAL, + authorization_grant_type=Application.GRANT_AUTHORIZATION_CODE, + ) + app2 = Application.objects.create( + name="Test Application", + redirect_uris=( + "http://localhost http://example.com http://example.org custom-scheme://example.com" + ), + user=self.user, + client_type=Application.CLIENT_CONFIDENTIAL, + authorization_grant_type=Application.GRANT_AUTHORIZATION_CODE, + ) + id1 = IDToken.objects.create( + token="666", + expires=dt.now(), + scope=2, + application=app1, + user=self.user, + created=dt.now(), + updated=dt.now(), + ) + id2 = IDToken.objects.create( + token="999", + expires=dt.now(), + scope=2, + application=app2, + user=self.user, + created=dt.now(), + updated=dt.now(), + ) + refresh_token1 = SampleRefreshToken.objects.create( + token="test_token", + application=app1, + user=self.user, + ) + refresh_token2 = SampleRefreshToken.objects.create( + token="test_token2", + application=app2, + user=self.user, + ) + # Insert two tokens on database. + AccessToken.objects.create( + id=1, + token="555", + expires=dt.now(), + scope=2, + application=app1, + id_token=id1, + user=self.user, + created=dt.now(), + updated=dt.now(), + refresh_token=refresh_token1, + ) + AccessToken.objects.create( + id=2, + token="666", + expires=dt.now(), + scope=2, + application=app2, + user=self.user, + id_token=id2, + created=dt.now(), + updated=dt.now(), + refresh_token=refresh_token2, + ) + + def test_clear_expired_tokens(self): + oauth2_settings.REFRESH_TOKEN_EXPIRE_SECONDS = 60 + assert clear_expired() is None + + def test_clear_expired_tokens_incorect_timetype(self): + oauth2_settings.REFRESH_TOKEN_EXPIRE_SECONDS = "A" + with pytest.raises(ImproperlyConfigured) as excinfo: + clear_expired() + result = excinfo.value.__class__.__name__ + assert result == "ImproperlyConfigured" + + def test_clear_expired_tokens_with_tokens(self): + self.client.login(username="test_user", password="123456") + oauth2_settings.REFRESH_TOKEN_EXPIRE_SECONDS = 0 + ttokens = AccessToken.objects.count() + expiredt = AccessToken.objects.filter(expires__lte=dt.now()).count() + assert ttokens == 2 + assert expiredt == 2 + clear_expired() + expiredt = AccessToken.objects.filter(expires__lte=dt.now()).count() + assert expiredt == 0 diff --git a/tests/test_oauth2_backends.py b/tests/test_oauth2_backends.py index d844da5..2381e9c 100644 --- a/tests/test_oauth2_backends.py +++ b/tests/test_oauth2_backends.py @@ -65,7 +65,7 @@ def test_create_token_response_gets_extra_credentials(self): payload = "grant_type=password&username=john&password=123456" request = self.factory.post("/o/token/", payload, content_type="application/x-www-form-urlencoded") - with mock.patch("oauthlib.oauth2.Server.create_token_response") as create_token_response: + with mock.patch("oauthlib.openid.connect.core.endpoints.pre_configured.Server.create_token_response") as create_token_response: mocked = mock.MagicMock() create_token_response.return_value = mocked, mocked, mocked core = self.MyOAuthLibCore() diff --git a/tests/test_oauth2_validators.py b/tests/test_oauth2_validators.py index 578e733..dd07d37 100644 --- a/tests/test_oauth2_validators.py +++ b/tests/test_oauth2_validators.py @@ -1,5 +1,5 @@ -import datetime import contextlib +import datetime from django.contrib.auth import get_user_model from django.test import TransactionTestCase @@ -270,6 +270,23 @@ def test_save_bearer_token__with_no_refresh_token__creates_new_access_token_only self.assertEqual(0, RefreshToken.objects.count()) self.assertEqual(1, AccessToken.objects.count()) + def test_save_bearer_token__with_new_token__calls_methods_to_create_access_and_refresh_tokens(self): + token = { + "scope": "foo bar", + "refresh_token": "abc", + "access_token": "123", + } + # Mock private methods to create access and refresh tokens + create_access_token_mock = mock.MagicMock() + create_refresh_token_mock = mock.MagicMock() + self.validator._create_refresh_token = create_refresh_token_mock + self.validator._create_access_token = create_access_token_mock + + self.validator.save_bearer_token(token, self.request) + + create_access_token_mock.assert_called_once() + create_refresh_token_mock.asert_called_once() + class TestOAuth2ValidatorProvidesErrorData(TransactionTestCase): """These test cases check that the recommended error codes are returned diff --git a/tests/test_password.py b/tests/test_password.py index 9a295c9..a6f3f5d 100644 --- a/tests/test_password.py +++ b/tests/test_password.py @@ -1,5 +1,3 @@ -from __future__ import unicode_literals - import json from django.contrib.auth import get_user_model @@ -78,7 +76,7 @@ def test_bad_credentials(self): auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) - self.assertEqual(response.status_code, 401) + self.assertEqual(response.status_code, 400) class TestPasswordProtectedResource(BaseTest): diff --git a/tests/test_rest_framework.py b/tests/test_rest_framework.py index d5a18bf..0251d98 100644 --- a/tests/test_rest_framework.py +++ b/tests/test_rest_framework.py @@ -2,17 +2,20 @@ from django.conf.urls import include, url from django.contrib.auth import get_user_model +from django.core.exceptions import ImproperlyConfigured from django.http import HttpResponse from django.test import TestCase from django.test.utils import override_settings from django.utils import timezone from rest_framework import permissions +from rest_framework.authentication import BaseAuthentication from rest_framework.test import APIRequestFactory, force_authenticate from rest_framework.views import APIView from oauth2_provider.contrib.rest_framework import ( IsAuthenticatedOrTokenHasScope, OAuth2Authentication, - TokenHasReadWriteScope, TokenHasResourceScope, TokenHasScope + TokenHasReadWriteScope, TokenHasResourceScope, + TokenHasScope, TokenMatchesOASRequirements ) from oauth2_provider.models import get_access_token_model, get_application_model from oauth2_provider.settings import oauth2_settings @@ -38,6 +41,9 @@ def get(self, request): def post(self, request): return HttpResponse({"a": 1, "b": 2, "c": 3}) + def put(self, request): + return HttpResponse({"a": 1, "b": 2, "c": 3}) + class OAuth2View(MockView): authentication_classes = [OAuth2Authentication] @@ -45,7 +51,7 @@ class OAuth2View(MockView): class ScopedView(OAuth2View): permission_classes = [permissions.IsAuthenticated, TokenHasScope] - required_scopes = ["scope1"] + required_scopes = ["scope1", "another"] class AuthenticatedOrScopedView(OAuth2View): @@ -62,13 +68,56 @@ class ResourceScopedView(OAuth2View): required_scopes = ["resource1"] +class MethodScopeAltView(OAuth2View): + permission_classes = [TokenMatchesOASRequirements] + required_alternate_scopes = { + "GET": [["read"]], + "POST": [["create"]], + "PUT": [["update", "put"], ["update", "edit"]], + "DELETE": [["delete"], ["deleter", "write"]], + } + + +class MethodScopeAltViewBad(OAuth2View): + permission_classes = [TokenMatchesOASRequirements] + + +class MissingAuthentication(BaseAuthentication): + def authenticate(self, request): + return ("junk", "junk",) + + +class BrokenOAuth2View(MockView): + authentication_classes = [MissingAuthentication] + + +class TokenHasScopeViewWrongAuth(BrokenOAuth2View): + permission_classes = [TokenHasScope] + + +class MethodScopeAltViewWrongAuth(BrokenOAuth2View): + permission_classes = [TokenMatchesOASRequirements] + +class AuthenticationNone(OAuth2Authentication): + def authenticate(self, request): + return None + +class AuthenticationNoneOAuth2View(MockView): + authentication_classes = [AuthenticationNone] + + urlpatterns = [ url(r"^oauth2/", include("oauth2_provider.urls")), url(r"^oauth2-test/$", OAuth2View.as_view()), url(r"^oauth2-scoped-test/$", ScopedView.as_view()), + url(r"^oauth2-scoped-missing-auth/$", TokenHasScopeViewWrongAuth.as_view()), url(r"^oauth2-read-write-test/$", ReadWriteScopedView.as_view()), url(r"^oauth2-resource-scoped-test/$", ResourceScopedView.as_view()), url(r"^oauth2-authenticated-or-scoped-test/$", AuthenticatedOrScopedView.as_view()), + url(r"^oauth2-method-scope-test/.*$", MethodScopeAltView.as_view()), + url(r"^oauth2-method-scope-fail/$", MethodScopeAltViewBad.as_view()), + url(r"^oauth2-method-scope-missing-auth/$", MethodScopeAltViewWrongAuth.as_view()), + url(r"^oauth2-authentication-none/$", AuthenticationNoneOAuth2View.as_view()), ] @@ -142,13 +191,19 @@ def test_authentication_or_scope_denied(self): self.assertEqual(response.status_code, 403) def test_scoped_permission_allow(self): - self.access_token.scope = "scope1" + self.access_token.scope = "scope1 another" self.access_token.save() auth = self._create_authorization_header(self.access_token.token) response = self.client.get("/oauth2-scoped-test/", HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 200) + def test_scope_missing_scope_attr(self): + auth = self._create_authorization_header("fake-token") + with self.assertRaises(AssertionError) as e: + self.client.get("/oauth2-scoped-missing-auth/", HTTP_AUTHORIZATION=auth) + self.assertTrue("`oauth2_provider.rest_framework.OAuth2Authentication`" in str(e.exception)) + def test_authenticated_or_scoped_permission_allow(self): self.access_token.scope = "scope1" self.access_token.save() @@ -255,7 +310,7 @@ def test_required_scope_in_response(self): auth = self._create_authorization_header(self.access_token.token) response = self.client.get("/oauth2-scoped-test/", HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 403) - self.assertEqual(response.data["required_scopes"], ["scope1"]) + self.assertEqual(response.data["required_scopes"], ["scope1", "another"]) def test_required_scope_not_in_response_by_default(self): self.access_token.scope = "scope2" @@ -265,3 +320,95 @@ def test_required_scope_not_in_response_by_default(self): response = self.client.get("/oauth2-scoped-test/", HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 403) self.assertNotIn("required_scopes", response.data) + + def test_method_scope_alt_permission_get_allow(self): + self.access_token.scope = "read" + self.access_token.save() + + auth = self._create_authorization_header(self.access_token.token) + response = self.client.get("/oauth2-method-scope-test/", HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 200) + + def test_method_scope_alt_permission_post_allow(self): + self.access_token.scope = "create" + self.access_token.save() + + auth = self._create_authorization_header(self.access_token.token) + response = self.client.post("/oauth2-method-scope-test/", HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 200) + + def test_method_scope_alt_permission_put_allow(self): + self.access_token.scope = "edit update" + self.access_token.save() + + auth = self._create_authorization_header(self.access_token.token) + response = self.client.put("/oauth2-method-scope-test/123", HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 200) + + def test_method_scope_alt_permission_put_fail(self): + self.access_token.scope = "edit" + self.access_token.save() + + auth = self._create_authorization_header(self.access_token.token) + response = self.client.put("/oauth2-method-scope-test/123", HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 403) + + def test_method_scope_alt_permission_get_deny(self): + self.access_token.scope = "write" + self.access_token.save() + + auth = self._create_authorization_header(self.access_token.token) + response = self.client.get("/oauth2-method-scope-test/", HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 403) + + def test_method_scope_alt_permission_post_deny(self): + self.access_token.scope = "read" + self.access_token.save() + + auth = self._create_authorization_header(self.access_token.token) + response = self.client.post("/oauth2-method-scope-test/", HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 403) + + def test_method_scope_alt_no_token(self): + self.access_token.scope = "" + self.access_token.save() + + auth = self._create_authorization_header(self.access_token.token) + self.access_token = None + response = self.client.post("/oauth2-method-scope-test/", HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 403) + + def test_method_scope_alt_missing_attr(self): + self.access_token.scope = "read" + self.access_token.save() + + auth = self._create_authorization_header(self.access_token.token) + with self.assertRaises(ImproperlyConfigured): + self.client.post("/oauth2-method-scope-fail/", HTTP_AUTHORIZATION=auth) + + def test_method_scope_alt_missing_patch_method(self): + self.access_token.scope = "update" + self.access_token.save() + + auth = self._create_authorization_header(self.access_token.token) + response = self.client.patch("/oauth2-method-scope-test/", HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 403) + + def test_method_scope_alt_empty_scope(self): + self.access_token.scope = "" + self.access_token.save() + + auth = self._create_authorization_header(self.access_token.token) + response = self.client.patch("/oauth2-method-scope-test/", HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 403) + + def test_method_scope_alt_missing_scope_attr(self): + auth = self._create_authorization_header("fake-token") + with self.assertRaises(AssertionError) as e: + self.client.get("/oauth2-method-scope-missing-auth/", HTTP_AUTHORIZATION=auth) + self.assertTrue("`oauth2_provider.rest_framework.OAuth2Authentication`" in str(e.exception)) + + def test_authentication_none(self): + auth = self._create_authorization_header(self.access_token.token) + response = self.client.get("/oauth2-authentication-none/", HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 401) diff --git a/tests/test_scopes.py b/tests/test_scopes.py index daccfed..2529254 100644 --- a/tests/test_scopes.py +++ b/tests/test_scopes.py @@ -1,13 +1,15 @@ -from __future__ import unicode_literals - import json +try: + from urllib.parse import parse_qs, urlparse +except ImportError: + from urlparse import parse_qs, urlparse + from django.contrib.auth import get_user_model from django.core.exceptions import ImproperlyConfigured from django.test import RequestFactory, TestCase from django.urls import reverse -from oauth2_provider.compat import parse_qs, urlparse from oauth2_provider.models import ( get_access_token_model, get_application_model, get_grant_model ) diff --git a/tests/test_scopes_backend.py b/tests/test_scopes_backend.py index 06d45b0..5f62961 100644 --- a/tests/test_scopes_backend.py +++ b/tests/test_scopes_backend.py @@ -1,5 +1,3 @@ -from __future__ import absolute_import, unicode_literals - from oauth2_provider.scopes import SettingsScopes diff --git a/tests/test_token_revocation.py b/tests/test_token_revocation.py index c752064..04144e8 100644 --- a/tests/test_token_revocation.py +++ b/tests/test_token_revocation.py @@ -1,13 +1,15 @@ -from __future__ import unicode_literals - import datetime +try: + from urllib.parse import urlencode +except ImportError: + from urlparse import urlencode + from django.contrib.auth import get_user_model from django.test import RequestFactory, TestCase from django.urls import reverse from django.utils import timezone -from oauth2_provider.compat import urlencode from oauth2_provider.models import ( get_access_token_model, get_application_model, get_refresh_token_model ) @@ -153,6 +155,31 @@ def test_revoke_refresh_token(self): self.assertIsNotNone(refresh_token.revoked) self.assertFalse(AccessToken.objects.filter(id=rtok.access_token.id).exists()) + def test_revoke_refresh_token_with_revoked_access_token(self): + tok = AccessToken.objects.create( + user=self.test_user, token="1234567890", + application=self.application, + expires=timezone.now() + datetime.timedelta(days=1), + scope="read write" + ) + rtok = RefreshToken.objects.create( + user=self.test_user, token="999999999", + application=self.application, access_token=tok + ) + for token in (tok.token, rtok.token): + query_string = urlencode({ + "client_id": self.application.client_id, + "client_secret": self.application.client_secret, + "token": token, + }) + url = "{url}?{qs}".format(url=reverse("oauth2_provider:revoke-token"), qs=query_string) + response = self.client.post(url) + self.assertEqual(response.status_code, 200) + + self.assertFalse(AccessToken.objects.filter(id=tok.id).exists()) + refresh_token = RefreshToken.objects.filter(id=rtok.id).first() + self.assertIsNotNone(refresh_token.revoked) + def test_revoke_token_with_wrong_hint(self): """ From the revocation rfc, `Section 4.1.2`_ : diff --git a/tests/test_token_view.py b/tests/test_token_view.py index 5c0a92d..67fa1a5 100644 --- a/tests/test_token_view.py +++ b/tests/test_token_view.py @@ -1,5 +1,3 @@ -from __future__ import unicode_literals - import datetime from django.contrib.auth import get_user_model diff --git a/tests/test_validators.py b/tests/test_validators.py index d9d3297..82930a9 100644 --- a/tests/test_validators.py +++ b/tests/test_validators.py @@ -1,43 +1,62 @@ -from __future__ import unicode_literals - from django.core.validators import ValidationError from django.test import TestCase from oauth2_provider.settings import oauth2_settings -from oauth2_provider.validators import validate_uris +from oauth2_provider.validators import RedirectURIValidator class TestValidators(TestCase): def test_validate_good_uris(self): - good_uris = "http://example.com/ http://example.org/?key=val http://example" - # Check ValidationError not thrown - validate_uris(good_uris) + validator = RedirectURIValidator(allowed_schemes=["https"]) + good_uris = [ + "https://example.com/", + "https://example.org/?key=val", + "https://example", + "https://localhost", + "https://1.1.1.1", + "https://127.0.0.1", + "https://255.255.255.255", + ] + for uri in good_uris: + # Check ValidationError not thrown + validator(uri) def test_validate_custom_uri_scheme(self): - oauth2_settings.ALLOWED_REDIRECT_URI_SCHEMES = ["my-scheme", "http"] - good_uris = "my-scheme://example.com http://example.com" - # Check ValidationError not thrown - validate_uris(good_uris) - - def test_validate_whitespace_separators(self): - # Check that whitespace can be used as a separator - good_uris = "http://example\r\nhttp://example\thttp://example" - # Check ValidationError not thrown - validate_uris(good_uris) + validator = RedirectURIValidator(allowed_schemes=["my-scheme", "https", "git+ssh"]) + good_uris = [ + "my-scheme://example.com", + "my-scheme://example", + "my-scheme://localhost", + "https://example.com", + "HTTPS://example.com", + "git+ssh://example.com", + ] + for uri in good_uris: + # Check ValidationError not thrown + validator(uri) def test_validate_bad_uris(self): - bad_uri = "http://example.com/#fragment" - self.assertRaises(ValidationError, validate_uris, bad_uri) - bad_uri = "http:/example.com" - self.assertRaises(ValidationError, validate_uris, bad_uri) - # Bad IPv6 URL, urlparse behaves differently for these - bad_uri = "https://[\">" - self.assertRaises(ValidationError, validate_uris, bad_uri) - bad_uri = "my-scheme://example.com" - self.assertRaises(ValidationError, validate_uris, bad_uri) - bad_uri = "sdklfsjlfjljdflksjlkfjsdkl" - self.assertRaises(ValidationError, validate_uris, bad_uri) - bad_uri = " " - self.assertRaises(ValidationError, validate_uris, bad_uri) - bad_uri = "" - self.assertRaises(ValidationError, validate_uris, bad_uri) + validator = RedirectURIValidator(allowed_schemes=["https"]) + oauth2_settings.ALLOWED_REDIRECT_URI_SCHEMES = ["https", "good"] + bad_uris = [ + "http:/example.com", + "HTTP://localhost", + "HTTP://example.com", + "HTTP://example.com.", + "http://example.com/#fragment", + "123://example.com", + "http://fe80::1", + "git+ssh://example.com", + "my-scheme://example.com", + "uri-without-a-scheme", + "https://example.com/#fragment", + "good://example.com/#fragment", + " ", + "", + # Bad IPv6 URL, urlparse behaves differently for these + 'https://[">', + ] + + for uri in bad_uris: + with self.assertRaises(ValidationError): + validator(uri) diff --git a/tests/utils.py b/tests/utils.py index 29bdb58..9e29c48 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,5 +1,3 @@ -from __future__ import unicode_literals - import base64 diff --git a/tox.ini b/tox.ini index 1ec8f55..f1156d5 100644 --- a/tox.ini +++ b/tox.ini @@ -1,27 +1,28 @@ [tox] envlist = - py27-django{111}, - py35-django{111,20,master}, - py36-django{111,20,master}, - docs, - flake8 + py34-django20, + py35-django{20,21,master}, + py36-django{20,21,master}, + py37-django{20,21,master}, + py36-docs, + py36-flake8 [pytest] django_find_project = false [testenv] -commands = - pip install https://github.com/oauthlib/oauthlib/archive/master.tar.gz - pytest --cov=oauth2_provider --cov-report= --cov-append {posargs} +commands = + pytest --cov=oauth2_provider --cov-report= --cov-append {posargs} -s setenv = DJANGO_SETTINGS_MODULE = tests.settings PYTHONPATH = {toxinidir} PYTHONWARNINGS = all deps = - django111: Django>=1.11,<2.0 django20: Django>=2.0,<2.1 + django21: Django>=2.1,<2.2 djangomaster: https://github.com/django/django/archive/master.tar.gz - djangorestframework >=3.5 + djangorestframework + oauthlib>=3.0.1 coverage pytest pytest-cov @@ -30,22 +31,22 @@ deps = py27: mock jwcrypto -[testenv:docs] +[testenv:py36-docs] basepython = python changedir = docs whitelist_externals = make commands = make html deps = sphinx + oauthlib>=3.0.1 -[testenv:flake8] +[testenv:py36-flake8] skip_install = True commands = - flake8 {toxinidir} {posargs} - isort {toxinidir} -c + flake8 {toxinidir} deps = flake8 + flake8-isort flake8-quotes - isort [coverage:run] source = oauth2_provider @@ -58,9 +59,10 @@ application-import-names = oauth2_provider inline-quotes = double [isort] -lines_after_imports = 2 +balanced_wrapping = True +default_section = THIRDPARTY known_first_party = oauth2_provider +line_length = 80 +lines_after_imports = 2 multi_line_output = 5 skip = oauth2_provider/migrations/, .tox/ -line_length = 80 -balanced_wrapping = True