diff --git a/.gitignore b/.gitignore index c22ef00..3643335 100644 --- a/.gitignore +++ b/.gitignore @@ -25,6 +25,7 @@ __pycache__ pip-log.txt # Unit test / coverage reports +.cache .pytest_cache .coverage .tox diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 2aef56d..0000000 --- a/.travis.yml +++ /dev/null @@ -1,59 +0,0 @@ -# https://travis-ci.org/jazzband/django-oauth-toolkit -dist: bionic - -language: python - -cache: - directories: - - $HOME/.cache/pip - - $TRAVIS_BUILD_DIR/.tox - -# Make sure to coordinate changes to envlist in tox.ini. -matrix: - allow_failures: - - env: TOXENV=py36-djangomaster - - env: TOXENV=py37-djangomaster - - env: TOXENV=py38-djangomaster - - include: - - python: 3.7 - env: TOXENV=py37-flake8 - - python: 3.7 - env: TOXENV=py37-docs - - - python: 3.8 - env: TOXENV=py38-django30 - - python: 3.8 - env: TOXENV=py38-django22 - - python: 3.8 - env: TOXENV=py38-django21 - - python: 3.8 - env: TOXENV=py38-djangomaster - - - python: 3.7 - env: TOXENV=py37-django30 - - python: 3.7 - env: TOXENV=py37-django22 - - python: 3.7 - env: TOXENV=py37-django21 - - python: 3.7 - env: TOXENV=py37-djangomaster - - - python: 3.6 - env: TOXENV=py36-django22 - - python: 3.6 - env: TOXENV=py36-django21 - - - python: 3.5 - env: TOXENV=py35-django22 - - python: 3.5 - env: TOXENV=py35-django21 - -install: - - pip install coveralls tox tox-travis - -script: - - tox - -after_script: - - coveralls diff --git a/AUTHORS b/AUTHORS index cbcefa2..5058928 100644 --- a/AUTHORS +++ b/AUTHORS @@ -7,23 +7,50 @@ Federico Frenguelli Contributors ============ -Alessandro De Angelis +Abhishek Patel Alan Crosswell -Asif Saif Uddin -Ash Christopher +Aleksander Vaskevich +Alessandro De Angelis +Allisson Azevedo +Anvesh Agarwal Aristóbulo Meneses +Aryan Iyappan +Ash Christopher +Asif Saif Uddin Bart Merenda Bas van Oostveen +Dave Burkholder David Fischer +David Smith Diego Garcia +Dulmandakh Sukhbaatar +Dylan Giesler Emanuele Palazzetti Federico Dolce +Frederico Vieira +Hasan Ramezani Hiroki Kiyohara Jens Timmerman Jerome Leclanche Jim Graham +Jonas Nygaard Pedersen +Jonathan Steffan +Jun Zhou +Kristian Rune Larsen +Paul Dekkers Paul Oswald -pySilver +Pavel Tvrdík Rodney Richardson +Rustem Saiargaliev +Sandro Rodrigues +Shaun Stanworth Silvano Cerza +Spencer Carroll Stéphane Raimbault +Tom Evans +Will Beaufoy +Rustem Saiargaliev +Jadiel Teófilo +pySilver +Łukasz Skarżyński +Shaheed Haque diff --git a/CHANGELOG.md b/CHANGELOG.md index 400bc13..52876d1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,7 +4,7 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). - -## [1.3.1] unreleased +## [unreleased] +* Remove support for Django 3.0 +* Add support for Django 3.2 + +### Added +* #712, #636, #808. Calls to `django.contrib.auth.authenticate()` now pass a `request` + to provide compatibility with backends that need one. + +### Fixed +* #524 Restrict usage of timezone aware expire dates to Django projects with USE_TZ set to True. +* #955 Avoid doubling of `oauth2_provider` urls mountpath in json response for OIDC view `ConnectDiscoveryInfoView`. + Breaks existing OIDC discovery output +* #953 Allow loopback redirect URIs with random ports using http scheme, localhost address and no explicit port + configuration in the allowed redirect_uris for Oauth2 Applications (RFC8252) + +## [2.2.0] 2021-05-10 +Aligned to [django-oauth-toolkit 1.5.0](https://github.com/jazzband/django-oauth-toolkit/pull/947) + +### Added +* #915 Add optional OpenID Connect support. + +### Changed +* #942 Help via defunct Google group replaced with using GitHub issues + +## [2.1.1] 2021-03-12 + +### Changed +* #925 OAuth2TokenMiddleware converted to new style middleware, and no longer extends MiddlewareMixin. + +### Removed +* #936 Remove support for Python 3.5 + +## [2.1.0] 2021-02-08 + +### Added +* #917 Documentation improvement for Access Token expiration. +* #916 (for DOT contributors) Added `tox -e livedocs` which launches a local web server on `locahost:8000` + to display Sphinx documentation with live updates as you edit. +* #891 (for DOT contributors) Added [details](https://django-oauth-toolkit.readthedocs.io/en/latest/contributing.html) + on how best to contribute to this project. +* #884 Added support for Python 3.9 +* #898 Added the ability to customize classes for django admin +* #690 Added pt-PT translations to HTML templates. This enables adding additional translations. + +### Fixed +* #906 Made token revocation not apply a limit to the `select_for_update` statement (impacts Oracle 12c database). +* #903 Disable `redirect_uri` field length limit for `AbstractGrant` + +## [1.3.3] 2020-10-16 + +### Added +* added `select_related` in intospect view for better query performance +* #831 Authorization token creation now can receive an expire date +* #831 Added a method to override Grant creation +* #825 Bump oauthlib to 3.1.0 to introduce PKCE +* Support for Django 3.1 + +### Fixed +* #847: Fix inappropriate message when response from authentication server is not OK. + +### Changed +* few smaller improvements to remove older django version compatibility #830, #861, #862, #863 + +## [1.3.2] 2020-03-24 + +### Fixed +* Fixes: 1.3.1 inadvertently uploaded to pypi with an extra migration (0003...) from a dev branch. + +## [1.3.1] 2020-03-23 + +### Added +* #725: HTTP Basic Auth support for introspection (Fix issue #709) + ### Fixed * #812: Reverts #643 pass wrong request object to authenticate function. * Fix concurrency issue with refresh token requests (#[810](https://github.com/jazzband/django-oauth-toolkit/pull/810)) * #817: Reverts #734 tutorial documentation error. + ## [1.3.0] 2020-03-02 ### Added diff --git a/README.rst b/README.rst index c547364..fe43598 100644 --- a/README.rst +++ b/README.rst @@ -7,17 +7,24 @@ Django OAuth Toolkit *OAuth2 goodies for the Djangonauts!* -.. image:: https://badge.fury.io/py/django-oauth-toolkit.png +.. image:: https://badge.fury.io/py/django-oauth-toolkit.svg :target: http://badge.fury.io/py/django-oauth-toolkit -.. image:: https://travis-ci.org/jazzband/django-oauth-toolkit.png - :alt: Build Status - :target: https://travis-ci.org/jazzband/django-oauth-toolkit +.. image:: https://github.com/jazzband/django-oauth-toolkit/workflows/Test/badge.svg + :target: https://github.com/jazzband/django-oauth-toolkit/actions + :alt: GitHub Actions -.. image:: https://coveralls.io/repos/github/jazzband/django-oauth-toolkit/badge.svg?branch=master - :alt: Coverage Status - :target: https://coveralls.io/github/jazzband/django-oauth-toolkit?branch=master +.. image:: https://codecov.io/gh/jazzband/django-oauth-toolkit/branch/master/graph/badge.svg + :target: https://codecov.io/gh/jazzband/django-oauth-toolkit + :alt: Coverage +.. image:: https://img.shields.io/pypi/pyversions/django-oauth-toolkit.svg + :target: https://pypi.org/project/django-oauth-toolkit/ + :alt: Supported Python versions + +.. image:: https://img.shields.io/pypi/djversions/django-oauth-toolkit.svg + :target: https://pypi.org/project/django-oauth-toolkit/ + :alt: Supported Django versions If you are facing one or more of the following: * Your Django app exposes a web API you want to protect with OAuth2 authentication, @@ -42,9 +49,9 @@ Please report any security issues to the JazzBand security team at v documentation". -#html_title = None +# html_title = None # A shorter title for the navigation bar. Default is the same as html_title. -#html_short_title = None +# html_short_title = None # The name of an image file (relative to this directory) to place at the top # of the sidebar. -#html_logo = None +# html_logo = None # The name of an image file (within the static path) to use as favicon of the # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 # pixels large. -#html_favicon = None +# html_favicon = None # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -#html_static_path = ['_static'] +# html_static_path = ['_static'] # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, # using the given strftime format. -#html_last_updated_fmt = '%b %d, %Y' +# html_last_updated_fmt = '%b %d, %Y' # If true, SmartyPants will be used to convert quotes and dashes to # typographically correct entities. -#html_use_smartypants = True +# html_use_smartypants = True # Custom sidebar templates, maps document names to template names. -#html_sidebars = {} +# html_sidebars = {} # Additional templates that should be rendered to pages, maps page names to # template names. -#html_additional_pages = {} +# html_additional_pages = {} # If false, no module index is generated. -#html_domain_indices = True +# html_domain_indices = True # If false, no index is generated. -#html_use_index = True +# html_use_index = True # If true, the index is split into individual pages for each letter. -#html_split_index = False +# html_split_index = False # If true, links to the reST sources are added to the pages. -#html_show_sourcelink = True +# html_show_sourcelink = True # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. -#html_show_sphinx = True +# html_show_sphinx = True # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. -#html_show_copyright = True +# html_show_copyright = True # If true, an OpenSearch description file will be output, and all pages will # contain a tag referring to it. The value of this option must be the # base URL from which the finished HTML is served. -#html_use_opensearch = '' +# html_use_opensearch = '' # This is the file name suffix for HTML files (e.g. ".xhtml"). -#html_file_suffix = None +# html_file_suffix = None # Output file base name for HTML help builder. -htmlhelp_basename = 'DjangoOAuthToolkitdoc' +htmlhelp_basename = "DjangoOAuthToolkitdoc" # -- Options for LaTeX output -------------------------------------------------- latex_elements = { -# The paper size ('letterpaper' or 'a4paper'). -#'papersize': 'letterpaper', - -# The font size ('10pt', '11pt' or '12pt'). -#'pointsize': '10pt', - -# Additional stuff for the LaTeX preamble. -#'preamble': '', + # The paper size ('letterpaper' or 'a4paper'). + # 'papersize': 'letterpaper', + # The font size ('10pt', '11pt' or '12pt'). + # 'pointsize': '10pt', + # Additional stuff for the LaTeX preamble. + # 'preamble': '', } # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, author, documentclass [howto/manual]). latex_documents = [ - ('index', 'DjangoOAuthToolkit.tex', u'Django OAuth Toolkit Documentation', - u'Evonove', 'manual'), + ("index", "DjangoOAuthToolkit.tex", "Django OAuth Toolkit Documentation", "Evonove", "manual"), ] # The name of an image file (relative to this directory) to place at the top of # the title page. -#latex_logo = None +# latex_logo = None # For "manual" documents, if this is true, then toplevel headings are parts, # not chapters. -#latex_use_parts = False +# latex_use_parts = False # If true, show page references after internal links. -#latex_show_pagerefs = False +# latex_show_pagerefs = False # If true, show URL addresses after external links. -#latex_show_urls = False +# latex_show_urls = False # Documents to append as an appendix to all manuals. -#latex_appendices = [] +# latex_appendices = [] # If false, no module index is generated. -#latex_domain_indices = True +# latex_domain_indices = True # -- Options for manual page output -------------------------------------------- # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). -man_pages = [ - ('index', 'djangooauthtoolkit', u'Django OAuth Toolkit Documentation', - [u'Evonove'], 1) -] +man_pages = [("index", "djangooauthtoolkit", "Django OAuth Toolkit Documentation", ["Evonove"], 1)] # If true, show URL addresses after external links. -#man_show_urls = False +# man_show_urls = False # -- Options for Texinfo output ------------------------------------------------ @@ -249,19 +251,25 @@ # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - ('index', 'DjangoOAuthToolkit', u'Django OAuth Toolkit Documentation', - u'Evonove', 'DjangoOAuthToolkit', 'One line description of project.', - 'Miscellaneous'), + ( + "index", + "DjangoOAuthToolkit", + "Django OAuth Toolkit Documentation", + "Evonove", + "DjangoOAuthToolkit", + "One line description of project.", + "Miscellaneous", + ), ] # Documents to append as an appendix to all manuals. -#texinfo_appendices = [] +# texinfo_appendices = [] # If false, no module index is generated. -#texinfo_domain_indices = True +# texinfo_domain_indices = True # How to display URL addresses: 'footnote', 'no', or 'inline'. -#texinfo_show_urls = 'footnote' +# texinfo_show_urls = 'footnote' # If true, do not generate a @detailmenu in the "Top" node's menu. -#texinfo_no_detailmenu = False +# texinfo_no_detailmenu = False diff --git a/docs/contributing.rst b/docs/contributing.rst index 021895e..c336d04 100644 --- a/docs/contributing.rst +++ b/docs/contributing.rst @@ -2,6 +2,13 @@ Contributing ============ +.. image:: https://jazzband.co/static/img/jazzband.svg + :target: https://jazzband.co/ + :alt: Jazzband + +This is a `Jazzband `_ project. By contributing you agree to abide by the `Contributor Code of Conduct `_ and follow the `guidelines `_. + + Setup ===== @@ -17,6 +24,78 @@ You can find the list of bugs, enhancements and feature requests on the `issue tracker `_. If you want to fix an issue, pick up one and add a comment stating you're working on it. +Code Style +========== + +The project uses `flake8 `_ for linting, +`black `_ for formatting the code, +`isort `_ for formatting and sorting imports, +and `pre-commit `_ for checking/fixing commits for +correctness before they are made. + +You will need to install ``pre-commit`` yourself, and then ``pre-commit`` will +take care of installing ``flake8``, ``black`` and ``isort``. + +After cloning your repository, go into it and run:: + + pre-commit install + +to install the hooks. On the next commit that you make, ``pre-commit`` will +download and install the necessary hooks (a one off task). If anything in the +commit would fail the hooks, the commit will be abandoned. For ``black`` and +``isort``, any necessary changes will be made automatically, but not staged. +Review the changes, and then re-stage and commit again. + +Using ``pre-commit`` ensures that code that would fail in QA does not make it +into a commit in the first place, and will save you time in the long run. You +can also (largely) stop worrying about code style, although you should always +check how the code looks after ``black`` has formatted it, and think if there +is a better way to structure the code so that it is more readable. + +Documentation +============= + +You can edit the documentation by editing files in ``docs/``. This project +uses sphinx to turn ``ReStructuredText`` into the HTML docs you are reading. + +In order to build the docs in to HTML, you can run:: + + tox -e docs + +This will build the docs, and place the result in ``docs/_build/html``. +Alternatively, you can run:: + + tox -e livedocs + +This will run ``sphinx`` in a live reload mode, so any changes that you make to +the ``RST`` files will be automatically detected and the HTML files rebuilt. +It will also run a simple HTTP server available at ``_ +serving the HTML files, and auto-reload the page when changes are made. + +This allows you to edit the docs and see your changes instantly reflected in +the browser. + +* `ReStructuredText primer + `_ + +Translations +============ + +You can contribute international language translations using +`django-admin makemessages `_. + +For example, to add Deutsch:: + + cd oauth2_provider + django-admin makemessages --locale de + +Then edit ``locale/de/LC_MESSAGES/django.po`` to add your translations. + +When deploying your app, don't forget to compile the messages with:: + + django-admin compilemessages + + Pull requests ============= @@ -49,7 +128,7 @@ When you begin your PR, you'll be asked to provide the following: * Any new or changed code requires that a unit test be added or updated. Make sure your tests check for correct error behavior as well as normal expected behavior. Strive for 100% code coverage of any new code you contribute! Improving unit tests is always a welcome contribution. - If your change reduces coverage, you'll be warned by `coveralls `_. + If your change reduces coverage, you'll be warned by `Codecov `_. * Update the documentation (in `docs/`) to describe the new or changed functionality. @@ -70,7 +149,7 @@ When you begin your PR, you'll be asked to provide the following: JazzBand security team ``. Do not file an issue on the tracker or submit a PR until directed to do so.) -* Make sure your name is in `AUTHORS`. +* Make sure your name is in `AUTHORS`. We want to give credit to all contrbutors! If your PR is not yet ready to be merged mark it as a Work-in-Progress By prepending `WIP:` to the PR title so that it doesn't get inadvertently approved and merged. @@ -106,6 +185,29 @@ How to get your pull request accepted We really want your code, so please follow these simple guidelines to make the process as smooth as possible. +The Checklist +------------- + +A checklist template is automatically added to your PR when you create it. Make sure you've done all the +applicable steps and check them off to indicate you have done so. This is +what you'll see when creating your PR: + + Fixes # + + ## Description of the Change + + ## Checklist + + - [ ] PR only contains one change (considered splitting up PR) + - [ ] unit-test added + - [ ] documentation updated + - [ ] `CHANGELOG.md` updated (only for user relevant changes) + - [ ] author name in `AUTHORS` + +Any PRs that are missing checklist items will not be merged and may be reverted if they are merged by +mistake. + + Run the tests! -------------- @@ -132,7 +234,7 @@ You can check your coverage locally with the `coverage `_ + +Maintainer Checklist +==================== +The following notes are to remind the project maintainers and leads of the steps required to +review and merge PRs and to publish a new release. + +Reviewing and Merging PRs +------------------------ + +- Make sure the PR description includes the `pull request template + `_ +- Confirm that all required checklist items from the PR template are both indicated as done in the + PR description and are actually done. +- Perform a careful review and ask for any needed changes. +- Make sure any PRs only ever improve code coverage percentage. +- All PRs should be be reviewed by one individual (not the submitter) and merged by another. + +PRs that are incorrectly merged may (reluctantly) be reverted by the Project Leads. + + +Publishing a Release +-------------------- + +Only Project Leads can publish a release to pypi.org and rtfd.io. This checklist is a reminder +of steps. + +- When planning a new release, create a `milestone + `_ + and assign issues, PRs, etc. to that milestone. +- Review all commits since the last release and confirm that they are properly + documented in the CHANGELOG. (Unfortunately, this has not always been the case + so you may be stuck documenting things that should have been documented as part of their PRs.) +- Make a final PR for the release that updates: + + - CHANGELOG to show the release date. + - setup.cfg to set `version = ...` + +- Once the final PR is committed push the new release to pypi and rtfd.io. diff --git a/docs/getting_started.rst b/docs/getting_started.rst new file mode 100644 index 0000000..427195a --- /dev/null +++ b/docs/getting_started.rst @@ -0,0 +1,394 @@ +Getting started +=============== + +Build a OAuth2 provider using Django, Django OAuth Toolkit, and OAuthLib. + +What we will build? +------------------- + +The plan is to build an OAuth2 provider from ground up. + +On this getting started we will: + +* Create the Django project. +* Install and configure Django OAuth Toolkit. +* Create two OAuth2 applications. +* Use Authorization code grant flow. +* Use Client Credential grant flow. + +What is OAuth? +---------------- + +OAuth is an open standard for access delegation, commonly used as a way for Internet users to grant websites or applications access to their information on other websites but without giving them the passwords. +-- `Whitson Gordon`_ + +Django +------ + +Django is a high-level Python Web framework that encourages rapid development and clean, pragmatic design. Built by experienced developers, it takes care of much of the hassle of Web development, so you can focus on writing your app without needing to reinvent the wheel. +-- `Django website`_ + +Let's get start by creating a virtual environment:: + + mkproject iam + +This will create, activate and change directory to the new Python virtual environment. + +Install Django:: + + pip install Django + +Create a Django project:: + + django-admin startproject iam + +This will create a mysite directory in your current directory. With the following estructure:: + + . + └── iam + ├── iam + │   ├── asgi.py + │   ├── __init__.py + │   ├── settings.py + │   ├── urls.py + │   └── wsgi.py + └── manage.py + +Create a Django application:: + + cd iam/ + python manage.py startapp users + +That’ll create a directory :file:`users`, which is laid out like this:: + + . + ├── iam + │   ├── asgi.py + │   ├── __init__.py + │   ├── settings.py + │   ├── urls.py + │   └── wsgi.py + ├── manage.py + └── users + ├── admin.py + ├── apps.py + ├── __init__.py + ├── migrations + │   └── __init__.py + ├── models.py + ├── tests.py + └── views.py + +If you’re starting a new project, it’s highly recommended to set up a custom user model, even if the default `User`_ model is sufficient for you. This model behaves identically to the default user model, but you’ll be able to customize it in the future if the need arises. +-- `Django documentation`_ + +Edit :file:`users/models.py` adding the code below: + +.. code-block:: python + + from django.contrib.auth.models import AbstractUser + + class User(AbstractUser): + pass + +Change :file:`iam/settings.py` to add ``users`` application to ``INSTALLED_APPS``: + +.. code-block:: python + + INSTALLED_APPS = [ + 'django.contrib.admin', + 'django.contrib.auth', + 'django.contrib.contenttypes', + 'django.contrib.sessions', + 'django.contrib.messages', + 'django.contrib.staticfiles', + 'users', + ] + +Configure ``users.User`` to be the model used for the ``auth`` application by adding ``AUTH_USER_MODEL`` to :file:`iam/settings.py`: + +.. code-block:: python + + AUTH_USER_MODEL='users.User' + +Create inital migration for ``users`` application ``User`` model:: + + python manage.py makemigrations + +The command above will create the migration:: + + Migrations for 'users': + users/migrations/0001_initial.py + - Create model User + +Finally execute the migration:: + + python manage.py migrate + +The ``migrate`` output:: + + Operations to perform: + Apply all migrations: admin, auth, contenttypes, sessions, users + Running migrations: + Applying contenttypes.0001_initial... OK + Applying contenttypes.0002_remove_content_type_name... OK + Applying auth.0001_initial... OK + Applying auth.0002_alter_permission_name_max_length... OK + Applying auth.0003_alter_user_email_max_length... OK + Applying auth.0004_alter_user_username_opts... OK + Applying auth.0005_alter_user_last_login_null... OK + Applying auth.0006_require_contenttypes_0002... OK + Applying auth.0007_alter_validators_add_error_messages... OK + Applying auth.0008_alter_user_username_max_length... OK + Applying auth.0009_alter_user_last_name_max_length... OK + Applying auth.0010_alter_group_name_max_length... OK + Applying auth.0011_update_proxy_permissions... OK + Applying users.0001_initial... OK + Applying admin.0001_initial... OK + Applying admin.0002_logentry_remove_auto_add... OK + Applying admin.0003_logentry_add_action_flag_choices... OK + Applying sessions.0001_initial... OK + +Django OAuth Toolkit +-------------------- + +Django OAuth Toolkit can help you by providing, out of the box, all the endpoints, data, and logic needed to add OAuth2 capabilities to your Django projects. + +Install Django OAuth Toolkit:: + + pip install django-oauth-toolkit + +Add ``oauth2_provider`` to ``INSTALLED_APPS`` in :file:`iam/settings.py`: + +.. code-block:: python + + INSTALLED_APPS = [ + 'django.contrib.admin', + 'django.contrib.auth', + 'django.contrib.contenttypes', + 'django.contrib.sessions', + 'django.contrib.messages', + 'django.contrib.staticfiles', + 'users', + 'oauth2_provider', + ] + +Execute the migration:: + + python manage.py migrate + +The ``migrate`` command output:: + + Operations to perform: + Apply all migrations: admin, auth, contenttypes, oauth2_provider, sessions, users + Running migrations: + Applying oauth2_provider.0001_initial... OK + Applying oauth2_provider.0002_auto_20190406_1805... OK + +Include ``oauth2_provider.urls`` to :file:`iam/urls.py` as follows: + +.. code-block:: python + + from django.contrib import admin + from django.urls import include, path + + urlpatterns = [ + path('admin/', admin.site.urls), + path('o/', include('oauth2_provider.urls', namespace='oauth2_provider')), + ] + +This will make available endpoints to authorize, generate token and create OAuth applications. + +Last change, add ``LOGIN_URL`` to :file:`iam/settings.py`: + +.. code-block:: python + + LOGIN_URL='/admin/login/' + +We will use Django Admin login to make our life easy. + +Create a user:: + + python manage.py createsuperuser + + Username: wiliam + Email address: me@wiliam.dev + Password: + Password (again): + Superuser created successfully. + +OAuth2 Authorization Grants +--------------------------- + +An authorization grant is a credential representing the resource owner's authorization (to access its protected resources) used by the client to obtain an access token. +-- `RFC6749`_ + +The OAuth framework specifies several grant types for different use cases. +-- `Grant types`_ + +We will start by given a try to the grant types listed below: + +* Authorization code +* Client credential + +These two grant types cover the most initially used use cases. + +Authorization Code +------------------ + +The Authorization Code flow is best used in web and mobile apps. This is the flow used for third party integration, the user authorizes your partner to access its products in your APIs. + +Start the development server:: + + python manage.py runserver + +Point your browser to http://127.0.0.1:8000/o/applications/register/ lets create an application. + +Fill the form as show in the screenshot bellow and before save take note of ``Client id`` and ``Client secret`` we will use it in a minute. + +.. image:: _images/application-register-auth-code.png + :alt: Authorization code application registration + +Export ``Client id`` and ``Client secret`` values as environment variable: + +.. sourcecode:: sh + + export ID=vW1RcAl7Mb0d5gyHNQIAcH110lWoOW2BmWJIero8 + export SECRET=DZFpuNjRdt5xUEzxXovAp40bU3lQvoMvF3awEStn61RXWE0Ses4RgzHWKJKTvUCHfRkhcBi3ebsEfSjfEO96vo2Sh6pZlxJ6f7KcUbhvqMMPoVxRwv4vfdWEoWMGPeIO + +To start the Authorization code flow go to this `URL`_ which is the same as shown below:: + + http://127.0.0.1:8000/o/authorize/?response_type=code&client_id=vW1RcAl7Mb0d5gyHNQIAcH110lWoOW2BmWJIero8&redirect_uri=http://127.0.0.1:8000/noexist/callback + +Note the parameters we pass: + +* **response_type**: ``code`` +* **client_id**: ``vW1RcAl7Mb0d5gyHNQIAcH110lWoOW2BmWJIero8`` +* **redirect_uri**: ``http://127.0.0.1:8000/noexist/callback`` + +This identifies your application, the user is asked to authorize your application to access its resources. + +Go ahead and authorize the ``web-app`` + +.. image:: _images/application-authorize-web-app.png + :alt: Authorization code authorize web-app + +Remember we used ``http://127.0.0.1:8000/noexist/callback`` as ``redirect_uri`` you will get a **Page not found (404)** but it worked if you get a url like:: + + http://127.0.0.1:8000/noexist/callback?code=uVqLxiHDKIirldDZQfSnDsmYW1Abj2 + +This is the OAuth2 provider trying to give you a ``code``. in this case ``uVqLxiHDKIirldDZQfSnDsmYW1Abj2``. + +Export it as an environment variable: + +.. code-block:: sh + + export CODE=uVqLxiHDKIirldDZQfSnDsmYW1Abj2 + +Now that you have the user authorization is time to get an access token:: + + curl -X POST -H "Cache-Control: no-cache" -H "Content-Type: application/x-www-form-urlencoded" "http://127.0.0.1:8000/o/token/" -d "client_id=${ID}" -d "client_secret=${SECRET}" -d "code=${CODE}" -d "redirect_uri=http://127.0.0.1:8000/noexist/callback" -d "grant_type=authorization_code" + +To be more easy to visualize:: + + curl -X POST \ + -H "Cache-Control: no-cache" \ + -H "Content-Type: application/x-www-form-urlencoded" \ + "http://127.0.0.1:8000/o/token/" \ + -d "client_id=${ID}" \ + -d "client_secret=${SECRET}" \ + -d "code=${CODE}" \ + -d "redirect_uri=http://127.0.0.1:8000/noexist/callback" \ + -d "grant_type=authorization_code" + +The OAuth2 provider will return the follow response: + +.. code-block:: javascript + + { + "access_token": "jooqrnOrNa0BrNWlg68u9sl6SkdFZg", + "expires_in": 36000, + "token_type": "Bearer", + "scope": "read write", + "refresh_token": "HNvDQjjsnvDySaK0miwG4lttJEl9yD" + } + +To access the user resources we just use the ``access_token``:: + + curl \ + -H "Authorization: Bearer jooqrnOrNa0BrNWlg68u9sl6SkdFZg" \ + -X GET http://localhost:8000/resource + +Client Credential +----------------- + +The Client Credential grant is suitable for machine-to-machine authentication. You authorize your own service or worker to change a bank account transaction status to accepted. + +Point your browser to http://127.0.0.1:8000/o/applications/register/ lets create an application. + +Fill the form as show in the screenshot below, and before saving take note of ``Client id`` and ``Client secret`` we will use it in a minute. + +.. image:: _images/application-register-client-credential.png + :alt: Client credential application registration + +Export ``Client id`` and ``Client secret`` values as environment variable: + +.. code-block:: sh + + export ID=axXSSBVuvOyGVzh4PurvKaq5MHXMm7FtrHgDMi4u + export SECRET=1fuv5WVfR7A5BlF0o155H7s5bLgXlwWLhi3Y7pdJ9aJuCdl0XV5Cxgd0tri7nSzC80qyrovh8qFXFHgFAAc0ldPNn5ZYLanxSm1SI1rxlRrWUP591wpHDGa3pSpB6dCZ + +The Client Credential flow is simpler than the Authorization Code flow. + +We need to encode ``client_id`` and ``client_secret`` as HTTP base authentication encoded in ``base64`` I use the following code to do that. + +.. code-block:: python + + >>> import base64 + >>> client_id = "axXSSBVuvOyGVzh4PurvKaq5MHXMm7FtrHgDMi4u" + >>> secret = "1fuv5WVfR7A5BlF0o155H7s5bLgXlwWLhi3Y7pdJ9aJuCdl0XV5Cxgd0tri7nSzC80qyrovh8qFXFHgFAAc0ldPNn5ZYLanxSm1SI1rxlRrWUP591wpHDGa3pSpB6dCZ" + >>> credential = "{0}:{1}".format(client_id, secret) + >>> base64.b64encode(credential.encode("utf-8")) + b'YXhYU1NCVnV2T3lHVnpoNFB1cnZLYXE1TUhYTW03RnRySGdETWk0dToxZnV2NVdWZlI3QTVCbEYwbzE1NUg3czViTGdYbHdXTGhpM1k3cGRKOWFKdUNkbDBYVjVDeGdkMHRyaTduU3pDODBxeXJvdmg4cUZYRkhnRkFBYzBsZFBObjVaWUxhbnhTbTFTSTFyeGxScldVUDU5MXdwSERHYTNwU3BCNmRDWg==' + >>> + +Export the credential as an environment variable + +.. code-block:: sh + + export CREDENTIAL=YXhYU1NCVnV2T3lHVnpoNFB1cnZLYXE1TUhYTW03RnRySGdETWk0dToxZnV2NVdWZlI3QTVCbEYwbzE1NUg3czViTGdYbHdXTGhpM1k3cGRKOWFKdUNkbDBYVjVDeGdkMHRyaTduU3pDODBxeXJvdmg4cUZYRkhnRkFBYzBsZFBObjVaWUxhbnhTbTFTSTFyeGxScldVUDU5MXdwSERHYTNwU3BCNmRDWg== + +To start the Client Credential flow you call ``/token/`` endpoint direct:: + + curl -X POST -H "Authorization: Basic ${CREDENTIAL}" -H "Cache-Control: no-cache" -H "Content-Type: application/x-www-form-urlencoded" "http://127.0.0.1:8000/o/token/" -d "grant_type=client_credentials" + +To be easier to visualize:: + + curl -X POST \ + -H "Authorization: Basic ${CREDENTIAL}" \ + -H "Cache-Control: no-cache" \ + -H "Content-Type: application/x-www-form-urlencoded" \ + "http://127.0.0.1:8000/o/token/" \ + -d "grant_type=client_credentials" + +The OAuth2 provider will return the following response: + +.. code-block:: javascript + + { + "access_token": "PaZDOD5UwzbGOFsQr34LQ7JUYOj3yK", + "expires_in": 36000, + "token_type": "Bearer", + "scope": "read write" + } + +Next step is :doc:`first tutorial `. + +.. _Django website: https://www.djangoproject.com/ +.. _Whitson Gordon: https://en.wikipedia.org/wiki/OAuth#cite_note-1 +.. _User: https://docs.djangoproject.com/en/3.0/ref/contrib/auth/#django.contrib.auth.models.User +.. _Django documentation: https://docs.djangoproject.com/en/3.0/topics/auth/customizing/#using-a-custom-user-model-when-starting-a-project +.. _RFC6749: https://tools.ietf.org/html/rfc6749#section-1.3 +.. _Grant Types: https://oauth.net/2/grant-types/ +.. _URL: http://127.0.0.1:8000/o/authorize/?response_type=code&client_id=vW1RcAl7Mb0d5gyHNQIAcH110lWoOW2BmWJIero8&redirect_uri=http://127.0.0.1:8000/noexist/callback + diff --git a/docs/index.rst b/docs/index.rst index 8716eb9..d2d4e8c 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -6,7 +6,7 @@ Welcome to Django OAuth Toolkit Documentation ============================================= -Django OAuth Toolkit can help you providing out of the box all the endpoints, data and logic needed to add OAuth2 +Django OAuth Toolkit can help you by providing, out of the box, all the endpoints, data, and logic needed to add OAuth2 capabilities to your Django projects. Django OAuth Toolkit makes extensive use of the excellent `OAuthLib `_, so that everything is `rfc-compliant `_. @@ -16,14 +16,14 @@ See our :doc:`Changelog ` for information on updates. Support ------- -If you need support please send a message to the `Django OAuth Toolkit Google Group `_ +If you need help please submit a `question `_. Requirements ------------ -* Python 3.5+ -* Django 2.1+ -* oauthlib 3.0+ +* Python 3.6+ +* Django 2.2+ +* oauthlib 3.1+ Index ===== @@ -32,6 +32,7 @@ Index :maxdepth: 2 install + getting_started tutorial/tutorial rest-framework/rest-framework views/views @@ -39,6 +40,7 @@ Index views/details models advanced_topics + oidc signals settings resource_server diff --git a/docs/install.rst b/docs/install.rst index ccff177..65dcb1d 100644 --- a/docs/install.rst +++ b/docs/install.rst @@ -19,11 +19,22 @@ If you need an OAuth2 provider you'll want to add the following to your urls.py .. code-block:: python + from django.urls import include, path + urlpatterns = [ ... path('o/', include('oauth2_provider.urls', namespace='oauth2_provider')), + ] + +Or using `re_path()` + +.. code-block:: python + + from django.urls import include, re_path + + urlpatterns = [ + ... - # using re_path re_path(r'^o/', include('oauth2_provider.urls', namespace='oauth2_provider')), ] @@ -34,4 +45,5 @@ Sync your database $ python manage.py migrate oauth2_provider -Next step is our :doc:`first tutorial `. +Next step is :doc:`getting started ` or :doc:`first tutorial `. + diff --git a/docs/models.rst b/docs/models.rst index 8fcbdc5..1e2657c 100644 --- a/docs/models.rst +++ b/docs/models.rst @@ -1,5 +1,5 @@ -`Models` -======== +Models +====== .. automodule:: oauth2_provider.models :members: diff --git a/docs/oidc.rst b/docs/oidc.rst new file mode 100644 index 0000000..87fadce --- /dev/null +++ b/docs/oidc.rst @@ -0,0 +1,308 @@ +OpenID Connect +++++++++++++++ + +OpenID Connect support +====================== + +``django-oauth-toolkit`` supports OpenID Connect (OIDC), which standardizes +authentication flows and provides a plug and play integration with other +systems. OIDC is built on top of OAuth 2.0 to provide: + +* Generating ID tokens as part of the login process. These are JWT that + describe the user, and can be used to authenticate them to your application. +* Metadata based auto-configuration for providers +* A user info endpoint, which applications can query to get more information + about a user. + +Enabling OIDC doesn't affect your existing OAuth 2.0 flows, these will +continue to work alongside OIDC. + +We support: + +* OpenID Connect Authorization Code Flow +* OpenID Connect Implicit Flow +* OpenID Connect Hybrid Flow + + +Configuration +============= + +OIDC is not enabled by default because it requires additional configuration +that must be provided. ``django-oauth-toolkit`` supports two different +algorithms for signing JWT tokens, ``RS256``, which uses asymmetric RSA keys (a +public key and a private key), and ``HS256``, which uses a symmetric key. + +It is preferrable to use ``RS256``, because this produces a token that can be +verified by anyone using the public key (which is made available and +discoverable by OIDC service auto-discovery, included with +``django-oauth-toolkit``). ``HS256`` on the other hand uses the +``client_secret`` in order to verify keys. This is simpler to implement, but +makes it harder to safely verify tokens. + +Using ``HS256`` also means that you cannot use the Implicit or Hybrid flows, +or verify the tokens in public clients, because you cannot disclose the +``client_secret`` to a public client. If you are using a public client, you +must use ``RS256``. + + +Creating RSA private key +~~~~~~~~~~~~~~~~~~~~~~~~ + +To use ``RS256`` requires an RSA private key, which is used for signing JWT. You +can generate this using the `openssl`_ tool:: + + openssl genrsa -out oidc.key 4096 + +This will generate a 4096-bit RSA key, which will be sufficient for our needs. + +.. _openssl: https://www.openssl.org + +.. warning:: + The contents of this key *must* be kept a secret. Don't put it in your + settings and commit it to version control! + + If the key is ever accidentally disclosed, an attacker could use it to + forge JWT tokens that verify as issued by your OAuth provider, which is + very bad! + + If it is ever disclosed, you should immediately replace the key. + + Safe ways to handle it would be: + + * Store it in a secure system like `Hashicorp Vault`_, and inject it in to + your environment when running your server. + * Store it in a secure file on your server, and use your initialization + scripts to inject it in to your environment. + +.. _Hashicorp Vault: https://www.hashicorp.com/products/vault + +Now we need to add this key to our settings and allow the ``openid`` scope to +be used. Assuming we have set an environment variable called +``OIDC_RSA_PRIVATE_KEY``, we can make changes to our ``settings.py``:: + + import os.environ + + OAUTH2_PROVIDER = { + "OIDC_ENABLED": True, + "OIDC_RSA_PRIVATE_KEY": os.environ.get("OIDC_RSA_PRIVATE_KEY"), + "SCOPES": { + "openid": "OpenID Connect scope", + # ... any other scopes that you use + }, + # ... any other settings you want + } + +If you are adding OIDC support to an existing OAuth 2.0 provider site, and you +are currently using a custom class for ``OAUTH2_SERVER_CLASS``, you must +change this class to derive from ``oauthlib.openid.Server`` instead of +``oauthlib.oauth2.Server``. + +With ``RSA`` key-pairs, the public key can be generated from the private key, +so there is no need to add a setting for the public key. + +Using ``HS256`` keys +~~~~~~~~~~~~~~~~~~~~ + +If you would prefer to use just ``HS256`` keys, you don't need to create any +additional keys, ``django-oauth-toolkit`` will just use the application's +``client_secret`` to sign the JWT token. + +In this case, you just need to enable OIDC and add ``openid`` to your list of +scopes in your ``settings.py``:: + + OAUTH2_PROVIDER = { + "OIDC_ENABLED": True, + "SCOPES": { + "openid": "OpenID Connect scope", + # ... any other scopes that you use + }, + # ... any other settings you want + } + +.. info:: + If you want to enable ``RS256`` at a later date, you can do so - just add + the private key as described above. + +Setting up OIDC enabled clients +=============================== + +Setting up an OIDC client in ``django-oauth-toolkit`` is simple - in fact, all +existing OAuth 2.0 Authorization Code Flow and Implicit Flow applications that +are already configured can be easily updated to use OIDC by setting the +appropriate algorithm for them to use. + +You can also switch existing apps to use OIDC Hybrid Flow by changing their +Authorization Grant Type and selecting a signing algorithm to use. + +You can read about the pros and cons of the different flows in `this excellent +article`_ from Robert Broeckelmann. + +.. _this excellent article: https://medium.com/@robert.broeckelmann/when-to-use-which-oauth2-grants-and-oidc-flows-ec6a5c00d864 + +OIDC Authorization Code Flow +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +To create an OIDC Authorization Code Flow client, create an ``Application`` +with the grant type ``Authorization code`` and select your desired signing +algorithm. + +When making an authorization request, be sure to include ``openid`` as a +scope. When the code is exchanged for the access token, the response will +also contain an ID token JWT. + +If the ``openid`` scope is not requested, authorization requests will be +treated as standard OAuth 2.0 Authorization Code Grant requests. + +With ``PKCE`` enabled, even public clients can use this flow, and it is the most +secure and recommended flow. + +OIDC Implicit Flow +~~~~~~~~~~~~~~~~~~ + +OIDC Implicit Flow is very similar to OAuth 2.0 Implicit Grant, except that +the client can request a ``response_type`` of ``id_token`` or ``id_token +token``. Requesting just ``token`` is also possible, but it would make it not +an OIDC flow and would fall back to being the same as OAuth 2.0 Implicit +Grant. + +To setup an OIDC Implicit Flow client, simply create an ``Application`` with +the a grant type of ``Implicit`` and select your desired signing algorithm, +and configure the client to request the ``openid`` scope and an OIDC +``response_type`` (``id_token`` or ``id_token token``). + + +OIDC Hybrid Flow +~~~~~~~~~~~~~~~~ + +OIDC Hybrid Flow is a mixture of the previous two flows. It allows the ID +token and an access token to be returned to the frontend, whilst also +allowing the backend to retrieve the ID token and an access token (not +necessarily the same access token) on the backend. + +To setup an OIDC Hybrid Flow application, create an ``Application`` with a +grant type of ``OpenID connect hybrid`` and select your desired signing +algorithm. + + +Customizing the OIDC responses +============================== + +This basic configuration will give you a basic working OIDC setup, but your +ID tokens will have very few claims in them, and the ``UserInfo`` service will +just return the same claims as the ID token. + +To configure all of these things we need to customize the +``OAUTH2_VALIDATOR_CLASS`` in ``django-oauth-toolkit``. Create a new file in +our project, eg ``my_project/oauth_validator.py``:: + + from oauth2_provider.oauth2_validators import OAuth2Validator + + + class CustomOAuth2Validator(OAuth2Validator): + pass + + +and then configure our site to use this in our ``settings.py``:: + + OAUTH2_PROVIDER = { + "OAUTH2_VALIDATOR_CLASS": "my_project.oauth_validators.CustomOAuth2Validator", + # ... other settings + } + +Now we can customize the tokens and the responses that are produced by adding +methods to our custom validator. + + +Adding claims to the ID token +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +By default the ID token will just have a ``sub`` claim (in addition to the +required claims, eg ``iss``, ``aud``, ``exp``, ``iat``, ``auth_time`` etc), +and the ``sub`` claim will use the primary key of the user as the value. +You'll probably want to customize this and add additional claims or change +what is sent for the ``sub`` claim. To do so, you will need to add a method to +our custom validator:: + + class CustomOAuth2Validator(OAuth2Validator): + + def get_additional_claims(self, request): + return { + "sub": request.user.email, + "first_name": request.user.first_name, + "last_name": request.user.last_name, + } + +.. note:: + This ``request`` object is not a ``django.http.Request`` object, but an + ``oauthlib.common.Request`` object. This has a number of attributes that + you can use to decide what claims to put in to the ID token: + + * ``request.scopes`` - a list of the scopes requested by the client when + making an authorization request. + * ``request.claims`` - a dictionary of the requested claims, using the + `OIDC claims requesting system`_. These must be requested by the client + when making an authorization request. + * ``request.user`` - the django user object. + +.. _OIDC claims requesting system: https://openid.net/specs/openid-connect-core-1_0.html#ClaimsParameter + +What claims you decide to put in to the token is up to you to determine based +upon what the scopes and / or claims means to your provider. + + +Adding information to the ``UserInfo`` service +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The ``UserInfo`` service is supplied as part of the OIDC service, and is used +to retrieve more information about the user than was supplied in the ID token +when the user logged in to the OIDC client application. It is optional to use +the service. The service is accessed by making a request to the +``UserInfo`` endpoint, eg ``/o/userinfo/`` and supplying the access token +retrieved at login as a ``Bearer`` token. + +Again, to modify the content delivered, we need to add a function to our +custom validator. The default implementation adds the claims from the ID +token, so you will probably want to re-use that:: + + class CustomOAuth2Validator(OAuth2Validator): + + def get_userinfo_claims(self, request): + claims = super().get_userinfo_claims(request) + claims["color_scheme"] = get_color_scheme(request.user) + return claims + + +OIDC Views +========== + +Enabling OIDC support adds three views to ``django-oauth-toolkit``. When OIDC +is not enabled, these views will log that OIDC support is not enabled, and +return a ``404`` response, or if ``DEBUG`` is enabled, raise an +``ImproperlyConfigured`` exception. + +In the docs below, it assumes that you have mounted the +``django-oauth-toolkit`` at ``/o/``. If you have mounted it elsewhere, adjust +the URLs accordingly. + + +ConnectDiscoveryInfoView +~~~~~~~~~~~~~~~~~~~~~~~~ + +Available at ``/o/.well-known/openid-configuration/``, this view provides auto +discovery information to OIDC clients, telling them the JWT issuer to use, the +location of the JWKs to verify JWTs with, the token and userinfo endpoints to +query, and other details. + + +JwksInfoView +~~~~~~~~~~~~ + +Available at ``/o/.well-known/jwks.json``, this view provides details of the key used to sign +the JWTs generated for ID tokens, so that clients are able to verify them. + + +UserInfoView +~~~~~~~~~~~~ + +Available at ``/o/userinfo/``, this view provides extra user details. You can +customize the details included in the response as described above. diff --git a/docs/requirements.txt b/docs/requirements.txt index 63d8276..c1f7269 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,4 +1,4 @@ Django>=3.0,<3.1 -oauthlib>=3.0.1 +oauthlib>=3.1.0 m2r>=0.2.1 . diff --git a/docs/settings.rst b/docs/settings.rst index d0bc62e..de7bcf8 100644 --- a/docs/settings.rst +++ b/docs/settings.rst @@ -31,7 +31,7 @@ ACCESS_TOKEN_EXPIRE_SECONDS ~~~~~~~~~~~~~~~~~~~~~~~~~~~ The number of seconds an access token remains valid. Requesting a protected resource after this duration will fail. Keep this value high enough so clients -can cache the token for a reasonable amount of time. +can cache the token for a reasonable amount of time. (default: 36000) ACCESS_TOKEN_MODEL ~~~~~~~~~~~~~~~~~~ @@ -52,6 +52,11 @@ Default: ``["http", "https"]`` A list of schemes that the ``redirect_uri`` field will be validated against. Setting this to ``["https"]`` only in production is strongly recommended. +For Native Apps the ``http`` scheme can be safely used with loopback addresses in the +Application (``[::1]`` or ``127.0.0.1``). In this case the ``redirect_uri`` can be +configured without explicit port specification, so that the Application accepts randomly +assigned ports. + Note that you may override ``Application.get_allowed_schemes()`` to set this on a per-application basis. @@ -97,10 +102,36 @@ The import string of the class (model) representing your grants. Overwrite this value if you wrote your own implementation (subclass of ``oauth2_provider.models.Grant``). +APPLICATION_ADMIN_CLASS +~~~~~~~~~~~~~~~~~ +The import string of the class (model) representing your application admin class. +Overwrite this value if you wrote your own implementation (subclass of +``oauth2_provider.admin.ApplicationAdmin``). + +ACCESS_TOKEN_ADMIN_CLASS +~~~~~~~~~~~~~~~~~ +The import string of the class (model) representing your access token admin class. +Overwrite this value if you wrote your own implementation (subclass of +``oauth2_provider.admin.AccessTokenAdmin``). + +GRANT_ADMIN_CLASS +~~~~~~~~~~~~~~~~~ +The import string of the class (model) representing your grant admin class. +Overwrite this value if you wrote your own implementation (subclass of +``oauth2_provider.admin.GrantAdmin``). + +REFRESH_TOKEN_ADMIN_CLASS +~~~~~~~~~~~~~~~~~ +The import string of the class (model) representing your refresh token admin class. +Overwrite this value if you wrote your own implementation (subclass of +``oauth2_provider.admin.RefreshTokenAdmin``). + OAUTH2_SERVER_CLASS ~~~~~~~~~~~~~~~~~~~ The import string for the ``server_class`` (or ``oauthlib.oauth2.Server`` subclass) -used in the ``OAuthLibMixin`` that implements OAuth2 grant types. +used in the ``OAuthLibMixin`` that implements OAuth2 grant types. It defaults +to ``oauthlib.oauth2.Server``, except when OIDC support is enabled, when the +default is ``oauthlib.openid.Server``. OAUTH2_VALIDATOR_CLASS ~~~~~~~~~~~~~~~~~~~~~~ @@ -118,7 +149,7 @@ The number of seconds before a refresh token gets removed from the database by the ``cleartokens`` management command. Check :ref:`cleartokens` management command for further info. NOTE: This value is completely ignored when validating refresh tokens. If you don't change the validator code and don't run cleartokens all refresh -tokens will last until revoked or the end of time. +tokens will last until revoked or the end of time. You should change this. REFRESH_TOKEN_GRACE_PERIOD_SECONDS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -198,12 +229,18 @@ Only applicable when used with `Django REST Framework /o/userinfo/``. + +OIDC_ISS_ENDPOINT +~~~~~~~~~~~~~~~~~ +Default: ``""`` + +The URL of the issuer that is used in the ID token JWT and advertised in the +OIDC discovery metadata. Clients use this location to retrieve the OIDC +discovery metadata from ``OIDC_ISS_ENDPOINT`` + +``/.well-known/openid-configuration/``. + +If unset, the default location is used, eg if ``django-oauth-toolkit`` is +mounted at ``/o``, it will be ``/o``. + +OIDC_RESPONSE_TYPES_SUPPORTED +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Default:: + + [ + "code", + "token", + "id_token", + "id_token token", + "code token", + "code id_token", + "code id_token token", + ] + + +The response types that are advertised to be supported by this server. + +OIDC_SUBJECT_TYPES_SUPPORTED +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Default: ``["public"]`` + +The subject types that are advertised to be supported by this server. + +OIDC_TOKEN_ENDPOINT_AUTH_METHODS_SUPPORTED +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Default: ``["client_secret_post", "client_secret_basic"]`` + +The authentication methods that are advertised to be supported by this server. + + +Settings imported from Django project +-------------------------- + +USE_TZ +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Used to determine whether or not to make token expire dates timezone aware. diff --git a/docs/templates.rst b/docs/templates.rst index 4b7e103..4f6320b 100644 --- a/docs/templates.rst +++ b/docs/templates.rst @@ -100,7 +100,7 @@ Example (this is the default page you may find on ``templates/oauth2_provider/au {% endif %} {% endfor %} -

{% trans "Application requires following permissions" %}

+

{% trans "Application requires the following permissions" %}

    {% for scope in scopes_descriptions %}
  • {{ scope }}
  • diff --git a/docs/tutorial/tutorial_01.rst b/docs/tutorial/tutorial_01.rst index 4c31d7b..6b605c1 100644 --- a/docs/tutorial/tutorial_01.rst +++ b/docs/tutorial/tutorial_01.rst @@ -51,13 +51,6 @@ CorsMiddleware should be placed as high as possible, especially before any middl # ... ) - # Or on Django < 1.10: - MIDDLEWARE_CLASSES = ( - # ... - 'corsheaders.middleware.CorsMiddleware', - # ... - ) - Allow CORS requests from all domains (just for the scope of this tutorial): .. code-block:: python diff --git a/docs/tutorial/tutorial_02.rst b/docs/tutorial/tutorial_02.rst index 7beb606..cdc9454 100644 --- a/docs/tutorial/tutorial_02.rst +++ b/docs/tutorial/tutorial_02.rst @@ -34,7 +34,7 @@ URL this view will respond to: .. code-block:: python - from django.conf.urls import url, include + from django.urls import path, include import oauth2_provider.views as oauth2_views from django.conf import settings from .views import ApiEndpoint @@ -65,7 +65,9 @@ URL this view will respond to: urlpatterns = [ # OAuth 2 endpoints: - path('o/', include(oauth2_endpoint_views, namespace="oauth2_provider")), + # need to pass in a tuple of the endpoints as well as the app's name + # because the app_name attribute is not set in the included module + path('o/', include((oauth2_endpoint_views, 'oauth2_provider'), namespace="oauth2_provider")), path('api/hello', ApiEndpoint.as_view()), # an example resource endpoint ] diff --git a/docs/tutorial/tutorial_03.rst b/docs/tutorial/tutorial_03.rst index d79be99..ad56e31 100644 --- a/docs/tutorial/tutorial_03.rst +++ b/docs/tutorial/tutorial_03.rst @@ -31,14 +31,6 @@ which takes care of token verification. In your settings.py: '...', ) - # Or on Django<1.10: - MIDDLEWARE_CLASSES = ( - '...', - 'django.contrib.auth.middleware.SessionAuthenticationMiddleware', - 'oauth2_provider.middleware.OAuth2TokenMiddleware', - '...', - ) - You will likely use the `django.contrib.auth.backends.ModelBackend` along with the OAuth2 backend (or you might not be able to log in into the admin), only pay attention to the order in which Django processes authentication backends. diff --git a/oauth2_provider/admin.py b/oauth2_provider/admin.py index a8d69e6..79bcf77 100644 --- a/oauth2_provider/admin.py +++ b/oauth2_provider/admin.py @@ -1,8 +1,16 @@ from django.contrib import admin -from .models import ( - get_access_token_model, get_application_model, - get_grant_model, get_id_token_model, get_refresh_token_model +from oauth2_provider.models import ( + get_access_token_admin_class, + get_access_token_model, + get_application_admin_class, + get_application_model, + get_grant_admin_class, + get_grant_model, + get_id_token_admin_class, + get_id_token_model, + get_refresh_token_admin_class, + get_refresh_token_model, ) @@ -13,12 +21,7 @@ class ApplicationAdmin(admin.ModelAdmin): "client_type": admin.HORIZONTAL, "authorization_grant_type": admin.VERTICAL, } - raw_id_fields = ("user", ) - - -class GrantAdmin(admin.ModelAdmin): - list_display = ("code", "application", "user", "expires") - raw_id_fields = ("user", ) + raw_id_fields = ("user",) class AccessTokenAdmin(admin.ModelAdmin): @@ -26,9 +29,14 @@ class AccessTokenAdmin(admin.ModelAdmin): raw_id_fields = ("user", "source_refresh_token") +class GrantAdmin(admin.ModelAdmin): + list_display = ("code", "application", "user", "expires") + raw_id_fields = ("user",) + + class IDTokenAdmin(admin.ModelAdmin): - list_display = ("token", "user", "application", "expires") - raw_id_fields = ("user", ) + list_display = ("jti", "user", "application", "expires") + raw_id_fields = ("user",) class RefreshTokenAdmin(admin.ModelAdmin): @@ -36,14 +44,20 @@ class RefreshTokenAdmin(admin.ModelAdmin): raw_id_fields = ("user", "access_token") -Application = get_application_model() -Grant = get_grant_model() -AccessToken = get_access_token_model() -IDToken = get_id_token_model() -RefreshToken = get_refresh_token_model() +application_model = get_application_model() +access_token_model = get_access_token_model() +grant_model = get_grant_model() +id_token_model = get_id_token_model() +refresh_token_model = get_refresh_token_model() + +application_admin_class = get_application_admin_class() +access_token_admin_class = get_access_token_admin_class() +grant_admin_class = get_grant_admin_class() +id_token_admin_class = get_id_token_admin_class() +refresh_token_admin_class = get_refresh_token_admin_class() -admin.site.register(Application, ApplicationAdmin) -admin.site.register(Grant, GrantAdmin) -admin.site.register(AccessToken, AccessTokenAdmin) -admin.site.register(IDToken, IDTokenAdmin) -admin.site.register(RefreshToken, RefreshTokenAdmin) +admin.site.register(application_model, application_admin_class) +admin.site.register(access_token_model, access_token_admin_class) +admin.site.register(grant_model, grant_admin_class) +admin.site.register(id_token_model, id_token_admin_class) +admin.site.register(refresh_token_model, refresh_token_admin_class) diff --git a/oauth2_provider/backends.py b/oauth2_provider/backends.py index aa7e1ec..3f6fab9 100644 --- a/oauth2_provider/backends.py +++ b/oauth2_provider/backends.py @@ -7,7 +7,7 @@ OAuthLibCore = get_oauthlib_core() -class OAuth2Backend(object): +class OAuth2Backend: """ Authenticate against an OAuth2 access token """ diff --git a/oauth2_provider/contrib/rest_framework/__init__.py b/oauth2_provider/contrib/rest_framework/__init__.py index a004c18..b54f422 100644 --- a/oauth2_provider/contrib/rest_framework/__init__.py +++ b/oauth2_provider/contrib/rest_framework/__init__.py @@ -1,6 +1,9 @@ # flake8: noqa from .authentication import OAuth2Authentication from .permissions import ( - TokenHasScope, TokenHasReadWriteScope, TokenMatchesOASRequirements, - TokenHasResourceScope, IsAuthenticatedOrTokenHasScope + IsAuthenticatedOrTokenHasScope, + TokenHasReadWriteScope, + TokenHasResourceScope, + TokenHasScope, + TokenMatchesOASRequirements, ) diff --git a/oauth2_provider/contrib/rest_framework/authentication.py b/oauth2_provider/contrib/rest_framework/authentication.py index 2283619..53087f7 100644 --- a/oauth2_provider/contrib/rest_framework/authentication.py +++ b/oauth2_provider/contrib/rest_framework/authentication.py @@ -9,16 +9,14 @@ class OAuth2Authentication(BaseAuthentication): """ OAuth 2 authentication backend using `django-oauth-toolkit` """ + www_authenticate_realm = "api" def _dict_to_string(self, my_dict): """ Return a string of comma-separated key-value pairs (e.g. k="v",k2="v2"). """ - return ",".join([ - '{k}="{v}"'.format(k=k, v=v) - for k, v in my_dict.items() - ]) + return ",".join(['{k}="{v}"'.format(k=k, v=v) for k, v in my_dict.items()]) def authenticate(self, request): """ @@ -36,9 +34,11 @@ def authenticate_header(self, request): """ Bearer is the only finalized type currently """ - www_authenticate_attributes = OrderedDict([ - ("realm", self.www_authenticate_realm,), - ]) + www_authenticate_attributes = OrderedDict( + [ + ("realm", self.www_authenticate_realm), + ] + ) oauth2_error = getattr(request, "oauth2_error", {}) www_authenticate_attributes.update(oauth2_error) return "Bearer {attributes}".format( diff --git a/oauth2_provider/contrib/rest_framework/permissions.py b/oauth2_provider/contrib/rest_framework/permissions.py index 7ba1c5c..1050bf7 100644 --- a/oauth2_provider/contrib/rest_framework/permissions.py +++ b/oauth2_provider/contrib/rest_framework/permissions.py @@ -2,9 +2,7 @@ from django.core.exceptions import ImproperlyConfigured from rest_framework.exceptions import PermissionDenied -from rest_framework.permissions import ( - SAFE_METHODS, BasePermission, IsAuthenticated -) +from rest_framework.permissions import SAFE_METHODS, BasePermission, IsAuthenticated from ...settings import oauth2_settings from .authentication import OAuth2Authentication @@ -33,10 +31,10 @@ def has_permission(self, request, view): # Provide information about required scope? include_required_scope = ( - oauth2_settings.ERROR_RESPONSE_WITH_SCOPES and - required_scopes and - not token.is_expired() and - not token.allow_scopes(required_scopes) + oauth2_settings.ERROR_RESPONSE_WITH_SCOPES + and required_scopes + and not token.is_expired() + and not token.allow_scopes(required_scopes) ) if include_required_scope: @@ -47,9 +45,11 @@ def has_permission(self, request, view): return False - assert False, ("TokenHasScope requires the" - "`oauth2_provider.rest_framework.OAuth2Authentication` authentication " - "class to be used.") + assert False, ( + "TokenHasScope requires the" + "`oauth2_provider.rest_framework.OAuth2Authentication` authentication " + "class to be used." + ) def get_scopes(self, request, view): try: @@ -96,9 +96,7 @@ def get_scopes(self, request, view): else: scope_type = oauth2_settings.WRITE_SCOPE - required_scopes = [ - "{}:{}".format(scope, scope_type) for scope in view_scopes - ] + required_scopes = ["{}:{}".format(scope, scope_type) for scope in view_scopes] return required_scopes @@ -113,6 +111,7 @@ class IsAuthenticatedOrTokenHasScope(BasePermission): the browsable api's if they log in using the a non token bassed middleware, and let them access the api's using a rest client with a token """ + def has_permission(self, request, view): is_authenticated = IsAuthenticated().has_permission(request, view) oauth2authenticated = False @@ -155,8 +154,11 @@ def has_permission(self, request, view): m = request.method.upper() if m in required_alternate_scopes: - log.debug("Required scopes alternatives to access resource: {0}" - .format(required_alternate_scopes[m])) + log.debug( + "Required scopes alternatives to access resource: {0}".format( + required_alternate_scopes[m] + ) + ) for alt in required_alternate_scopes[m]: if token.is_valid(alt): return True @@ -165,9 +167,11 @@ def has_permission(self, request, view): log.warning("no scope alternates defined for method {0}".format(m)) return False - assert False, ("TokenMatchesOASRequirements requires the" - "`oauth2_provider.rest_framework.OAuth2Authentication` authentication " - "class to be used.") + assert False, ( + "TokenMatchesOASRequirements requires the" + "`oauth2_provider.rest_framework.OAuth2Authentication` authentication " + "class to be used." + ) def get_required_alternate_scopes(self, request, view): try: @@ -175,4 +179,5 @@ def get_required_alternate_scopes(self, request, view): except AttributeError: raise ImproperlyConfigured( "TokenMatchesOASRequirements requires the view to" - " define the required_alternate_scopes attribute") + " define the required_alternate_scopes attribute" + ) diff --git a/oauth2_provider/decorators.py b/oauth2_provider/decorators.py index d4b7085..0ab26dd 100644 --- a/oauth2_provider/decorators.py +++ b/oauth2_provider/decorators.py @@ -33,7 +33,9 @@ def _validate(request, *args, **kwargs): request.resource_owner = oauthlib_req.user return view_func(request, *args, **kwargs) return HttpResponseForbidden() + return _validate + return decorator @@ -62,8 +64,7 @@ def _validate(request, *args, **kwargs): if not set(read_write_scopes).issubset(set(provided_scopes)): raise ImproperlyConfigured( "rw_protected_resource decorator requires following scopes {0}" - " to be in OAUTH2_PROVIDER['SCOPES'] list in settings".format( - read_write_scopes) + " to be in OAUTH2_PROVIDER['SCOPES'] list in settings".format(read_write_scopes) ) # Check if method is safe @@ -80,5 +81,7 @@ def _validate(request, *args, **kwargs): request.resource_owner = oauthlib_req.user return view_func(request, *args, **kwargs) return HttpResponseForbidden() + return _validate + return decorator diff --git a/oauth2_provider/exceptions.py b/oauth2_provider/exceptions.py index 2155155..c420848 100644 --- a/oauth2_provider/exceptions.py +++ b/oauth2_provider/exceptions.py @@ -2,6 +2,7 @@ class OAuthToolkitError(Exception): """ Base class for exceptions """ + def __init__(self, error=None, redirect_uri=None, *args, **kwargs): super().__init__(*args, **kwargs) self.oauthlib_error = error @@ -14,4 +15,5 @@ class FatalClientError(OAuthToolkitError): """ Class for critical errors """ + pass diff --git a/oauth2_provider/forms.py b/oauth2_provider/forms.py index 41129c4..8762136 100644 --- a/oauth2_provider/forms.py +++ b/oauth2_provider/forms.py @@ -11,3 +11,4 @@ class AllowForm(forms.Form): response_type = forms.CharField(widget=forms.HiddenInput()) code_challenge = forms.CharField(required=False, widget=forms.HiddenInput()) code_challenge_method = forms.CharField(required=False, widget=forms.HiddenInput()) + claims = forms.CharField(required=False, widget=forms.HiddenInput()) diff --git a/oauth2_provider/generators.py b/oauth2_provider/generators.py index a548088..f72bc6e 100644 --- a/oauth2_provider/generators.py +++ b/oauth2_provider/generators.py @@ -4,10 +4,11 @@ from .settings import oauth2_settings -class BaseHashGenerator(object): +class BaseHashGenerator: """ All generators should extend this class overriding `.hash()` method. """ + def hash(self): raise NotImplementedError() diff --git a/oauth2_provider/http.py b/oauth2_provider/http.py index 980cb7b..274ed81 100644 --- a/oauth2_provider/http.py +++ b/oauth2_provider/http.py @@ -11,6 +11,7 @@ class OAuth2ResponseRedirect(HttpResponse): Works like django.http.HttpResponseRedirect but we customize it to give us more flexibility on allowed scheme validation. """ + status_code = 302 def __init__(self, redirect_to, allowed_schemes, *args, **kwargs): @@ -28,6 +29,4 @@ def validate_redirect(self, redirect_to): if not parsed.scheme: raise DisallowedRedirect("OAuth2 redirects require a URI scheme.") if parsed.scheme not in self.allowed_schemes: - raise DisallowedRedirect( - "Redirect to scheme {!r} is not permitted".format(parsed.scheme) - ) + raise DisallowedRedirect("Redirect to scheme {!r} is not permitted".format(parsed.scheme)) diff --git a/oauth2_provider/locale/pt/LC_MESSAGES/django.po b/oauth2_provider/locale/pt/LC_MESSAGES/django.po new file mode 100644 index 0000000..63c4708 --- /dev/null +++ b/oauth2_provider/locale/pt/LC_MESSAGES/django.po @@ -0,0 +1,167 @@ +#, fuzzy +msgid "" +msgstr "" +"Project-Id-Version: PACKAGE VERSION\n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2019-01-25 11:45+0000\n" +"PO-Revision-Date: 2019-01-25 11:45+0000\n" +"Last-Translator: Sandro Rodrigues \n" +"Language-Team: LANGUAGE \n" +"Language: pt-PT\n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=UTF-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Plural-Forms: nplurals=2; plural=(n != 1);\n" + +#: docs/_build/html/_sources/templates.rst.txt:94 +#: oauth2_provider/templates/oauth2_provider/authorize.html:8 +#: oauth2_provider/templates/oauth2_provider/authorize.html:30 +msgid "Authorize" +msgstr "Autorizar" + +#: docs/_build/html/_sources/templates.rst.txt:103 +#: oauth2_provider/templates/oauth2_provider/authorize.html:17 +msgid "Application requires the following permissions" +msgstr "A aplicação requer as seguintes permissões" + +#: oauth2_provider/models.py:41 +msgid "Confidential" +msgstr "Confidencial" + +#: oauth2_provider/models.py:42 +msgid "Public" +msgstr "Público" + +#: oauth2_provider/models.py:50 +msgid "Authorization code" +msgstr "Código de autorização" + +#: oauth2_provider/models.py:51 +msgid "Implicit" +msgstr "Implícito" + +#: oauth2_provider/models.py:52 +msgid "Resource owner password-based" +msgstr "Palavra-passe do proprietário de dados" + +#: oauth2_provider/models.py:53 +msgid "Client credentials" +msgstr "Credenciais do cliente" + +#: oauth2_provider/models.py:67 +msgid "Allowed URIs list, space separated" +msgstr "Lista de URIs permitidos, separados por espaço" + +#: oauth2_provider/models.py:143 +#, python-brace-format +msgid "Unauthorized redirect scheme: {scheme}" +msgstr "Esquema de redirecionamento não autorizado: {scheme}" + +#: oauth2_provider/models.py:148 +#, python-brace-format +msgid "redirect_uris cannot be empty with grant_type {grant_type}" +msgstr "redirect_uris não pode estar vazio com o grant_type {grant_type}" + +#: oauth2_provider/oauth2_validators.py:166 +msgid "The access token is invalid." +msgstr "O token de acesso é inválido." + +#: oauth2_provider/oauth2_validators.py:171 +msgid "The access token has expired." +msgstr "O token de acesso expirou." + +#: oauth2_provider/oauth2_validators.py:176 +msgid "The access token is valid but does not have enough scope." +msgstr "O token de acesso é válido, mas não tem permissões suficientes." + +#: oauth2_provider/templates/oauth2_provider/application_confirm_delete.html:6 +msgid "Are you sure to delete the application" +msgstr "Tem a certeza que pretende apagar a aplicação" + +#: oauth2_provider/templates/oauth2_provider/application_confirm_delete.html:12 +#: oauth2_provider/templates/oauth2_provider/authorize.html:29 +msgid "Cancel" +msgstr "Cancelar" + +#: oauth2_provider/templates/oauth2_provider/application_confirm_delete.html:13 +#: oauth2_provider/templates/oauth2_provider/application_detail.html:38 +#: oauth2_provider/templates/oauth2_provider/authorized-token-delete.html:7 +msgid "Delete" +msgstr "Apagar" + +#: oauth2_provider/templates/oauth2_provider/application_detail.html:10 +msgid "Client id" +msgstr "ID do Cliente" + +#: oauth2_provider/templates/oauth2_provider/application_detail.html:15 +msgid "Client secret" +msgstr "Segredo do cliente" + +#: oauth2_provider/templates/oauth2_provider/application_detail.html:20 +msgid "Client type" +msgstr "Tipo de cliente" + +#: oauth2_provider/templates/oauth2_provider/application_detail.html:25 +msgid "Authorization Grant Type" +msgstr "Tipo de concessão de autorização" + +#: oauth2_provider/templates/oauth2_provider/application_detail.html:30 +msgid "Redirect Uris" +msgstr "URI's de redirecionamento" + +#: oauth2_provider/templates/oauth2_provider/application_detail.html:36 +#: oauth2_provider/templates/oauth2_provider/application_form.html:35 +msgid "Go Back" +msgstr "Voltar" + +#: oauth2_provider/templates/oauth2_provider/application_detail.html:37 +msgid "Edit" +msgstr "Editar" + +#: oauth2_provider/templates/oauth2_provider/application_form.html:9 +msgid "Edit application" +msgstr "Editar aplicação" + +#: oauth2_provider/templates/oauth2_provider/application_form.html:37 +msgid "Save" +msgstr "Guardar" + +#: oauth2_provider/templates/oauth2_provider/application_list.html:6 +msgid "Your applications" +msgstr "As tuas aplicações" + +#: oauth2_provider/templates/oauth2_provider/application_list.html:14 +msgid "New Application" +msgstr "Nova Aplicação" + +#: oauth2_provider/templates/oauth2_provider/application_list.html:17 +msgid "No applications defined" +msgstr "Sem aplicações definidas" + +#: oauth2_provider/templates/oauth2_provider/application_list.html:17 +msgid "Click here" +msgstr "Clica aqui" + +#: oauth2_provider/templates/oauth2_provider/application_list.html:17 +msgid "if you want to register a new one" +msgstr "se pretender registar uma nova" + +#: oauth2_provider/templates/oauth2_provider/application_registration_form.html:5 +msgid "Register a new application" +msgstr "Registar nova aplicação" + +#: oauth2_provider/templates/oauth2_provider/authorized-token-delete.html:6 +msgid "Are you sure you want to delete this token?" +msgstr "Tem a certeza que pretende apagar o token?" + +#: oauth2_provider/templates/oauth2_provider/authorized-tokens.html:6 +msgid "Tokens" +msgstr "Tokens" + +#: oauth2_provider/templates/oauth2_provider/authorized-tokens.html:11 +msgid "revoke" +msgstr "revogar" + +#: oauth2_provider/templates/oauth2_provider/authorized-tokens.html:19 +msgid "There are no authorized tokens yet." +msgstr "De momento, não tem tokens autorizados." diff --git a/oauth2_provider/management/commands/createapplication.py b/oauth2_provider/management/commands/createapplication.py index 95cb2d8..92c4ae4 100644 --- a/oauth2_provider/management/commands/createapplication.py +++ b/oauth2_provider/management/commands/createapplication.py @@ -72,15 +72,10 @@ def handle(self, *args, **options): try: new_application.full_clean() except ValidationError as exc: - errors = "\n ".join(["- " + err_key + ": " + str(err_value) for err_key, - err_value in exc.message_dict.items()]) - self.stdout.write( - self.style.ERROR( - "Please correct the following errors:\n %s" % errors - ) + errors = "\n ".join( + ["- " + err_key + ": " + str(err_value) for err_key, err_value in exc.message_dict.items()] ) + self.stdout.write(self.style.ERROR("Please correct the following errors:\n %s" % errors)) else: new_application.save() - self.stdout.write( - self.style.SUCCESS("New application created successfully") - ) + self.stdout.write(self.style.SUCCESS("New application created successfully")) diff --git a/oauth2_provider/middleware.py b/oauth2_provider/middleware.py index b94cb71..17ba6c3 100644 --- a/oauth2_provider/middleware.py +++ b/oauth2_provider/middleware.py @@ -1,9 +1,8 @@ from django.contrib.auth import authenticate from django.utils.cache import patch_vary_headers -from django.utils.deprecation import MiddlewareMixin -class OAuth2TokenMiddleware(MiddlewareMixin): +class OAuth2TokenMiddleware: """ Middleware for OAuth2 user authentication @@ -23,7 +22,10 @@ class OAuth2TokenMiddleware(MiddlewareMixin): reverse proxy can create proper cache keys. """ - def process_request(self, request): + def __init__(self, get_response): + self.get_response = get_response + + def __call__(self, request): # do something only if request contains a Bearer token if request.META.get("HTTP_AUTHORIZATION", "").startswith("Bearer"): if not hasattr(request, "user") or request.user.is_anonymous: @@ -31,6 +33,6 @@ def process_request(self, request): if user: request.user = request._cached_user = user - def process_response(self, request, response): + response = self.get_response(request) patch_vary_headers(response, ("Authorization",)) return response diff --git a/oauth2_provider/migrations/0001_initial.py b/oauth2_provider/migrations/0001_initial.py index f415cb6..b281dbd 100644 --- a/oauth2_provider/migrations/0001_initial.py +++ b/oauth2_provider/migrations/0001_initial.py @@ -20,7 +20,7 @@ class Migration(migrations.Migration): fields=[ ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), ('client_id', models.CharField(default=oauth2_provider.generators.generate_client_id, unique=True, max_length=100, db_index=True)), - ('redirect_uris', models.TextField(help_text='Allowed URIs list, space separated', blank=True, validators=[oauth2_provider.validators.validate_uris])), + ('redirect_uris', models.TextField(help_text='Allowed URIs list, space separated', blank=True)), ('client_type', models.CharField(max_length=32, choices=[('confidential', 'Confidential'), ('public', 'Public')])), ('authorization_grant_type', models.CharField(max_length=32, choices=[('authorization-code', 'Authorization code'), ('implicit', 'Implicit'), ('password', 'Resource owner password-based'), ('client-credentials', 'Client credentials')])), ('client_secret', models.CharField(default=oauth2_provider.generators.generate_client_secret, max_length=255, db_index=True, blank=True)), diff --git a/oauth2_provider/migrations/0014_auto_20210510_0935.py b/oauth2_provider/migrations/0014_auto_20210510_0935.py new file mode 100644 index 0000000..3fcd0c4 --- /dev/null +++ b/oauth2_provider/migrations/0014_auto_20210510_0935.py @@ -0,0 +1,43 @@ +# Generated by Django 2.2.16 on 2021-05-10 09:35 + +from django.db import migrations, models +import uuid + + +class Migration(migrations.Migration): + + dependencies = [ + ('oauth2_provider', '0013_auto_20190816_1714'), + ] + + operations = [ + migrations.RemoveField( + model_name='idtoken', + name='token', + ), + migrations.AddField( + model_name='grant', + name='claims', + field=models.TextField(blank=True), + ), + migrations.AddField( + model_name='grant', + name='nonce', + field=models.CharField(blank=True, default='', max_length=255), + ), + migrations.AddField( + model_name='idtoken', + name='jti', + field=models.UUIDField(default=uuid.uuid4, editable=False, unique=True, verbose_name='JWT Token ID'), + ), + migrations.AlterField( + model_name='application', + name='algorithm', + field=models.CharField(blank=True, choices=[('', 'No OIDC support'), ('RS256', 'RSA with SHA-2 256'), ('HS256', 'HMAC with SHA-2 256')], default='', max_length=5), + ), + migrations.AlterField( + model_name='grant', + name='redirect_uri', + field=models.TextField(), + ), + ] diff --git a/oauth2_provider/models.py b/oauth2_provider/models.py index 1421c89..aa10eca 100644 --- a/oauth2_provider/models.py +++ b/oauth2_provider/models.py @@ -1,5 +1,5 @@ -import json import logging +import uuid from datetime import timedelta from urllib.parse import parse_qsl, urlparse @@ -10,7 +10,8 @@ from django.urls import reverse from django.utils import timezone from django.utils.translation import gettext_lazy as _ -from jwcrypto import jwk, jwt +from jwcrypto import jwk +from jwcrypto.common import base64url_encode from .generators import generate_client_id, generate_client_secret from .scopes import get_scopes_backend @@ -41,6 +42,7 @@ class AbstractApplication(models.Model): the registration process as described in :rfc:`2.2` * :attr:`name` Friendly name for the Application """ + CLIENT_CONFIDENTIAL = "confidential" CLIENT_PUBLIC = "public" CLIENT_TYPES = ( @@ -61,30 +63,31 @@ class AbstractApplication(models.Model): (GRANT_OPENID_HYBRID, _("OpenID connect hybrid")), ) + NO_ALGORITHM = "" RS256_ALGORITHM = "RS256" HS256_ALGORITHM = "HS256" ALGORITHM_TYPES = ( + (NO_ALGORITHM, _("No OIDC support")), (RS256_ALGORITHM, _("RSA with SHA-2 256")), (HS256_ALGORITHM, _("HMAC with SHA-2 256")), ) id = models.BigAutoField(primary_key=True) - client_id = models.CharField( - max_length=100, unique=True, default=generate_client_id, db_index=True - ) + client_id = models.CharField(max_length=100, unique=True, default=generate_client_id, db_index=True) user = models.ForeignKey( settings.AUTH_USER_MODEL, related_name="%(app_label)s_%(class)s", - null=True, blank=True, on_delete=models.CASCADE + null=True, + blank=True, + on_delete=models.CASCADE, ) redirect_uris = models.TextField( - blank=True, help_text=_("Allowed URIs list, space separated"), + blank=True, + help_text=_("Allowed URIs list, space separated"), ) client_type = models.CharField(max_length=32, choices=CLIENT_TYPES) - authorization_grant_type = models.CharField( - max_length=32, choices=GRANT_TYPES - ) + authorization_grant_type = models.CharField(max_length=32, choices=GRANT_TYPES) client_secret = models.CharField( max_length=255, blank=True, default=generate_client_secret, db_index=True ) @@ -93,7 +96,7 @@ class AbstractApplication(models.Model): created = models.DateTimeField(auto_now_add=True) updated = models.DateTimeField(auto_now=True) - algorithm = models.CharField(max_length=5, choices=ALGORITHM_TYPES, default=RS256_ALGORITHM) + algorithm = models.CharField(max_length=5, choices=ALGORITHM_TYPES, default=NO_ALGORITHM, blank=True) class Meta: abstract = True @@ -122,21 +125,7 @@ def redirect_uri_allowed(self, uri): :param uri: Url to check """ - parsed_uri = urlparse(uri) - uqs_set = set(parse_qsl(parsed_uri.query)) - for allowed_uri in self.redirect_uris.split(): - parsed_allowed_uri = urlparse(allowed_uri) - - if (parsed_allowed_uri.scheme == parsed_uri.scheme and - parsed_allowed_uri.netloc == parsed_uri.netloc and - parsed_allowed_uri.path == parsed_uri.path): - - aqs_set = set(parse_qsl(parsed_allowed_uri.query)) - - if aqs_set.issubset(uqs_set): - return True - - return False + return redirect_to_uri_allowed(uri, self.redirect_uris.split()) def clean(self): from django.core.exceptions import ValidationError @@ -144,6 +133,11 @@ def clean(self): grant_types = ( AbstractApplication.GRANT_AUTHORIZATION_CODE, AbstractApplication.GRANT_IMPLICIT, + AbstractApplication.GRANT_OPENID_HYBRID, + ) + hs_forbidden_grant_types = ( + AbstractApplication.GRANT_IMPLICIT, + AbstractApplication.GRANT_OPENID_HYBRID, ) redirect_uris = self.redirect_uris.strip().split() @@ -155,14 +149,26 @@ def clean(self): validator(uri) scheme = urlparse(uri).scheme if scheme not in allowed_schemes: - raise ValidationError(_( - "Unauthorized redirect scheme: {scheme}" - ).format(scheme=scheme)) + raise ValidationError(_("Unauthorized redirect scheme: {scheme}").format(scheme=scheme)) elif self.authorization_grant_type in grant_types: - raise ValidationError(_( - "redirect_uris cannot be empty with grant_type {grant_type}" - ).format(grant_type=self.authorization_grant_type)) + raise ValidationError( + _("redirect_uris cannot be empty with grant_type {grant_type}").format( + grant_type=self.authorization_grant_type + ) + ) + if self.algorithm == AbstractApplication.RS256_ALGORITHM: + if not oauth2_settings.OIDC_RSA_PRIVATE_KEY: + raise ValidationError(_("You must set OIDC_RSA_PRIVATE_KEY to use RSA algorithm")) + + if self.algorithm == AbstractApplication.HS256_ALGORITHM: + if any( + ( + self.authorization_grant_type in hs_forbidden_grant_types, + self.client_type == Application.CLIENT_PUBLIC, + ) + ): + raise ValidationError(_("You cannot use HS256 with public grants or clients")) def get_absolute_url(self): return reverse("oauth2_provider:detail", args=[str(self.id)]) @@ -181,10 +187,20 @@ def is_usable(self, request): """ Determines whether the application can be used. - :param request: The HTTP request being processed. + :param request: The oauthlib.common.Request being processed. """ return True + @property + def jwk_key(self): + if self.algorithm == AbstractApplication.RS256_ALGORITHM: + if not oauth2_settings.OIDC_RSA_PRIVATE_KEY: + raise ImproperlyConfigured("You must set OIDC_RSA_PRIVATE_KEY to use RSA algorithm") + return jwk.JWK.from_pem(oauth2_settings.OIDC_RSA_PRIVATE_KEY.encode("utf8")) + elif self.algorithm == AbstractApplication.HS256_ALGORITHM: + return jwk.JWK(kty="oct", k=base64url_encode(self.client_secret)) + raise ImproperlyConfigured("This application does not support signed tokens") + class ApplicationManager(models.Manager): def get_by_natural_key(self, client_id): @@ -218,24 +234,19 @@ class AbstractGrant(models.Model): * :attr:`code_challenge` PKCE code challenge * :attr:`code_challenge_method` PKCE code challenge transform algorithm """ + CODE_CHALLENGE_PLAIN = "plain" CODE_CHALLENGE_S256 = "S256" - CODE_CHALLENGE_METHODS = ( - (CODE_CHALLENGE_PLAIN, "plain"), - (CODE_CHALLENGE_S256, "S256") - ) + CODE_CHALLENGE_METHODS = ((CODE_CHALLENGE_PLAIN, "plain"), (CODE_CHALLENGE_S256, "S256")) id = models.BigAutoField(primary_key=True) user = models.ForeignKey( - settings.AUTH_USER_MODEL, on_delete=models.CASCADE, - related_name="%(app_label)s_%(class)s" + settings.AUTH_USER_MODEL, on_delete=models.CASCADE, related_name="%(app_label)s_%(class)s" ) code = models.CharField(max_length=255, unique=True) # code comes from oauthlib - application = models.ForeignKey( - oauth2_settings.APPLICATION_MODEL, on_delete=models.CASCADE - ) + application = models.ForeignKey(oauth2_settings.APPLICATION_MODEL, on_delete=models.CASCADE) expires = models.DateTimeField() - redirect_uri = models.CharField(max_length=255) + redirect_uri = models.TextField() scope = models.TextField(blank=True) created = models.DateTimeField(auto_now_add=True) @@ -243,7 +254,11 @@ class AbstractGrant(models.Model): code_challenge = models.CharField(max_length=128, blank=True, default="") code_challenge_method = models.CharField( - max_length=10, blank=True, default="", choices=CODE_CHALLENGE_METHODS) + max_length=10, blank=True, default="", choices=CODE_CHALLENGE_METHODS + ) + + nonce = models.CharField(max_length=255, blank=True, default="") + claims = models.TextField(blank=True) def is_expired(self): """ @@ -283,23 +298,39 @@ class AbstractAccessToken(models.Model): * :attr:`expires` Date and time of token expiration, in DateTime format * :attr:`scope` Allowed scopes """ + id = models.BigAutoField(primary_key=True) user = models.ForeignKey( - settings.AUTH_USER_MODEL, on_delete=models.CASCADE, blank=True, null=True, - related_name="%(app_label)s_%(class)s" + settings.AUTH_USER_MODEL, + on_delete=models.CASCADE, + blank=True, + null=True, + related_name="%(app_label)s_%(class)s", ) source_refresh_token = models.OneToOneField( # unique=True implied by the OneToOneField - oauth2_settings.REFRESH_TOKEN_MODEL, on_delete=models.SET_NULL, blank=True, null=True, - related_name="refreshed_access_token" + oauth2_settings.REFRESH_TOKEN_MODEL, + on_delete=models.SET_NULL, + blank=True, + null=True, + related_name="refreshed_access_token", + ) + token = models.CharField( + max_length=255, + unique=True, ) - token = models.CharField(max_length=255, unique=True, ) id_token = models.OneToOneField( - oauth2_settings.ID_TOKEN_MODEL, on_delete=models.CASCADE, blank=True, null=True, - related_name="access_token" + oauth2_settings.ID_TOKEN_MODEL, + on_delete=models.CASCADE, + blank=True, + null=True, + related_name="access_token", ) application = models.ForeignKey( - oauth2_settings.APPLICATION_MODEL, on_delete=models.CASCADE, blank=True, null=True, + oauth2_settings.APPLICATION_MODEL, + on_delete=models.CASCADE, + blank=True, + null=True, ) expires = models.DateTimeField() scope = models.TextField(blank=True) @@ -380,17 +411,19 @@ class AbstractRefreshToken(models.Model): bounded to * :attr:`revoked` Timestamp of when this refresh token was revoked """ + id = models.BigAutoField(primary_key=True) user = models.ForeignKey( - settings.AUTH_USER_MODEL, on_delete=models.CASCADE, - related_name="%(app_label)s_%(class)s" + settings.AUTH_USER_MODEL, on_delete=models.CASCADE, related_name="%(app_label)s_%(class)s" ) token = models.CharField(max_length=255) - application = models.ForeignKey( - oauth2_settings.APPLICATION_MODEL, on_delete=models.CASCADE) + application = models.ForeignKey(oauth2_settings.APPLICATION_MODEL, on_delete=models.CASCADE) access_token = models.OneToOneField( - oauth2_settings.ACCESS_TOKEN_MODEL, on_delete=models.SET_NULL, blank=True, null=True, - related_name="refresh_token" + oauth2_settings.ACCESS_TOKEN_MODEL, + on_delete=models.SET_NULL, + blank=True, + null=True, + related_name="refresh_token", ) created = models.DateTimeField(auto_now_add=True) @@ -404,11 +437,10 @@ def revoke(self): access_token_model = get_access_token_model() refresh_token_model = get_refresh_token_model() with transaction.atomic(): - self = refresh_token_model.objects.filter( - pk=self.pk, revoked__isnull=True - ).select_for_update().first() - if not self: + token = refresh_token_model.objects.select_for_update().filter(pk=self.pk, revoked__isnull=True) + if not token: return + self = list(token)[0] try: access_token_model.objects.get(id=self.access_token_id).revoke() @@ -423,7 +455,10 @@ def __str__(self): class Meta: abstract = True - unique_together = ("token", "revoked",) + unique_together = ( + "token", + "revoked", + ) class RefreshToken(AbstractRefreshToken): @@ -439,19 +474,28 @@ class AbstractIDToken(models.Model): Fields: * :attr:`user` The Django user representing resources' owner - * :attr:`token` ID token + * :attr:`jti` ID token JWT Token ID, to identify an individual token * :attr:`application` Application instance * :attr:`expires` Date and time of token expiration, in DateTime format * :attr:`scope` Allowed scopes + * :attr:`created` Date and time of token creation, in DateTime format + * :attr:`updated` Date and time of token update, in DateTime format """ + id = models.BigAutoField(primary_key=True) user = models.ForeignKey( - settings.AUTH_USER_MODEL, on_delete=models.CASCADE, blank=True, null=True, - related_name="%(app_label)s_%(class)s" + settings.AUTH_USER_MODEL, + on_delete=models.CASCADE, + blank=True, + null=True, + related_name="%(app_label)s_%(class)s", ) - token = models.TextField(unique=True) + jti = models.UUIDField(unique=True, default=uuid.uuid4, editable=False, verbose_name="JWT Token ID") application = models.ForeignKey( - oauth2_settings.APPLICATION_MODEL, on_delete=models.CASCADE, blank=True, null=True, + oauth2_settings.APPLICATION_MODEL, + on_delete=models.CASCADE, + blank=True, + null=True, ) expires = models.DateTimeField() scope = models.TextField(blank=True) @@ -506,14 +550,8 @@ def scopes(self): token_scopes = self.scope.split() return {name: desc for name, desc in all_scopes.items() if name in token_scopes} - @property - def claims(self): - key = jwk.JWK.from_pem(oauth2_settings.OIDC_RSA_PRIVATE_KEY.encode("utf8")) - jwt_token = jwt.JWT(key=key, jwt=self.token) - return json.loads(jwt_token.claims) - def __str__(self): - return self.token + return "JTI: {self.jti} User: {self.user_id}".format(self=self) class Meta: abstract = True @@ -549,6 +587,36 @@ def get_refresh_token_model(): return apps.get_model(oauth2_settings.REFRESH_TOKEN_MODEL) +def get_application_admin_class(): + """ Return the Application admin class that is active in this project. """ + application_admin_class = oauth2_settings.APPLICATION_ADMIN_CLASS + return application_admin_class + + +def get_access_token_admin_class(): + """ Return the AccessToken admin class that is active in this project. """ + access_token_admin_class = oauth2_settings.ACCESS_TOKEN_ADMIN_CLASS + return access_token_admin_class + + +def get_grant_admin_class(): + """ Return the Grant admin class that is active in this project. """ + grant_admin_class = oauth2_settings.GRANT_ADMIN_CLASS + return grant_admin_class + + +def get_id_token_admin_class(): + """ Return the IDToken admin class that is active in this project. """ + id_token_admin_class = oauth2_settings.ID_TOKEN_ADMIN_CLASS + return id_token_admin_class + + +def get_refresh_token_admin_class(): + """ Return the RefreshToken admin class that is active in this project. """ + refresh_token_admin_class = oauth2_settings.REFRESH_TOKEN_ADMIN_CLASS + return refresh_token_admin_class + + def clear_expired(): now = timezone.now() refresh_expire_at = None @@ -580,13 +648,9 @@ def clear_expired(): revoked.delete() expired.delete() else: - logger.info("refresh_expire_at is %s. No refresh tokens deleted.", - refresh_expire_at) + logger.info("refresh_expire_at is %s. No refresh tokens deleted.", refresh_expire_at) - access_tokens = access_token_model.objects.filter( - refresh_token__isnull=True, - expires__lt=now - ) + access_tokens = access_token_model.objects.filter(refresh_token__isnull=True, expires__lt=now) grants = grant_model.objects.filter(expires__lt=now) logger.info("%s Expired access tokens to be deleted", access_tokens.count()) @@ -594,3 +658,50 @@ def clear_expired(): access_tokens.delete() grants.delete() + + +def redirect_to_uri_allowed(uri, allowed_uris): + """ + Checks if a given uri can be redirected to based on the provided allowed_uris configuration. + + On top of exact matches, this function also handles loopback IPs based on RFC 8252. + + :param uri: URI to check + :param allowed_uris: A list of URIs that are allowed + """ + + parsed_uri = urlparse(uri) + uqs_set = set(parse_qsl(parsed_uri.query)) + for allowed_uri in allowed_uris: + parsed_allowed_uri = urlparse(allowed_uri) + + # From RFC 8252 (Section 7.3) + # + # Loopback redirect URIs use the "http" scheme + # [...] + # The authorization server MUST allow any port to be specified at the + # time of the request for loopback IP redirect URIs, to accommodate + # clients that obtain an available ephemeral port from the operating + # system at the time of the request. + + allowed_uri_is_loopback = ( + parsed_allowed_uri.scheme == "http" + and parsed_allowed_uri.hostname in ["127.0.0.1", "::1"] + and parsed_allowed_uri.port is None + ) + if ( + allowed_uri_is_loopback + and parsed_allowed_uri.scheme == parsed_uri.scheme + and parsed_allowed_uri.hostname == parsed_uri.hostname + and parsed_allowed_uri.path == parsed_uri.path + ) or ( + parsed_allowed_uri.scheme == parsed_uri.scheme + and parsed_allowed_uri.netloc == parsed_uri.netloc + and parsed_allowed_uri.path == parsed_uri.path + ): + + aqs_set = set(parse_qsl(parsed_allowed_uri.query)) + if aqs_set.issubset(uqs_set): + return True + + return False diff --git a/oauth2_provider/oauth2_backends.py b/oauth2_provider/oauth2_backends.py index 8902f22..ee20585 100644 --- a/oauth2_provider/oauth2_backends.py +++ b/oauth2_provider/oauth2_backends.py @@ -4,19 +4,22 @@ from urllib.parse import urlencode as urllib_urlencode, SplitResult from oauthlib import oauth2 +from oauthlib.common import Request as OauthlibRequest from oauthlib.common import quote, urlencode, urlencoded +from oauthlib.oauth2 import OAuth2Error from .exceptions import FatalClientError, OAuthToolkitError from .settings import oauth2_settings -class OAuthLibCore(object): +class OAuthLibCore: """ Wrapper for oauth Server providing django-specific interfaces. Meant for things like extracting request data and converting everything to formats more palatable for oauthlib's Server. """ + def __init__(self, server=None): """ :params server: An instance of oauthlib.oauth2.Server class @@ -24,9 +27,7 @@ def __init__(self, server=None): validator_class = oauth2_settings.OAUTH2_VALIDATOR_CLASS validator = validator_class() server_kwargs = oauth2_settings.server_kwargs - self.server = server or oauth2_settings.OAUTH2_SERVER_CLASS( - validator, **server_kwargs - ) + self.server = server or oauth2_settings.OAUTH2_SERVER_CLASS(validator, **server_kwargs) def _get_escaped_full_path(self, request): """ @@ -87,6 +88,10 @@ def extract_headers(self, request): del headers["wsgi.errors"] if "HTTP_AUTHORIZATION" in headers: headers["Authorization"] = headers["HTTP_AUTHORIZATION"] + if request.is_secure(): + headers["X_DJANGO_OAUTH_TOOLKIT_SECURE"] = "1" + elif "X_DJANGO_OAUTH_TOOLKIT_SECURE" in headers: + del headers["X_DJANGO_OAUTH_TOOLKIT_SECURE"] return headers @@ -107,7 +112,8 @@ def validate_authorization_request(self, request): try: uri, http_method, body, headers = self._extract_params(request) scopes, credentials = self.server.validate_authorization_request( - uri, http_method=http_method, body=body, headers=headers) + uri, http_method=http_method, body=body, headers=headers + ) return scopes, credentials except oauth2.FatalClientError as error: @@ -115,7 +121,7 @@ def validate_authorization_request(self, request): except oauth2.OAuth2Error as error: raise OAuthToolkitError(error=error) - def create_authorization_response(self, uri, request, scopes, credentials, body, allow): + def create_authorization_response(self, request, scopes, credentials, allow): """ A wrapper method that calls create_authorization_response on `server_class` instance. @@ -123,23 +129,27 @@ def create_authorization_response(self, uri, request, scopes, credentials, body, :param request: The current django.http.HttpRequest object :param scopes: A list of provided scopes :param credentials: Authorization credentials dictionary containing - `client_id`, `state`, `redirect_uri` and `response_type` - :param body: Other body parameters not used in credentials dictionary + `client_id`, `state`, `redirect_uri`, `response_type` :param allow: True if the user authorize the client, otherwise False """ try: if not allow: - raise oauth2.AccessDeniedError( - state=credentials.get("state", None)) + raise oauth2.AccessDeniedError(state=credentials.get("state", None)) # add current user to credentials. this will be used by OAUTH2_VALIDATOR_CLASS credentials["user"] = request.user + request_uri, http_method, _, request_headers = self._extract_params(request) headers, body, status = self.server.create_authorization_response( - uri=uri, scopes=scopes, credentials=credentials, body=body) - redirect_uri = headers.get("Location", None) + uri=request_uri, + http_method=http_method, + headers=request_headers, + scopes=scopes, + credentials=credentials, + ) + uri = headers.get("Location", None) - return redirect_uri, headers, body, status + return uri, headers, body, status except oauth2.FatalClientError as error: raise FatalClientError(error=error, redirect_uri=credentials["redirect_uri"]) @@ -155,8 +165,9 @@ def create_token_response(self, request): uri, http_method, body, headers = self._extract_params(request) extra_credentials = self._get_extra_credentials(request) - headers, body, status = self.server.create_token_response(uri, http_method, body, - headers, extra_credentials) + headers, body, status = self.server.create_token_response( + uri, http_method, body, headers, extra_credentials + ) uri = headers.get("Location", None) return uri, headers, body, status @@ -170,12 +181,26 @@ def create_revocation_response(self, request): """ uri, http_method, body, headers = self._extract_params(request) - headers, body, status = self.server.create_revocation_response( - uri, http_method, body, headers) + headers, body, status = self.server.create_revocation_response(uri, http_method, body, headers) uri = headers.get("Location", None) return uri, headers, body, status + def create_userinfo_response(self, request): + """ + A wrapper method that calls create_userinfo_response on a + `server_class` instance. + + :param request: The current django.http.HttpRequest object + """ + uri, http_method, body, headers = self._extract_params(request) + try: + headers, body, status = self.server.create_userinfo_response(uri, http_method, body, headers) + uri = headers.get("Location", None) + return uri, headers, body, status + except OAuth2Error as exc: + return None, exc.headers, exc.json, exc.status_code + def verify_request(self, request, scopes): """ A wrapper method that calls verify_request on `server_class` instance. @@ -185,18 +210,24 @@ def verify_request(self, request, scopes): """ uri, http_method, body, headers = self._extract_params(request) - if ('Authorization' in headers and headers.get('Authorization').startswith('Bearer')) or getattr(request, 'access_token', None): - valid, r = self.server.verify_request(uri, http_method, body, headers, scopes=scopes) - return valid, r - else: - # Fall back to other validators - return False, request + valid, r = self.server.verify_request(uri, http_method, body, headers, scopes=scopes) + return valid, r + + def authenticate_client(self, request): + """Wrapper to call `authenticate_client` on `server_class` instance. + + :param request: The current django.http.HttpRequest object + """ + uri, http_method, body, headers = self._extract_params(request) + oauth_request = OauthlibRequest(uri, http_method, body, headers) + return self.server.request_validator.authenticate_client(oauth_request) class JSONOAuthLibCore(OAuthLibCore): """ Extends the default OAuthLibCore to parse correctly application/json requests """ + def extract_body(self, request): """ Extracts the JSON body from the Django request object diff --git a/oauth2_provider/oauth2_validators.py b/oauth2_provider/oauth2_validators.py index f6ede19..f3a24e2 100644 --- a/oauth2_provider/oauth2_validators.py +++ b/oauth2_provider/oauth2_validators.py @@ -1,8 +1,9 @@ import base64 import binascii -import hashlib +import http.client import json import logging +import uuid from collections import OrderedDict from datetime import datetime, timedelta from urllib.parse import unquote_plus @@ -13,19 +14,24 @@ from django.core.exceptions import ObjectDoesNotExist from django.db import transaction from django.db.models import Q +from django.http import HttpRequest from django.utils import dateformat, timezone from django.utils.timezone import make_aware from django.utils.translation import gettext_lazy as _ -from jwcrypto import jwk, jwt +from jwcrypto import jws, jwt from jwcrypto.common import JWException from jwcrypto.jwt import JWTExpired -from oauthlib.oauth2 import RequestValidator from oauthlib.oauth2.rfc6749 import utils +from oauthlib.openid import RequestValidator from .exceptions import FatalClientError from .models import ( - AbstractApplication, get_access_token_model, get_application_model, - get_grant_model, get_id_token_model, get_refresh_token_model + AbstractApplication, + get_access_token_model, + get_application_model, + get_grant_model, + get_id_token_model, + get_refresh_token_model, ) from .scopes import get_scopes_backend from .settings import oauth2_settings @@ -44,6 +50,7 @@ AbstractApplication.GRANT_AUTHORIZATION_CODE, AbstractApplication.GRANT_PASSWORD, AbstractApplication.GRANT_CLIENT_CREDENTIALS, + AbstractApplication.GRANT_OPENID_HYBRID, ), } @@ -100,17 +107,11 @@ def _authenticate_basic_auth(self, request): try: auth_string_decoded = b64_decoded.decode(encoding) except UnicodeDecodeError: - log.debug( - "Failed basic auth: %r can't be decoded as unicode by %r", - auth_string, - encoding, - ) + log.debug("Failed basic auth: %r can't be decoded as unicode by %r", auth_string, encoding) return False try: - client_id, client_secret = map( - unquote_plus, auth_string_decoded.split(":", 1) - ) + client_id, client_secret = map(unquote_plus, auth_string_decoded.split(":", 1)) except ValueError: log.debug("Failed basic auth, Invalid base64 encoding.") return False @@ -159,57 +160,48 @@ def _load_application(self, client_id, request): """ # we want to be sure that request has the client attribute! - assert hasattr( - request, "client" - ), '"request" instance has no "client" attribute' + assert hasattr(request, "client"), '"request" instance has no "client" attribute' try: - request.client = request.client or Application.objects.get( - client_id=client_id - ) + request.client = request.client or Application.objects.get(client_id=client_id) # Check that the application can be used (defaults to always True) if not request.client.is_usable(request): - log.debug( - "Failed body authentication: Application %r is disabled" - % (client_id) - ) + log.debug("Failed body authentication: Application %r is disabled" % (client_id)) return None return request.client except Application.DoesNotExist: - log.debug( - "Failed body authentication: Application %r does not exist" - % (client_id) - ) + log.debug("Failed body authentication: Application %r does not exist" % (client_id)) return None def _set_oauth2_error_on_request(self, request, access_token, scopes): if access_token is None: error = OrderedDict( [ - ("error", "invalid_token",), - ("error_description", _("The access token is invalid."),), + ("error", "invalid_token"), + ("error_description", _("The access token is invalid.")), ] ) elif access_token.is_expired(): error = OrderedDict( [ - ("error", "invalid_token",), - ("error_description", _("The access token has expired."),), + ("error", "invalid_token"), + ("error_description", _("The access token has expired.")), ] ) elif not access_token.allow_scopes(scopes): error = OrderedDict( [ - ("error", "insufficient_scope",), - ( - "error_description", - _("The access token is valid but does not have enough scope."), - ), + ("error", "insufficient_scope"), + ("error_description", _("The access token is valid but does not have enough scope.")), ] ) else: log.warning("OAuth2 access token is invalid for an unknown reason.") - error = OrderedDict([("error", "invalid_token",), ]) + error = OrderedDict( + [ + ("error", "invalid_token"), + ] + ) request.oauth2_error = error return request @@ -270,15 +262,11 @@ def authenticate_client_id(self, client_id, request, *args, **kwargs): proceed only if the client exists and is not of type "Confidential". """ if self._load_application(client_id, request) is not None: - log.debug( - "Application %r has type %r" % (client_id, request.client.client_type) - ) + log.debug("Application %r has type %r" % (client_id, request.client.client_type)) return request.client.client_type != AbstractApplication.CLIENT_CONFIDENTIAL return False - def confirm_redirect_uri( - self, client_id, code, redirect_uri, client, *args, **kwargs - ): + def confirm_redirect_uri(self, client_id, code, redirect_uri, client, *args, **kwargs): """ Ensure the redirect_uri is listed in the Application instance redirect_uris field """ @@ -330,12 +318,17 @@ def _get_token_from_authentication_server( headers = {"Authorization": "Basic {}".format(basic_auth.decode("utf-8"))} try: - response = requests.post( - introspection_url, data={"token": token}, headers=headers - ) + response = requests.post(introspection_url, data={"token": token}, headers=headers) except requests.exceptions.RequestException: + log.exception("Introspection: Failed POST to %r in token lookup", introspection_url) + return None + + # Log an exception when response from auth server is not successful + if response.status_code != http.client.OK: log.exception( - "Introspection: Failed POST to %r in token lookup", introspection_url + "Introspection: Failed to get a valid response " + "from authentication server. Status code: {}, " + "Reason: {}.".format(response.status_code, response.reason) ) return None @@ -365,7 +358,7 @@ def _get_token_from_authentication_server( expires = max_caching_time scope = content.get("scope", "") - expires = make_aware(expires) + expires = make_aware(expires) if settings.USE_TZ else expires access_token, _created = AccessToken.objects.update_or_create( token=token, @@ -388,25 +381,15 @@ def validate_bearer_token(self, token, scopes, request): introspection_url = oauth2_settings.RESOURCE_SERVER_INTROSPECTION_URL introspection_token = oauth2_settings.RESOURCE_SERVER_AUTH_TOKEN - introspection_credentials = ( - oauth2_settings.RESOURCE_SERVER_INTROSPECTION_CREDENTIALS - ) + introspection_credentials = oauth2_settings.RESOURCE_SERVER_INTROSPECTION_CREDENTIALS - try: - access_token = AccessToken.objects.select_related( - "application", "user" - ).get(token=token) - except AccessToken.DoesNotExist: - access_token = None + access_token = self._load_access_token(token) # if there is no token or it's invalid then introspect the token if there's an external OAuth server if not access_token or not access_token.is_valid(scopes): if introspection_url and (introspection_token or introspection_credentials): access_token = self._get_token_from_authentication_server( - token, - introspection_url, - introspection_token, - introspection_credentials, + token, introspection_url, introspection_token, introspection_credentials ) if access_token and access_token.is_valid(scopes): @@ -421,38 +404,39 @@ def validate_bearer_token(self, token, scopes, request): self._set_oauth2_error_on_request(request, access_token, scopes) return False + def _load_access_token(self, token): + return AccessToken.objects.select_related("application", "user").filter(token=token).first() + def validate_code(self, client_id, code, client, request, *args, **kwargs): try: grant = Grant.objects.get(code=code, application=client) if not grant.is_expired(): request.scopes = grant.scope.split(" ") request.user = grant.user + if grant.nonce: + request.nonce = grant.nonce + if grant.claims: + request.claims = json.loads(grant.claims) return True return False except Grant.DoesNotExist: return False - def validate_grant_type( - self, client_id, grant_type, client, request, *args, **kwargs - ): + def validate_grant_type(self, client_id, grant_type, client, request, *args, **kwargs): """ Validate both grant_type is a valid string and grant_type is allowed for current workflow """ assert grant_type in GRANT_TYPE_MAPPING # mapping misconfiguration return request.client.allows_grant_type(*GRANT_TYPE_MAPPING[grant_type]) - def validate_response_type( - self, client_id, response_type, client, request, *args, **kwargs - ): + def validate_response_type(self, client_id, response_type, client, request, *args, **kwargs): """ We currently do not support the Authorization Endpoint Response Types registry as in rfc:`8.4`, so validate the response_type only if it matches "code" or "token" """ if response_type == "code": - return client.allows_grant_type( - AbstractApplication.GRANT_AUTHORIZATION_CODE - ) + return client.allows_grant_type(AbstractApplication.GRANT_AUTHORIZATION_CODE) elif response_type == "token": return client.allows_grant_type(AbstractApplication.GRANT_IMPLICIT) elif response_type == "id_token": @@ -472,15 +456,11 @@ def validate_scopes(self, client_id, scopes, client, request, *args, **kwargs): """ Ensure required scopes are permitted (as specified in the settings file) """ - available_scopes = get_scopes_backend().get_available_scopes( - application=client, request=request - ) + available_scopes = get_scopes_backend().get_available_scopes(application=client, request=request) return set(scopes).issubset(set(available_scopes)) def get_default_scopes(self, client_id, request, *args, **kwargs): - default_scopes = get_scopes_backend().get_default_scopes( - application=request.client, request=request - ) + default_scopes = get_scopes_backend().get_default_scopes(application=request.client, request=request) return default_scopes def validate_redirect_uri(self, client_id, redirect_uri, request, *args, **kwargs): @@ -506,37 +486,13 @@ def get_code_challenge_method(self, code, request): return grant.code_challenge_method or None def save_authorization_code(self, client_id, code, request, *args, **kwargs): - expires = timezone.now() + timedelta( - seconds=oauth2_settings.AUTHORIZATION_CODE_EXPIRE_SECONDS - ) - Grant.objects.create( - application=request.client, - user=request.user, - code=code["code"], - expires=expires, - redirect_uri=request.redirect_uri, - scope=" ".join(request.scopes), - code_challenge=request.code_challenge or "", - code_challenge_method=request.code_challenge_method or "", - ) + self._create_authorization_code(request, code) def get_authorization_code_scopes(self, client_id, code, redirect_uri, request): - scopes = [] - fields = { - "code": code, - } - - if client_id: - fields["application__client_id"] = client_id - - if redirect_uri: - fields["redirect_uri"] = redirect_uri - - grant = Grant.objects.filter(**fields).values() - if grant.exists(): - grant_dict = dict(grant[0]) - scopes = utils.scope_to_list(grant_dict["scope"]) - return scopes + scopes = Grant.objects.filter(code=code).values_list("scope", flat=True).first() + if scopes: + return utils.scope_to_list(scopes) + return [] def rotate_refresh_token(self, request): """ @@ -558,9 +514,12 @@ def save_bearer_token(self, token, request, *args, **kwargs): # expires_in is passed to Server on initialization # custom server class can have logic to override this - expires = timezone.now() + timedelta(seconds=token.get( - "expires_in", oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS, - )) + expires = timezone.now() + timedelta( + seconds=token.get( + "expires_in", + oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS, + ) + ) if request.grant_type == "client_credentials": request.user = None @@ -629,17 +588,13 @@ def save_bearer_token(self, token, request, *args, **kwargs): source_refresh_token=refresh_token_instance, ) - self._create_refresh_token( - request, refresh_token_code, access_token - ) + self._create_refresh_token(request, refresh_token_code, access_token) else: # make sure that the token data we're returning matches # the existing token token["access_token"] = previous_access_token.token token["refresh_token"] = ( - RefreshToken.objects.filter(access_token=previous_access_token) - .first() - .token + RefreshToken.objects.filter(access_token=previous_access_token).first().token ) token["scope"] = previous_access_token.scope @@ -650,7 +605,7 @@ def save_bearer_token(self, token, request, *args, **kwargs): def _create_access_token(self, expires, request, token, source_refresh_token=None): id_token = token.get("id_token", None) if id_token: - id_token = IDToken.objects.get(token=id_token) + id_token = self._load_id_token(id_token) return AccessToken.objects.create( user=request.user, scope=token["scope"], @@ -661,12 +616,25 @@ def _create_access_token(self, expires, request, token, source_refresh_token=Non source_refresh_token=source_refresh_token, ) + def _create_authorization_code(self, request, code, expires=None): + if not expires: + expires = timezone.now() + timedelta(seconds=oauth2_settings.AUTHORIZATION_CODE_EXPIRE_SECONDS) + return Grant.objects.create( + application=request.client, + user=request.user, + code=code["code"], + expires=expires, + redirect_uri=request.redirect_uri, + scope=" ".join(request.scopes), + code_challenge=request.code_challenge or "", + code_challenge_method=request.code_challenge_method or "", + nonce=request.nonce or "", + claims=json.dumps(request.claims or {}), + ) + def _create_refresh_token(self, request, refresh_token_code, access_token): return RefreshToken.objects.create( - user=request.user, - token=refresh_token_code, - application=request.client, - access_token=access_token, + user=request.user, token=refresh_token_code, application=request.client, access_token=access_token ) def revoke_token(self, token, token_type_hint, request, *args, **kwargs): @@ -697,7 +665,15 @@ def validate_user(self, username, password, client, request, *args, **kwargs): """ Check username and password correspond to a valid and active User """ - u = authenticate(username=username, password=password) + # Passing the optional HttpRequest adds compatibility for backends + # which depend on its presence. Create one with attributes likely + # to be used. + http_request = HttpRequest() + http_request.path = request.uri + http_request.method = request.http_method + getattr(http_request, request.http_method).update(dict(request.decoded_body)) + http_request.META = request.headers + u = authenticate(http_request, username=username, password=password) if u is not None and u.is_active: request.user = u return True @@ -719,12 +695,13 @@ def validate_refresh_token(self, refresh_token, client, request, *args, **kwargs """ null_or_recent = Q(revoked__isnull=True) | Q( - revoked__gt=timezone.now() - - timedelta(seconds=oauth2_settings.REFRESH_TOKEN_GRACE_PERIOD_SECONDS) + revoked__gt=timezone.now() - timedelta(seconds=oauth2_settings.REFRESH_TOKEN_GRACE_PERIOD_SECONDS) + ) + rt = ( + RefreshToken.objects.filter(null_or_recent, token=refresh_token) + .select_related("access_token") + .first() ) - rt = RefreshToken.objects.filter(null_or_recent, token=refresh_token).select_related( - "access_token" - ).first() if not rt: return False @@ -736,18 +713,14 @@ def validate_refresh_token(self, refresh_token, client, request, *args, **kwargs return rt.application == client @transaction.atomic - def _save_id_token(self, token, request, expires, *args, **kwargs): - + def _save_id_token(self, jti, request, expires, *args, **kwargs): scopes = request.scope or " ".join(request.scopes) - if request.grant_type == "client_credentials": - request.user = None - id_token = IDToken.objects.create( user=request.user, scope=scopes, expires=expires, - token=token.serialize(), + jti=jti, application=request.client, ) return id_token @@ -755,65 +728,67 @@ def _save_id_token(self, token, request, expires, *args, **kwargs): def get_jwt_bearer_token(self, token, token_handler, request): return self.get_id_token(token, token_handler, request) - def get_id_token(self, token, token_handler, request): + def get_oidc_claims(self, token, token_handler, request): + # Required OIDC claims + claims = { + "sub": str(request.user.id), + } - key = jwk.JWK.from_pem(oauth2_settings.OIDC_RSA_PRIVATE_KEY.encode("utf8")) + # https://openid.net/specs/openid-connect-core-1_0.html#StandardClaims + claims.update(**self.get_additional_claims(request)) - # TODO: http://openid.net/specs/openid-connect-core-1_0.html#HybridIDToken2 - # Save the id_token on database bound to code when the request come to - # Authorization Endpoint and return the same one when request come to - # Token Endpoint + return claims - # TODO: Check if at this point this request parameters are alredy validated + def get_id_token_dictionary(self, token, token_handler, request): + """ + Get the claims to put in the ID Token. - expiration_time = timezone.now() + timedelta( - seconds=oauth2_settings.ID_TOKEN_EXPIRE_SECONDS - ) + These claims are in addition to the claims automatically added by + ``oauthlib`` - aud, iat, nonce, at_hash, c_hash. + + This function adds in iss, exp and auth_time, plus any claims added from + calling ``get_oidc_claims()`` + """ + claims = self.get_oidc_claims(token, token_handler, request) + + expiration_time = timezone.now() + timedelta(seconds=oauth2_settings.ID_TOKEN_EXPIRE_SECONDS) # Required ID Token claims - claims = { - "iss": oauth2_settings.OIDC_ISS_ENDPOINT, - "sub": str(request.user.id), - "aud": request.client_id, - "exp": int(dateformat.format(expiration_time, "U")), - "iat": int(dateformat.format(datetime.utcnow(), "U")), - "auth_time": int(dateformat.format(request.user.last_login, "U")), - } + claims.update( + **{ + "iss": self.get_oidc_issuer_endpoint(request), + "exp": int(dateformat.format(expiration_time, "U")), + "auth_time": int(dateformat.format(request.user.last_login, "U")), + "jti": str(uuid.uuid4()), + } + ) - nonce = getattr(request, "nonce", None) - if nonce: - claims["nonce"] = nonce - - # TODO: create a function to check if we should add at_hash - # http://openid.net/specs/openid-connect-core-1_0.html#CodeIDToken - # http://openid.net/specs/openid-connect-core-1_0.html#ImplicitIDToken - # if request.grant_type in 'authorization_code' and 'access_token' in token: - if ( - (request.grant_type == "authorization_code" and "access_token" in token) - or request.response_type == "code id_token token" - or (request.response_type == "id_token token" and "access_token" in token) - ): - acess_token = token["access_token"] - sha256 = hashlib.sha256(acess_token.encode("ascii")) - bits128 = sha256.hexdigest()[:16] - at_hash = base64.urlsafe_b64encode(bits128.encode("ascii")) - claims["at_hash"] = at_hash.decode("utf8") - - # TODO: create a function to check if we should include c_hash - # http://openid.net/specs/openid-connect-core-1_0.html#HybridIDToken - if request.response_type in ("code id_token", "code id_token token"): - code = token["code"] - sha256 = hashlib.sha256(code.encode("ascii")) - bits256 = sha256.hexdigest()[:32] - c_hash = base64.urlsafe_b64encode(bits256.encode("ascii")) - claims["c_hash"] = c_hash.decode("utf8") + return claims, expiration_time + + def get_oidc_issuer_endpoint(self, request): + return oauth2_settings.oidc_issuer(request) + + def finalize_id_token(self, id_token, token, token_handler, request): + claims, expiration_time = self.get_id_token_dictionary(token, token_handler, request) + id_token.update(**claims) + # Workaround for oauthlib bug #746 + # https://github.com/oauthlib/oauthlib/issues/746 + if "nonce" not in id_token and request.nonce: + id_token["nonce"] = request.nonce + + header = { + "typ": "JWT", + "alg": request.client.algorithm, + } + # RS256 consumers expect a kid in the header for verifying the token + if request.client.algorithm == AbstractApplication.RS256_ALGORITHM: + header["kid"] = request.client.jwk_key.thumbprint() jwt_token = jwt.JWT( - header=json.dumps({"alg": "RS256"}, default=str), - claims=json.dumps(claims, default=str), + header=json.dumps(header, default=str), + claims=json.dumps(id_token, default=str), ) - jwt_token.make_signed_token(key) - - id_token = self._save_id_token(jwt_token, request, expiration_time) + jwt_token.make_signed_token(request.client.jwk_key) + id_token = self._save_id_token(id_token["jti"], request, expiration_time) # this is needed by django rest framework request.access_token = id_token request.id_token = id_token @@ -829,22 +804,54 @@ def validate_id_token(self, token, scopes, request): if not token: return False - key = jwk.JWK.from_pem(oauth2_settings.OIDC_RSA_PRIVATE_KEY.encode("utf8")) + id_token = self._load_id_token(token) + if not id_token: + return False + if not id_token.allow_scopes(scopes): + return False + + request.client = id_token.application + request.user = id_token.user + request.scopes = scopes + # this is needed by django rest framework + request.access_token = id_token + return True + + def _load_id_token(self, token): + key = self._get_key_for_token(token) + if not key: + return None try: jwt_token = jwt.JWT(key=key, jwt=token) - id_token = IDToken.objects.get(token=jwt_token.serialize()) - request.client = id_token.application - request.user = id_token.user - request.scopes = scopes - # this is needed by django rest framework - request.access_token = id_token - return True - except (JWException, JWTExpired): - # TODO: This is the base exception of all jwcrypto - return False + claims = json.loads(jwt_token.claims) + return IDToken.objects.get(jti=claims["jti"]) + except (JWException, JWTExpired, IDToken.DoesNotExist): + return None - return False + def _get_key_for_token(self, token): + """ + Peek at the unvalidated token to discover who it was issued for + and then use that to load that application and its key. + """ + unverified_token = jws.JWS() + unverified_token.deserialize(token) + claims = json.loads(unverified_token.objects["payload"].decode("utf-8")) + if "aud" not in claims: + return None + application = self._get_client_by_audience(claims["aud"]) + if application: + return application.jwk_key + + def _get_client_by_audience(self, audience): + """ + Load a client by the aud claim in a JWT. + aud may be multi-valued, if your provider makes it so. + This function is separate to allow further customization. + """ + if isinstance(audience, str): + audience = [audience] + return Application.objects.filter(client_id__in=audience).first() def validate_user_match(self, id_token_hint, scopes, claims, request): # TODO: Fix to validate when necessary acording @@ -853,7 +860,7 @@ def validate_user_match(self, id_token_hint, scopes, claims, request): return True def get_authorization_code_nonce(self, client_id, code, redirect_uri, request): - """ Extracts nonce from saved authorization code. + """Extracts nonce from saved authorization code. If present in the Authentication Request, Authorization Servers MUST include a nonce Claim in the ID Token with the Claim Value being the nonce value sent in the Authentication @@ -870,5 +877,17 @@ def get_authorization_code_nonce(self, client_id, code, redirect_uri, request): Method is used by: - Authorization Token Grant Dispatcher """ - # TODO: Fix this ;) - return "" + nonce = Grant.objects.filter(code=code).values_list("nonce", flat=True).first() + if nonce: + return nonce + + def get_userinfo_claims(self, request): + """ + Generates and saves a new JWT for this request, and returns it as the + current user's claims. + + """ + return self.get_oidc_claims(None, None, request) + + def get_additional_claims(self, request): + return {} diff --git a/oauth2_provider/scopes.py b/oauth2_provider/scopes.py index d30f43e..5fc1276 100644 --- a/oauth2_provider/scopes.py +++ b/oauth2_provider/scopes.py @@ -1,7 +1,7 @@ from .settings import oauth2_settings -class BaseScopes(object): +class BaseScopes: def get_all_scopes(self): """ Return a dict-like object with all the scopes available in the diff --git a/oauth2_provider/settings.py b/oauth2_provider/settings.py index d770cbd..b862fca 100644 --- a/oauth2_provider/settings.py +++ b/oauth2_provider/settings.py @@ -15,27 +15,23 @@ OAuth2 Provider settings, checking for user settings first, then falling back to the defaults. """ -import importlib from django.conf import settings from django.core.exceptions import ImproperlyConfigured +from django.http import HttpRequest +from django.test.signals import setting_changed +from django.urls import reverse +from django.utils.module_loading import import_string +from oauthlib.common import Request USER_SETTINGS = getattr(settings, "OAUTH2_PROVIDER", None) -APPLICATION_MODEL = getattr( - settings, "OAUTH2_PROVIDER_APPLICATION_MODEL", "oauth2_provider.Application" -) -ACCESS_TOKEN_MODEL = getattr( - settings, "OAUTH2_PROVIDER_ACCESS_TOKEN_MODEL", "oauth2_provider.AccessToken" -) -ID_TOKEN_MODEL = getattr( - settings, "OAUTH2_PROVIDER_ID_TOKEN_MODEL", "oauth2_provider.IDToken" -) +APPLICATION_MODEL = getattr(settings, "OAUTH2_PROVIDER_APPLICATION_MODEL", "oauth2_provider.Application") +ACCESS_TOKEN_MODEL = getattr(settings, "OAUTH2_PROVIDER_ACCESS_TOKEN_MODEL", "oauth2_provider.AccessToken") +ID_TOKEN_MODEL = getattr(settings, "OAUTH2_PROVIDER_ID_TOKEN_MODEL", "oauth2_provider.IDToken") GRANT_MODEL = getattr(settings, "OAUTH2_PROVIDER_GRANT_MODEL", "oauth2_provider.Grant") -REFRESH_TOKEN_MODEL = getattr( - settings, "OAUTH2_PROVIDER_REFRESH_TOKEN_MODEL", "oauth2_provider.RefreshToken" -) +REFRESH_TOKEN_MODEL = getattr(settings, "OAUTH2_PROVIDER_REFRESH_TOKEN_MODEL", "oauth2_provider.RefreshToken") DEFAULTS = { "CLIENT_ID_GENERATOR_CLASS": "oauth2_provider.generators.ClientIdGenerator", @@ -44,7 +40,8 @@ "ACCESS_TOKEN_GENERATOR": None, "REFRESH_TOKEN_GENERATOR": None, "EXTRA_SERVER_KWARGS": {}, - "OAUTH2_SERVER_CLASS": "oauthlib.openid.connect.core.endpoints.pre_configured.Server", + "OAUTH2_SERVER_CLASS": "oauthlib.oauth2.Server", + "OIDC_SERVER_CLASS": "oauthlib.openid.Server", "OAUTH2_VALIDATOR_CLASS": "oauth2_provider.oauth2_validators.OAuth2Validator", "OAUTH2_BACKEND_CLASS": "oauth2_provider.oauth2_backends.OAuthLibCore", "SCOPES": {"read": "Reading scope", "write": "Writing scope"}, @@ -64,8 +61,14 @@ "ID_TOKEN_MODEL": ID_TOKEN_MODEL, "GRANT_MODEL": GRANT_MODEL, "REFRESH_TOKEN_MODEL": REFRESH_TOKEN_MODEL, + "APPLICATION_ADMIN_CLASS": "oauth2_provider.admin.ApplicationAdmin", + "ACCESS_TOKEN_ADMIN_CLASS": "oauth2_provider.admin.AccessTokenAdmin", + "GRANT_ADMIN_CLASS": "oauth2_provider.admin.GrantAdmin", + "ID_TOKEN_ADMIN_CLASS": "oauth2_provider.admin.IDTokenAdmin", + "REFRESH_TOKEN_ADMIN_CLASS": "oauth2_provider.admin.RefreshTokenAdmin", "REQUEST_APPROVAL_PROMPT": "force", "ALLOWED_REDIRECT_URI_SCHEMES": ["http", "https"], + "OIDC_ENABLED": False, "OIDC_ISS_ENDPOINT": "", "OIDC_USERINFO_ENDPOINT": "", "OIDC_RSA_PRIVATE_KEY": "", @@ -79,7 +82,6 @@ "code id_token token", ], "OIDC_SUBJECT_TYPES_SUPPORTED": ["public"], - "OIDC_ID_TOKEN_SIGNING_ALG_VALUES_SUPPORTED": ["RS256", "HS256"], "OIDC_TOKEN_ENDPOINT_AUTH_METHODS_SUPPORTED": [ "client_secret_post", "client_secret_basic", @@ -94,6 +96,9 @@ "RESOURCE_SERVER_TOKEN_CACHING_SECONDS": 36000, # Whether or not PKCE is required "PKCE_REQUIRED": False, + # Whether to re-create OAuthlibCore on every request. + # Should only be required in testing. + "ALWAYS_RELOAD_OAUTHLIB_CORE": False, } # List of settings that cannot be empty @@ -105,12 +110,8 @@ "OAUTH2_BACKEND_CLASS", "SCOPES", "ALLOWED_REDIRECT_URI_SCHEMES", - "OIDC_ISS_ENDPOINT", - "OIDC_USERINFO_ENDPOINT", - "OIDC_RSA_PRIVATE_KEY", "OIDC_RESPONSE_TYPES_SUPPORTED", "OIDC_SUBJECT_TYPES_SUPPORTED", - "OIDC_ID_TOKEN_SIGNING_ALG_VALUES_SUPPORTED", "OIDC_TOKEN_ENDPOINT_AUTH_METHODS_SUPPORTED", ) @@ -124,6 +125,11 @@ "OAUTH2_VALIDATOR_CLASS", "OAUTH2_BACKEND_CLASS", "SCOPES_BACKEND_CLASS", + "APPLICATION_ADMIN_CLASS", + "ACCESS_TOKEN_ADMIN_CLASS", + "GRANT_ADMIN_CLASS", + "ID_TOKEN_ADMIN_CLASS", + "REFRESH_TOKEN_ADMIN_CLASS", ) @@ -132,12 +138,13 @@ def perform_import(val, setting_name): If the given setting is a string import notation, then perform the necessary import or imports. """ - if isinstance(val, (list, tuple)): - return [import_from_string(item, setting_name) for item in val] - elif "." in val: + if val is None: + return None + elif isinstance(val, str): return import_from_string(val, setting_name) - else: - raise ImproperlyConfigured("Bad value for %r: %r" % (setting_name, val)) + elif isinstance(val, (list, tuple)): + return [import_from_string(item, setting_name) for item in val] + return val def import_from_string(val, setting_name): @@ -145,21 +152,20 @@ def import_from_string(val, setting_name): Attempt to import a class from a string representation. """ try: - parts = val.split(".") - module_path, class_name = ".".join(parts[:-1]), parts[-1] - module = importlib.import_module(module_path) - return getattr(module, class_name) + return import_string(val) except ImportError as e: - msg = "Could not import %r for setting %r. %s: %s." % ( - val, - setting_name, - e.__class__.__name__, - e, - ) + msg = "Could not import %r for setting %r. %s: %s." % (val, setting_name, e.__class__.__name__, e) raise ImportError(msg) -class OAuth2ProviderSettings(object): +class _PhonyHttpRequest(HttpRequest): + _scheme = "http" + + def _get_scheme(self): + return self._scheme + + +class OAuth2ProviderSettings: """ A settings object, that allows OAuth2 Provider settings to be accessed as properties. @@ -167,24 +173,33 @@ class OAuth2ProviderSettings(object): and return the class, rather than the string literal. """ - def __init__( - self, user_settings=None, defaults=None, import_strings=None, mandatory=None - ): - self.user_settings = user_settings or {} - self.defaults = defaults or {} - self.import_strings = import_strings or () + def __init__(self, user_settings=None, defaults=None, import_strings=None, mandatory=None): + self._user_settings = user_settings or {} + self.defaults = defaults or DEFAULTS + self.import_strings = import_strings or IMPORT_STRINGS self.mandatory = mandatory or () + self._cached_attrs = set() - def __getattr__(self, attr): - if attr not in self.defaults.keys(): - raise AttributeError("Invalid OAuth2Provider setting: %r" % (attr)) + @property + def user_settings(self): + if not hasattr(self, "_user_settings"): + self._user_settings = getattr(settings, "OAUTH2_PROVIDER", {}) + return self._user_settings + def __getattr__(self, attr): + if attr not in self.defaults: + raise AttributeError("Invalid OAuth2Provider setting: %s" % attr) try: # Check if present in user settings val = self.user_settings[attr] except KeyError: # Fall back to defaults - val = self.defaults[attr] + # Special case OAUTH2_SERVER_CLASS - if not specified, and OIDC is + # enabled, use the OIDC_SERVER_CLASS setting instead + if attr == "OAUTH2_SERVER_CLASS" and self.OIDC_ENABLED: + val = self.defaults["OIDC_SERVER_CLASS"] + else: + val = self.defaults[attr] # Coerce import strings into classes if val and attr in self.import_strings: @@ -204,19 +219,18 @@ def __getattr__(self, attr): if scope in self._SCOPES: val.append(scope) else: - raise ImproperlyConfigured( - "Defined DEFAULT_SCOPES not present in SCOPES" - ) + raise ImproperlyConfigured("Defined DEFAULT_SCOPES not present in SCOPES") self.validate_setting(attr, val) # Cache the result + self._cached_attrs.add(attr) setattr(self, attr, val) return val def validate_setting(self, attr, val): if not val and attr in self.mandatory: - raise AttributeError("OAuth2Provider setting: %r is mandatory" % (attr)) + raise AttributeError("OAuth2Provider setting: %s is mandatory" % attr) @property def server_kwargs(self): @@ -244,5 +258,43 @@ def server_kwargs(self): kwargs.update(self.EXTRA_SERVER_KWARGS) return kwargs + def reload(self): + for attr in self._cached_attrs: + delattr(self, attr) + self._cached_attrs.clear() + if hasattr(self, "_user_settings"): + delattr(self, "_user_settings") + + def oidc_issuer(self, request): + """ + Helper function to get the OIDC issuer URL, either from the settings + or constructing it from the passed request. + + If only an oauthlib request is available, a dummy django request is + built from that and used to generate the URL. + """ + if self.OIDC_ISS_ENDPOINT: + return self.OIDC_ISS_ENDPOINT + if isinstance(request, HttpRequest): + django_request = request + elif isinstance(request, Request): + django_request = _PhonyHttpRequest() + django_request.META = request.headers + if request.headers.get("X_DJANGO_OAUTH_TOOLKIT_SECURE", False): + django_request._scheme = "https" + else: + raise TypeError("request must be a django or oauthlib request: got %r" % request) + abs_url = django_request.build_absolute_uri(reverse("oauth2_provider:oidc-connect-discovery-info")) + return abs_url[: -len("/.well-known/openid-configuration/")] + oauth2_settings = OAuth2ProviderSettings(USER_SETTINGS, DEFAULTS, IMPORT_STRINGS, MANDATORY) + + +def reload_oauth2_settings(*args, **kwargs): + setting = kwargs["setting"] + if setting == "OAUTH2_PROVIDER": + oauth2_settings.reload() + + +setting_changed.connect(reload_oauth2_settings) diff --git a/oauth2_provider/signals.py b/oauth2_provider/signals.py index 060db8c..1640bda 100644 --- a/oauth2_provider/signals.py +++ b/oauth2_provider/signals.py @@ -1,4 +1,4 @@ from django.dispatch import Signal -app_authorized = Signal(providing_args=["request", "token"]) +app_authorized = Signal() # providing_args=["request", "token"] diff --git a/oauth2_provider/templates/oauth2_provider/application_confirm_delete.html b/oauth2_provider/templates/oauth2_provider/application_confirm_delete.html index 35b961a..4716dc5 100644 --- a/oauth2_provider/templates/oauth2_provider/application_confirm_delete.html +++ b/oauth2_provider/templates/oauth2_provider/application_confirm_delete.html @@ -10,7 +10,7 @@

    {% trans "Are you sure to delete the applicatio diff --git a/oauth2_provider/templates/oauth2_provider/application_form.html b/oauth2_provider/templates/oauth2_provider/application_form.html index 43926e1..dd8a644 100644 --- a/oauth2_provider/templates/oauth2_provider/application_form.html +++ b/oauth2_provider/templates/oauth2_provider/application_form.html @@ -34,7 +34,7 @@

    {% trans "Go Back" %} - + diff --git a/oauth2_provider/templates/oauth2_provider/application_list.html b/oauth2_provider/templates/oauth2_provider/application_list.html index 34b299a..b8e4f3a 100644 --- a/oauth2_provider/templates/oauth2_provider/application_list.html +++ b/oauth2_provider/templates/oauth2_provider/application_list.html @@ -11,8 +11,9 @@

    {% trans "Your applications" %}

    {% endfor %}
- New Application + {% trans "New Application" %} {% else %} +

{% trans "No applications defined" %}. {% trans "Click here" %} {% trans "if you want to register a new one" %}

{% endif %} diff --git a/oauth2_provider/templates/oauth2_provider/authorize.html b/oauth2_provider/templates/oauth2_provider/authorize.html index 6e6a2a9..dcbcda7 100644 --- a/oauth2_provider/templates/oauth2_provider/authorize.html +++ b/oauth2_provider/templates/oauth2_provider/authorize.html @@ -14,7 +14,7 @@

{% trans "Authorize" %} {{ application.name }}? {% endif %} {% endfor %} -

{% trans "Application requires following permissions" %}

+

{% trans "Application requires the following permissions" %}

    {% for scope in scopes_descriptions %}
  • {{ scope }}
  • @@ -26,8 +26,8 @@

    {% trans "Authorize" %} {{ application.name }}?
    - - + +
    diff --git a/oauth2_provider/templates/oauth2_provider/authorized-token-delete.html b/oauth2_provider/templates/oauth2_provider/authorized-token-delete.html index e08233a..02a6ff4 100644 --- a/oauth2_provider/templates/oauth2_provider/authorized-token-delete.html +++ b/oauth2_provider/templates/oauth2_provider/authorized-token-delete.html @@ -4,6 +4,6 @@ {% block content %}
    {% csrf_token %}

    {% trans "Are you sure you want to delete this token?" %}

    - +
    {% endblock %} diff --git a/oauth2_provider/templates/oauth2_provider/authorized-tokens.html b/oauth2_provider/templates/oauth2_provider/authorized-tokens.html index 2c6a028..0f27325 100644 --- a/oauth2_provider/templates/oauth2_provider/authorized-tokens.html +++ b/oauth2_provider/templates/oauth2_provider/authorized-tokens.html @@ -8,7 +8,7 @@

    {% trans "Tokens" %}

    {% for authorized_token in authorized_tokens %}
  • {{ authorized_token.application }} - (revoke) + ({% trans "revoke" %})
    • {% for scope_name, scope_description in authorized_token.scopes.items %} diff --git a/oauth2_provider/urls.py b/oauth2_provider/urls.py index 4baef47..508f97c 100644 --- a/oauth2_provider/urls.py +++ b/oauth2_provider/urls.py @@ -1,4 +1,4 @@ -from django.conf.urls import url +from django.urls import re_path from . import views @@ -7,31 +7,37 @@ base_urlpatterns = [ - url(r"^authorize/$", views.AuthorizationView.as_view(), name="authorize"), - url(r"^token/$", views.TokenView.as_view(), name="token"), - url(r"^revoke_token/$", views.RevokeTokenView.as_view(), name="revoke-token"), - url(r"^introspect/$", views.IntrospectTokenView.as_view(), name="introspect"), + re_path(r"^authorize/$", views.AuthorizationView.as_view(), name="authorize"), + re_path(r"^token/$", views.TokenView.as_view(), name="token"), + re_path(r"^revoke_token/$", views.RevokeTokenView.as_view(), name="revoke-token"), + re_path(r"^introspect/$", views.IntrospectTokenView.as_view(), name="introspect"), ] management_urlpatterns = [ # Application management views - url(r"^applications/$", views.ApplicationList.as_view(), name="list"), - url(r"^applications/register/$", views.ApplicationRegistration.as_view(), name="register"), - url(r"^applications/(?P[\w-]+)/$", views.ApplicationDetail.as_view(), name="detail"), - url(r"^applications/(?P[\w-]+)/delete/$", views.ApplicationDelete.as_view(), name="delete"), - url(r"^applications/(?P[\w-]+)/update/$", views.ApplicationUpdate.as_view(), name="update"), + re_path(r"^applications/$", views.ApplicationList.as_view(), name="list"), + re_path(r"^applications/register/$", views.ApplicationRegistration.as_view(), name="register"), + re_path(r"^applications/(?P[\w-]+)/$", views.ApplicationDetail.as_view(), name="detail"), + re_path(r"^applications/(?P[\w-]+)/delete/$", views.ApplicationDelete.as_view(), name="delete"), + re_path(r"^applications/(?P[\w-]+)/update/$", views.ApplicationUpdate.as_view(), name="update"), # Token management views - url(r"^authorized_tokens/$", views.AuthorizedTokensListView.as_view(), name="authorized-token-list"), - url(r"^authorized_tokens/(?P[\w-]+)/delete/$", views.AuthorizedTokenDeleteView.as_view(), - name="authorized-token-delete"), + re_path(r"^authorized_tokens/$", views.AuthorizedTokensListView.as_view(), name="authorized-token-list"), + re_path( + r"^authorized_tokens/(?P[\w-]+)/delete/$", + views.AuthorizedTokenDeleteView.as_view(), + name="authorized-token-delete", + ), ] oidc_urlpatterns = [ - url(r"^\.well-known/openid-configuration/$", views.ConnectDiscoveryInfoView.as_view(), - name="oidc-connect-discovery-info"), - url(r"^jwks/$", views.JwksInfoView.as_view(), name="jwks-info"), - url(r"^userinfo/$", views.UserInfoView.as_view(), name="user-info") + re_path( + r"^\.well-known/openid-configuration/$", + views.ConnectDiscoveryInfoView.as_view(), + name="oidc-connect-discovery-info", + ), + re_path(r"^\.well-known/jwks.json$", views.JwksInfoView.as_view(), name="jwks-info"), + re_path(r"^userinfo/$", views.UserInfoView.as_view(), name="user-info"), ] diff --git a/oauth2_provider/validators.py b/oauth2_provider/validators.py index 4a4fabf..6c8fa38 100644 --- a/oauth2_provider/validators.py +++ b/oauth2_provider/validators.py @@ -10,12 +10,9 @@ class URIValidator(URLValidator): scheme_re = r"^(?:[a-z][a-z0-9\.\-\+]*)://" dotless_domain_re = r"(?!-)[A-Z\d-]{1,63}(?= 2.2.0 + django >= 2.2 requests >= 2.13.0 - oauthlib >= 3.0.1 - jwcrypto >= 0.4.2 + oauthlib >= 3.1.0 + jwcrypto >= 0.8.0 + six [options.packages.find] exclude = tests diff --git a/tests/admin.py b/tests/admin.py new file mode 100644 index 0000000..f071769 --- /dev/null +++ b/tests/admin.py @@ -0,0 +1,21 @@ +from django.contrib import admin + + +class CustomApplicationAdmin(admin.ModelAdmin): + list_display = ("id",) + + +class CustomAccessTokenAdmin(admin.ModelAdmin): + list_display = ("id",) + + +class CustomGrantAdmin(admin.ModelAdmin): + list_display = ("id",) + + +class CustomIDTokenAdmin(admin.ModelAdmin): + list_display = ("id",) + + +class CustomRefreshTokenAdmin(admin.ModelAdmin): + list_display = ("id",) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..a3274aa --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,156 @@ +from types import SimpleNamespace +from urllib.parse import parse_qs, urlparse + +import pytest +from django.conf import settings as test_settings +from django.contrib.auth import get_user_model +from django.urls import reverse +from jwcrypto import jwk + +from oauth2_provider.models import get_application_model +from oauth2_provider.settings import oauth2_settings as _oauth2_settings + +from . import presets + + +Application = get_application_model() +UserModel = get_user_model() + + +class OAuthSettingsWrapper: + """ + A wrapper around oauth2_settings to ensure that when an overridden value is + set, it also records it in _cached_attrs, so that the settings can be reset. + """ + + def __init__(self, settings, user_settings): + self.settings = settings + if not user_settings: + user_settings = {} + self.update(user_settings) + + def update(self, user_settings): + self.settings.OAUTH2_PROVIDER = user_settings + _oauth2_settings.reload() + # Reload OAuthlibCore for every view request during tests + self.ALWAYS_RELOAD_OAUTHLIB_CORE = True + + def __setattr__(self, attr, value): + if attr == "settings": + super().__setattr__(attr, value) + else: + setattr(_oauth2_settings, attr, value) + _oauth2_settings._cached_attrs.add(attr) + + def __delattr__(self, attr): + delattr(_oauth2_settings, attr) + if attr in _oauth2_settings._cached_attrs: + _oauth2_settings._cached_attrs.remove(attr) + + def __getattr__(self, attr): + return getattr(_oauth2_settings, attr) + + def finalize(self): + self.settings.finalize() + _oauth2_settings.reload() + + +@pytest.fixture +def oauth2_settings(request, settings): + """ + A fixture that provides a simple way to override OAUTH2_PROVIDER settings. + + It can be used two ways - either setting things on the fly, or by reading + configuration data from the pytest marker oauth2_settings. + + If used on a standard pytest function, you can use argument dependency + injection to get the wrapper. If used on a unittest.TestCase, the wrapper + is made available on the class instance, as `oauth2_settings`. + + Anything overridden will be restored at the end of the test case, ensuring + that there is no configuration leakage between test cases. + """ + marker = request.node.get_closest_marker("oauth2_settings") + user_settings = {} + if marker is not None: + user_settings = marker.args[0] + wrapper = OAuthSettingsWrapper(settings, user_settings) + if request.instance is not None: + request.instance.oauth2_settings = wrapper + yield wrapper + wrapper.finalize() + + +@pytest.fixture(scope="session") +def oidc_key_(): + return jwk.JWK.from_pem(test_settings.OIDC_RSA_PRIVATE_KEY.encode("utf8")) + + +@pytest.fixture +def oidc_key(request, oidc_key_): + if request.instance is not None: + request.instance.key = oidc_key_ + return oidc_key_ + + +@pytest.fixture +def application(): + return Application.objects.create( + name="Test Application", + redirect_uris="http://example.org", + client_type=Application.CLIENT_CONFIDENTIAL, + authorization_grant_type=Application.GRANT_AUTHORIZATION_CODE, + algorithm=Application.RS256_ALGORITHM, + ) + + +@pytest.fixture +def hybrid_application(application): + application.authorization_grant_type = application.GRANT_OPENID_HYBRID + application.save() + return application + + +@pytest.fixture +def test_user(): + return UserModel.objects.create_user("test_user", "test@example.com", "123456") + + +@pytest.fixture +def oidc_tokens(oauth2_settings, application, test_user, client): + oauth2_settings.update(presets.OIDC_SETTINGS_RW) + client.force_login(test_user) + auth_rsp = client.post( + reverse("oauth2_provider:authorize"), + data={ + "client_id": application.client_id, + "state": "random_state_string", + "scope": "openid", + "redirect_uri": "http://example.org", + "response_type": "code", + "allow": True, + }, + ) + assert auth_rsp.status_code == 302 + code = parse_qs(urlparse(auth_rsp["Location"]).query)["code"] + client.logout() + token_rsp = client.post( + reverse("oauth2_provider:token"), + data={ + "grant_type": "authorization_code", + "code": code, + "redirect_uri": "http://example.org", + "client_id": application.client_id, + "client_secret": application.client_secret, + "scope": "openid", + }, + ) + assert token_rsp.status_code == 200 + token_data = token_rsp.json() + return SimpleNamespace( + user=test_user, + application=application, + access_token=token_data["access_token"], + id_token=token_data["id_token"], + oauth2_settings=oauth2_settings, + ) diff --git a/tests/migrations/0001_initial.py b/tests/migrations/0001_initial.py index eef6dba..8903a5a 100644 --- a/tests/migrations/0001_initial.py +++ b/tests/migrations/0001_initial.py @@ -33,6 +33,8 @@ class Migration(migrations.Migration): ('custom_field', models.CharField(max_length=255)), ('application', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.OAUTH2_PROVIDER_APPLICATION_MODEL)), ('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='tests_samplegrant', to=settings.AUTH_USER_MODEL)), + ("nonce", models.CharField(blank=True, max_length=255, default="")), + ("claims", models.TextField(blank=True)), ], options={ 'abstract': False, diff --git a/tests/models.py b/tests/models.py index 7ca0c57..32f9a1b 100644 --- a/tests/models.py +++ b/tests/models.py @@ -1,8 +1,10 @@ from django.db import models from oauth2_provider.models import ( - AbstractAccessToken, AbstractApplication, - AbstractGrant, AbstractRefreshToken + AbstractAccessToken, + AbstractApplication, + AbstractGrant, + AbstractRefreshToken, ) from oauth2_provider.settings import oauth2_settings @@ -13,7 +15,7 @@ class BaseTestApplication(AbstractApplication): def get_allowed_schemes(self): if self.allowed_schemes: return self.allowed_schemes.split() - return super(BaseTestApplication, self).get_allowed_schemes() + return super().get_allowed_schemes() class SampleApplication(AbstractApplication): @@ -24,16 +26,22 @@ class SampleAccessToken(AbstractAccessToken): custom_field = models.CharField(max_length=255) source_refresh_token = models.OneToOneField( # unique=True implied by the OneToOneField - oauth2_settings.REFRESH_TOKEN_MODEL, on_delete=models.SET_NULL, blank=True, null=True, - related_name="s_refreshed_access_token" + oauth2_settings.REFRESH_TOKEN_MODEL, + on_delete=models.SET_NULL, + blank=True, + null=True, + related_name="s_refreshed_access_token", ) class SampleRefreshToken(AbstractRefreshToken): custom_field = models.CharField(max_length=255) access_token = models.OneToOneField( - oauth2_settings.ACCESS_TOKEN_MODEL, on_delete=models.SET_NULL, blank=True, null=True, - related_name="s_refresh_token" + oauth2_settings.ACCESS_TOKEN_MODEL, + on_delete=models.SET_NULL, + blank=True, + null=True, + related_name="s_refresh_token", ) diff --git a/tests/presets.py b/tests/presets.py new file mode 100644 index 0000000..214f804 --- /dev/null +++ b/tests/presets.py @@ -0,0 +1,45 @@ +from copy import deepcopy + +from django.conf import settings + + +# A set of OAUTH2_PROVIDER settings dicts that can be used in tests + +DEFAULT_SCOPES_RW = {"DEFAULT_SCOPES": ["read", "write"]} +DEFAULT_SCOPES_RO = {"DEFAULT_SCOPES": ["read"]} +OIDC_SETTINGS_RW = { + "OIDC_ENABLED": True, + "OIDC_ISS_ENDPOINT": "http://localhost/o", + "OIDC_USERINFO_ENDPOINT": "http://localhost/o/userinfo/", + "OIDC_RSA_PRIVATE_KEY": settings.OIDC_RSA_PRIVATE_KEY, + "SCOPES": { + "read": "Reading scope", + "write": "Writing scope", + "openid": "OpenID connect", + }, + "DEFAULT_SCOPES": ["read", "write"], +} +OIDC_SETTINGS_RO = deepcopy(OIDC_SETTINGS_RW) +OIDC_SETTINGS_RO["DEFAULT_SCOPES"] = ["read"] +OIDC_SETTINGS_HS256_ONLY = deepcopy(OIDC_SETTINGS_RW) +del OIDC_SETTINGS_HS256_ONLY["OIDC_RSA_PRIVATE_KEY"] +REST_FRAMEWORK_SCOPES = { + "SCOPES": { + "read": "Read scope", + "write": "Write scope", + "scope1": "Scope 1", + "scope2": "Scope 2", + "resource1": "Resource 1", + }, +} +INTROSPECTION_SETTINGS = { + "SCOPES": { + "read": "Read scope", + "write": "Write scope", + "introspection": "Introspection scope", + "dolphin": "eek eek eek scope", + }, + "RESOURCE_SERVER_INTROSPECTION_URL": "http://example.org/introspection", + "READ_SCOPE": "read", + "WRITE_SCOPE": "write", +} diff --git a/tests/settings.py b/tests/settings.py index edd1ae6..1d29598 100644 --- a/tests/settings.py +++ b/tests/settings.py @@ -80,7 +80,6 @@ "django.contrib.staticfiles", "django.contrib.admin", "django.contrib.messages", - "oauth2_provider", "tests", ) @@ -89,29 +88,17 @@ "version": 1, "disable_existing_loggers": False, "formatters": { - "verbose": { - "format": "%(levelname)s %(asctime)s %(module)s %(process)d %(thread)d %(message)s" - }, - "simple": { - "format": "%(levelname)s %(message)s" - }, - }, - "filters": { - "require_debug_false": { - "()": "django.utils.log.RequireDebugFalse" - } + "verbose": {"format": "%(levelname)s %(asctime)s %(module)s %(process)d %(thread)d %(message)s"}, + "simple": {"format": "%(levelname)s %(message)s"}, }, + "filters": {"require_debug_false": {"()": "django.utils.log.RequireDebugFalse"}}, "handlers": { "mail_admins": { "level": "ERROR", "filters": ["require_debug_false"], - "class": "django.utils.log.AdminEmailHandler" - }, - "console": { - "level": "DEBUG", - "class": "logging.StreamHandler", - "formatter": "simple" + "class": "django.utils.log.AdminEmailHandler", }, + "console": {"level": "DEBUG", "class": "logging.StreamHandler", "formatter": "simple"}, "null": { "level": "DEBUG", "class": "logging.NullHandler", @@ -128,7 +115,7 @@ "level": "DEBUG", "propagate": True, }, - } + }, } OIDC_RSA_PRIVATE_KEY = """-----BEGIN RSA PRIVATE KEY----- @@ -147,12 +134,6 @@ dTnvCVtA59ne4LEVie/PMH/odQWY0SxVm/76uBZv/1vY -----END RSA PRIVATE KEY-----""" -OAUTH2_PROVIDER = { - "OIDC_ISS_ENDPOINT": "http://localhost", - "OIDC_USERINFO_ENDPOINT": "http://localhost/userinfo/", - "OIDC_RSA_PRIVATE_KEY": OIDC_RSA_PRIVATE_KEY, -} - OAUTH2_PROVIDER_ACCESS_TOKEN_MODEL = "oauth2_provider.AccessToken" OAUTH2_PROVIDER_APPLICATION_MODEL = "oauth2_provider.Application" OAUTH2_PROVIDER_REFRESH_TOKEN_MODEL = "oauth2_provider.RefreshToken" diff --git a/tests/test_application_views.py b/tests/test_application_views.py index 64e112d..42eb17f 100644 --- a/tests/test_application_views.py +++ b/tests/test_application_views.py @@ -1,9 +1,9 @@ +import pytest from django.contrib.auth import get_user_model from django.test import TestCase from django.urls import reverse from oauth2_provider.models import get_application_model -from oauth2_provider.settings import oauth2_settings from oauth2_provider.views.application import ApplicationRegistration from .models import SampleApplication @@ -23,22 +23,19 @@ def tearDown(self): self.bar_user.delete() +@pytest.mark.usefixtures("oauth2_settings") class TestApplicationRegistrationView(BaseTest): - + @pytest.mark.oauth2_settings({"APPLICATION_MODEL": "tests.SampleApplication"}) def test_get_form_class(self): """ Tests that the form class returned by the "get_form_class" method is bound to custom application model defined in the "OAUTH2_PROVIDER_APPLICATION_MODEL" setting. """ - # Patch oauth2 settings to use a custom Application model - oauth2_settings.APPLICATION_MODEL = "tests.SampleApplication" # Create a registration view and tests that the model form is bound # to the custom Application model application_form_class = ApplicationRegistration().get_form_class() self.assertEqual(SampleApplication, application_form_class._meta.model) - # Revert oauth2 settings - oauth2_settings.APPLICATION_MODEL = "oauth2_provider.Application" def test_application_registration_user(self): self.client.login(username="foo_user", password="123456") @@ -50,7 +47,7 @@ def test_application_registration_user(self): "client_type": Application.CLIENT_CONFIDENTIAL, "redirect_uris": "http://example.com", "authorization_grant_type": Application.GRANT_AUTHORIZATION_CODE, - "algorithm": "RS256", + "algorithm": "", } response = self.client.post(reverse("oauth2_provider:register"), form_data) @@ -63,15 +60,16 @@ def test_application_registration_user(self): class TestApplicationViews(BaseTest): def _create_application(self, name, user): app = Application.objects.create( - name=name, redirect_uris="http://example.com", + name=name, + redirect_uris="http://example.com", client_type=Application.CLIENT_CONFIDENTIAL, authorization_grant_type=Application.GRANT_AUTHORIZATION_CODE, - user=user + user=user, ) return app def setUp(self): - super(TestApplicationViews, self).setUp() + super().setUp() self.app_foo_1 = self._create_application("app foo_user 1", self.foo_user) self.app_foo_2 = self._create_application("app foo_user 2", self.foo_user) self.app_foo_3 = self._create_application("app foo_user 3", self.foo_user) @@ -80,7 +78,7 @@ def setUp(self): self.app_bar_2 = self._create_application("app bar_user 2", self.bar_user) def tearDown(self): - super(TestApplicationViews, self).tearDown() + super().tearDown() get_application_model().objects.all().delete() def test_application_list(self): diff --git a/tests/test_auth_backends.py b/tests/test_auth_backends.py index 530caa7..151fc30 100644 --- a/tests/test_auth_backends.py +++ b/tests/test_auth_backends.py @@ -19,17 +19,17 @@ class BaseTest(TestCase): """ Base class for cases in this module """ + def setUp(self): self.user = UserModel.objects.create_user("user", "test@example.com", "123456") self.app = ApplicationModel.objects.create( name="app", client_type=ApplicationModel.CLIENT_CONFIDENTIAL, authorization_grant_type=ApplicationModel.GRANT_CLIENT_CREDENTIALS, - user=self.user + user=self.user, ) self.token = AccessTokenModel.objects.create( - user=self.user, token="tokstr", application=self.app, - expires=now() + timedelta(days=365) + user=self.user, token="tokstr", application=self.app, expires=now() + timedelta(days=365) ) self.factory = RequestFactory() @@ -40,7 +40,6 @@ def tearDown(self): class TestOAuth2Backend(BaseTest): - def test_authenticate(self): auth_headers = { "HTTP_AUTHORIZATION": "Bearer " + "tokstr", @@ -83,58 +82,62 @@ def test_get_user(self): } ) class TestOAuth2Middleware(BaseTest): - def setUp(self): - super(TestOAuth2Middleware, self).setUp() + super().setUp() self.anon_user = AnonymousUser() + def dummy_get_response(self, request): + return HttpResponse() + def test_middleware_wrong_headers(self): - m = OAuth2TokenMiddleware() + m = OAuth2TokenMiddleware(self.dummy_get_response) request = self.factory.get("/a-resource") - self.assertIsNone(m.process_request(request)) + m(request) + self.assertFalse(hasattr(request, "user")) auth_headers = { "HTTP_AUTHORIZATION": "Beerer " + "badstring", # a Beer token for you! } request = self.factory.get("/a-resource", **auth_headers) - self.assertIsNone(m.process_request(request)) + m(request) + self.assertFalse(hasattr(request, "user")) def test_middleware_user_is_set(self): - m = OAuth2TokenMiddleware() + m = OAuth2TokenMiddleware(self.dummy_get_response) auth_headers = { "HTTP_AUTHORIZATION": "Bearer " + "tokstr", } request = self.factory.get("/a-resource", **auth_headers) request.user = self.user - self.assertIsNone(m.process_request(request)) + m(request) + self.assertIs(request.user, self.user) request.user = self.anon_user - self.assertIsNone(m.process_request(request)) + m(request) + self.assertEqual(request.user.pk, self.user.pk) def test_middleware_success(self): - m = OAuth2TokenMiddleware() + m = OAuth2TokenMiddleware(self.dummy_get_response) auth_headers = { "HTTP_AUTHORIZATION": "Bearer " + "tokstr", } request = self.factory.get("/a-resource", **auth_headers) - m.process_request(request) + m(request) self.assertEqual(request.user, self.user) def test_middleware_response(self): - m = OAuth2TokenMiddleware() + m = OAuth2TokenMiddleware(self.dummy_get_response) auth_headers = { "HTTP_AUTHORIZATION": "Bearer " + "tokstr", } request = self.factory.get("/a-resource", **auth_headers) - response = HttpResponse() - processed = m.process_response(request, response) - self.assertIs(response, processed) + response = m(request) + self.assertIsInstance(response, HttpResponse) def test_middleware_response_header(self): - m = OAuth2TokenMiddleware() + m = OAuth2TokenMiddleware(self.dummy_get_response) auth_headers = { "HTTP_AUTHORIZATION": "Bearer " + "tokstr", } request = self.factory.get("/a-resource", **auth_headers) - response = HttpResponse() - m.process_response(request, response) + response = m(request) self.assertIn("Vary", response) self.assertIn("Authorization", response["Vary"]) diff --git a/tests/test_authorization_code.py b/tests/test_authorization_code.py index 0c6c717..ea1bee8 100644 --- a/tests/test_authorization_code.py +++ b/tests/test_authorization_code.py @@ -3,22 +3,26 @@ import hashlib import json import re -from urllib.parse import parse_qs, urlencode, urlparse +from urllib.parse import parse_qs, urlparse +import pytest from django.contrib.auth import get_user_model from django.test import RequestFactory, TestCase from django.urls import reverse from django.utils import timezone from django.utils.crypto import get_random_string +from jwcrypto import jwt from oauthlib.oauth2.rfc6749 import errors as oauthlib_errors from oauth2_provider.models import ( - get_access_token_model, get_application_model, - get_grant_model, get_refresh_token_model + get_access_token_model, + get_application_model, + get_grant_model, + get_refresh_token_model, ) -from oauth2_provider.settings import oauth2_settings from oauth2_provider.views import ProtectedResourceView +from . import presets from .utils import get_basic_auth_header @@ -38,17 +42,14 @@ def get(self, request, *args, **kwargs): return "This is a protected resource" +@pytest.mark.usefixtures("oauth2_settings") class BaseTest(TestCase): def setUp(self): self.factory = RequestFactory() - self.test_user = UserModel.objects.create_user( - "test_user", "test@example.com", "123456" - ) - self.dev_user = UserModel.objects.create_user( - "dev_user", "dev@example.com", "123456" - ) + self.test_user = UserModel.objects.create_user("test_user", "test@example.com", "123456") + self.dev_user = UserModel.objects.create_user("dev_user", "dev@example.com", "123456") - oauth2_settings.ALLOWED_REDIRECT_URI_SCHEMES = ["http", "custom-scheme"] + self.oauth2_settings.ALLOWED_REDIRECT_URI_SCHEMES = ["http", "custom-scheme"] self.application = Application.objects.create( name="Test Application", @@ -61,14 +62,6 @@ def setUp(self): authorization_grant_type=Application.GRANT_AUTHORIZATION_CODE, ) - oauth2_settings._SCOPES = ["read", "write", "openid"] - oauth2_settings._DEFAULT_SCOPES = ["read", "write"] - oauth2_settings.SCOPES = { - "read": "Reading scope", - "write": "Writing scope", - "openid": "OpenID connect", - } - def tearDown(self): self.application.delete() self.test_user.delete() @@ -83,24 +76,21 @@ class TestRegressionIssue315(BaseTest): def test_request_is_not_overwritten(self): self.client.login(username="test_user", password="123456") - query_string = urlencode( + response = self.client.get( + reverse("oauth2_provider:authorize"), { "client_id": self.application.client_id, "response_type": "code", "state": "random_state_string", "scope": "read write", "redirect_uri": "http://example.org", - } + }, ) - url = "{url}?{qs}".format( - url=reverse("oauth2_provider:authorize"), qs=query_string - ) - - response = self.client.get(url) self.assertEqual(response.status_code, 200) assert "request" not in response.context_data +@pytest.mark.oauth2_settings(presets.DEFAULT_SCOPES_RW) class TestAuthorizationCodeView(BaseTest): def test_skip_authorization_completely(self): """ @@ -110,44 +100,16 @@ def test_skip_authorization_completely(self): self.application.skip_authorization = True self.application.save() - query_string = urlencode( + response = self.client.get( + reverse("oauth2_provider:authorize"), { "client_id": self.application.client_id, "response_type": "code", "state": "random_state_string", "scope": "read write", "redirect_uri": "http://example.org", - } - ) - url = "{url}?{qs}".format( - url=reverse("oauth2_provider:authorize"), qs=query_string - ) - - response = self.client.get(url) - self.assertEqual(response.status_code, 302) - - def test_id_token_skip_authorization_completely(self): - """ - If application.skip_authorization = True, should skip the authorization page. - """ - self.client.login(username="test_user", password="123456") - self.application.skip_authorization = True - self.application.save() - - query_string = urlencode( - { - "client_id": self.application.client_id, - "response_type": "code", - "state": "random_state_string", - "scope": "openid", - "redirect_uri": "http://example.org", - } + }, ) - url = "{url}?{qs}".format( - url=reverse("oauth2_provider:authorize"), qs=query_string - ) - - response = self.client.get(url) self.assertEqual(response.status_code, 302) def test_pre_auth_invalid_client(self): @@ -156,14 +118,12 @@ def test_pre_auth_invalid_client(self): """ self.client.login(username="test_user", password="123456") - query_string = urlencode( - {"client_id": "fakeclientid", "response_type": "code", } - ) - url = "{url}?{qs}".format( - url=reverse("oauth2_provider:authorize"), qs=query_string - ) + query_data = { + "client_id": "fakeclientid", + "response_type": "code", + } - response = self.client.get(url) + response = self.client.get(reverse("oauth2_provider:authorize"), data=query_data) self.assertEqual(response.status_code, 400) self.assertEqual( response.context_data["url"], @@ -176,20 +136,15 @@ def test_pre_auth_valid_client(self): """ self.client.login(username="test_user", password="123456") - query_string = urlencode( - { - "client_id": self.application.client_id, - "response_type": "code", - "state": "random_state_string", - "scope": "read write", - "redirect_uri": "http://example.org", - } - ) - url = "{url}?{qs}".format( - url=reverse("oauth2_provider:authorize"), qs=query_string - ) + query_data = { + "client_id": self.application.client_id, + "response_type": "code", + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "http://example.org", + } - response = self.client.get(url) + response = self.client.get(reverse("oauth2_provider:authorize"), data=query_data) self.assertEqual(response.status_code, 200) # check form is in context and form params are valid @@ -201,37 +156,6 @@ def test_pre_auth_valid_client(self): self.assertEqual(form["scope"].value(), "read write") self.assertEqual(form["client_id"].value(), self.application.client_id) - def test_id_token_pre_auth_valid_client(self): - """ - Test response for a valid client_id with response_type: code - """ - self.client.login(username="test_user", password="123456") - - query_string = urlencode( - { - "client_id": self.application.client_id, - "response_type": "code", - "state": "random_state_string", - "scope": "openid", - "redirect_uri": "http://example.org", - } - ) - url = "{url}?{qs}".format( - url=reverse("oauth2_provider:authorize"), qs=query_string - ) - - response = self.client.get(url) - self.assertEqual(response.status_code, 200) - - # check form is in context and form params are valid - self.assertIn("form", response.context) - - form = response.context["form"] - self.assertEqual(form["redirect_uri"].value(), "http://example.org") - self.assertEqual(form["state"].value(), "random_state_string") - self.assertEqual(form["scope"].value(), "openid") - self.assertEqual(form["client_id"].value(), self.application.client_id) - def test_pre_auth_valid_client_custom_redirect_uri_scheme(self): """ Test response for a valid client_id with response_type: code @@ -239,20 +163,15 @@ def test_pre_auth_valid_client_custom_redirect_uri_scheme(self): """ self.client.login(username="test_user", password="123456") - query_string = urlencode( - { - "client_id": self.application.client_id, - "response_type": "code", - "state": "random_state_string", - "scope": "read write", - "redirect_uri": "custom-scheme://example.com", - } - ) - url = "{url}?{qs}".format( - url=reverse("oauth2_provider:authorize"), qs=query_string - ) + query_data = { + "client_id": self.application.client_id, + "response_type": "code", + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "custom-scheme://example.com", + } - response = self.client.get(url) + response = self.client.get(reverse("oauth2_provider:authorize"), data=query_data) self.assertEqual(response.status_code, 200) # check form is in context and form params are valid @@ -273,29 +192,26 @@ def test_pre_auth_approval_prompt(self): scope="read write", ) self.client.login(username="test_user", password="123456") - query_string = urlencode( - { - "client_id": self.application.client_id, - "response_type": "code", - "state": "random_state_string", - "scope": "read write", - "redirect_uri": "http://example.org", - "approval_prompt": "auto", - } - ) - url = "{url}?{qs}".format( - url=reverse("oauth2_provider:authorize"), qs=query_string - ) - response = self.client.get(url) + + query_data = { + "client_id": self.application.client_id, + "response_type": "code", + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "http://example.org", + "approval_prompt": "auto", + } + url = reverse("oauth2_provider:authorize") + response = self.client.get(url, data=query_data) self.assertEqual(response.status_code, 302) # user already authorized the application, but with different scopes: prompt them. tok.scope = "read" tok.save() - response = self.client.get(url) + response = self.client.get(url, data=query_data) self.assertEqual(response.status_code, 200) def test_pre_auth_approval_prompt_default(self): - self.assertEqual(oauth2_settings.REQUEST_APPROVAL_PROMPT, "force") + self.assertEqual(self.oauth2_settings.REQUEST_APPROVAL_PROMPT, "force") AccessToken.objects.create( user=self.test_user, @@ -305,23 +221,18 @@ def test_pre_auth_approval_prompt_default(self): scope="read write", ) self.client.login(username="test_user", password="123456") - query_string = urlencode( - { - "client_id": self.application.client_id, - "response_type": "code", - "state": "random_state_string", - "scope": "read write", - "redirect_uri": "http://example.org", - } - ) - url = "{url}?{qs}".format( - url=reverse("oauth2_provider:authorize"), qs=query_string - ) - response = self.client.get(url) + query_data = { + "client_id": self.application.client_id, + "response_type": "code", + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "http://example.org", + } + response = self.client.get(reverse("oauth2_provider:authorize"), data=query_data) self.assertEqual(response.status_code, 200) def test_pre_auth_approval_prompt_default_override(self): - oauth2_settings.REQUEST_APPROVAL_PROMPT = "auto" + self.oauth2_settings.REQUEST_APPROVAL_PROMPT = "auto" AccessToken.objects.create( user=self.test_user, @@ -331,19 +242,14 @@ def test_pre_auth_approval_prompt_default_override(self): scope="read write", ) self.client.login(username="test_user", password="123456") - query_string = urlencode( - { - "client_id": self.application.client_id, - "response_type": "code", - "state": "random_state_string", - "scope": "read write", - "redirect_uri": "http://example.org", - } - ) - url = "{url}?{qs}".format( - url=reverse("oauth2_provider:authorize"), qs=query_string - ) - response = self.client.get(url) + query_data = { + "client_id": self.application.client_id, + "response_type": "code", + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "http://example.org", + } + response = self.client.get(reverse("oauth2_provider:authorize"), data=query_data) self.assertEqual(response.status_code, 302) def test_pre_auth_default_redirect(self): @@ -352,14 +258,12 @@ def test_pre_auth_default_redirect(self): """ self.client.login(username="test_user", password="123456") - query_string = urlencode( - {"client_id": self.application.client_id, "response_type": "code", } - ) - url = "{url}?{qs}".format( - url=reverse("oauth2_provider:authorize"), qs=query_string - ) + query_data = { + "client_id": self.application.client_id, + "response_type": "code", + } - response = self.client.get(url) + response = self.client.get(reverse("oauth2_provider:authorize"), data=query_data) self.assertEqual(response.status_code, 200) form = response.context["form"] @@ -371,18 +275,13 @@ def test_pre_auth_forbibben_redirect(self): """ self.client.login(username="test_user", password="123456") - query_string = urlencode( - { - "client_id": self.application.client_id, - "response_type": "code", - "redirect_uri": "http://forbidden.it", - } - ) - url = "{url}?{qs}".format( - url=reverse("oauth2_provider:authorize"), qs=query_string - ) + query_data = { + "client_id": self.application.client_id, + "response_type": "code", + "redirect_uri": "http://forbidden.it", + } - response = self.client.get(url) + response = self.client.get(reverse("oauth2_provider:authorize"), data=query_data) self.assertEqual(response.status_code, 400) def test_pre_auth_wrong_response_type(self): @@ -391,14 +290,12 @@ def test_pre_auth_wrong_response_type(self): """ self.client.login(username="test_user", password="123456") - query_string = urlencode( - {"client_id": self.application.client_id, "response_type": "WRONG", } - ) - url = "{url}?{qs}".format( - url=reverse("oauth2_provider:authorize"), qs=query_string - ) + query_data = { + "client_id": self.application.client_id, + "response_type": "WRONG", + } - response = self.client.get(url) + response = self.client.get(reverse("oauth2_provider:authorize"), data=query_data) self.assertEqual(response.status_code, 302) self.assertIn("error=unsupported_response_type", response["Location"]) @@ -417,32 +314,7 @@ def test_code_post_auth_allow(self): "allow": True, } - response = self.client.post( - reverse("oauth2_provider:authorize"), data=form_data - ) - self.assertEqual(response.status_code, 302) - self.assertIn("http://example.org?", response["Location"]) - self.assertIn("state=random_state_string", response["Location"]) - self.assertIn("code=", response["Location"]) - - def test_id_token_code_post_auth_allow(self): - """ - Test authorization code is given for an allowed request with response_type: code - """ - self.client.login(username="test_user", password="123456") - - form_data = { - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": "openid", - "redirect_uri": "http://example.org", - "response_type": "code", - "allow": True, - } - - response = self.client.post( - reverse("oauth2_provider:authorize"), data=form_data - ) + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) self.assertEqual(response.status_code, 302) self.assertIn("http://example.org?", response["Location"]) self.assertIn("state=random_state_string", response["Location"]) @@ -463,9 +335,7 @@ def test_code_post_auth_deny(self): "allow": False, } - response = self.client.post( - reverse("oauth2_provider:authorize"), data=form_data - ) + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) self.assertEqual(response.status_code, 302) self.assertIn("error=access_denied", response["Location"]) self.assertIn("state=random_state_string", response["Location"]) @@ -484,9 +354,7 @@ def test_code_post_auth_deny_no_state(self): "allow": False, } - response = self.client.post( - reverse("oauth2_provider:authorize"), data=form_data - ) + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) self.assertEqual(response.status_code, 302) self.assertIn("error=access_denied", response["Location"]) self.assertNotIn("state", response["Location"]) @@ -506,9 +374,7 @@ def test_code_post_auth_bad_responsetype(self): "allow": True, } - response = self.client.post( - reverse("oauth2_provider:authorize"), data=form_data - ) + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) self.assertEqual(response.status_code, 302) self.assertIn("http://example.org?error", response["Location"]) @@ -527,9 +393,7 @@ def test_code_post_auth_forbidden_redirect_uri(self): "allow": True, } - response = self.client.post( - reverse("oauth2_provider:authorize"), data=form_data - ) + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) self.assertEqual(response.status_code, 400) def test_code_post_auth_malicious_redirect_uri(self): @@ -547,9 +411,7 @@ def test_code_post_auth_malicious_redirect_uri(self): "allow": True, } - response = self.client.post( - reverse("oauth2_provider:authorize"), data=form_data - ) + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) self.assertEqual(response.status_code, 400) def test_code_post_auth_allow_custom_redirect_uri_scheme(self): @@ -568,9 +430,7 @@ def test_code_post_auth_allow_custom_redirect_uri_scheme(self): "allow": True, } - response = self.client.post( - reverse("oauth2_provider:authorize"), data=form_data - ) + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) self.assertEqual(response.status_code, 302) self.assertIn("custom-scheme://example.com?", response["Location"]) self.assertIn("state=random_state_string", response["Location"]) @@ -592,9 +452,7 @@ def test_code_post_auth_deny_custom_redirect_uri_scheme(self): "allow": False, } - response = self.client.post( - reverse("oauth2_provider:authorize"), data=form_data - ) + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) self.assertEqual(response.status_code, 302) self.assertIn("custom-scheme://example.com?", response["Location"]) self.assertIn("error=access_denied", response["Location"]) @@ -617,9 +475,7 @@ def test_code_post_auth_redirection_uri_with_querystring(self): "allow": True, } - response = self.client.post( - reverse("oauth2_provider:authorize"), data=form_data - ) + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) self.assertEqual(response.status_code, 302) self.assertIn("http://example.com?foo=bar", response["Location"]) self.assertIn("code=", response["Location"]) @@ -642,9 +498,7 @@ def test_code_post_auth_failing_redirection_uri_with_querystring(self): "allow": False, } - response = self.client.post( - reverse("oauth2_provider:authorize"), data=form_data - ) + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) self.assertEqual(response.status_code, 302) self.assertIn("http://example.com?", response["Location"]) self.assertIn("error=access_denied", response["Location"]) @@ -666,13 +520,80 @@ def test_code_post_auth_fails_when_redirect_uri_path_is_invalid(self): "allow": True, } - response = self.client.post( - reverse("oauth2_provider:authorize"), data=form_data - ) + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) self.assertEqual(response.status_code, 400) -class TestAuthorizationCodeTokenView(BaseTest): +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) +class TestOIDCAuthorizationCodeView(BaseTest): + def test_id_token_skip_authorization_completely(self): + """ + If application.skip_authorization = True, should skip the authorization page. + """ + self.client.login(username="test_user", password="123456") + self.application.skip_authorization = True + self.application.save() + + query_data = { + "client_id": self.application.client_id, + "response_type": "code", + "state": "random_state_string", + "scope": "openid", + "redirect_uri": "http://example.org", + } + + response = self.client.get(reverse("oauth2_provider:authorize"), data=query_data) + self.assertEqual(response.status_code, 302) + + def test_id_token_pre_auth_valid_client(self): + """ + Test response for a valid client_id with response_type: code + """ + self.client.login(username="test_user", password="123456") + + query_data = { + "client_id": self.application.client_id, + "response_type": "code", + "state": "random_state_string", + "scope": "openid", + "redirect_uri": "http://example.org", + } + + response = self.client.get(reverse("oauth2_provider:authorize"), data=query_data) + self.assertEqual(response.status_code, 200) + + # check form is in context and form params are valid + self.assertIn("form", response.context) + + form = response.context["form"] + self.assertEqual(form["redirect_uri"].value(), "http://example.org") + self.assertEqual(form["state"].value(), "random_state_string") + self.assertEqual(form["scope"].value(), "openid") + self.assertEqual(form["client_id"].value(), self.application.client_id) + + def test_id_token_code_post_auth_allow(self): + """ + Test authorization code is given for an allowed request with response_type: code + """ + self.client.login(username="test_user", password="123456") + + form_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "openid", + "redirect_uri": "http://example.org", + "response_type": "code", + "allow": True, + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + self.assertEqual(response.status_code, 302) + self.assertIn("http://example.org?", response["Location"]) + self.assertIn("state=random_state_string", response["Location"]) + self.assertIn("code=", response["Location"]) + + +class BaseAuthorizationCodeTokenView(BaseTest): def get_auth(self, scope="read write"): """ Helper method to retrieve a valid authorization code @@ -686,9 +607,7 @@ def get_auth(self, scope="read write"): "allow": True, } - response = self.client.post( - reverse("oauth2_provider:authorize"), data=authcode_data - ) + response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) query_dict = parse_qs(urlparse(response["Location"]).query) return query_dict["code"].pop() @@ -699,11 +618,7 @@ def generate_pkce_codes(self, algorithm, length=43): code_verifier = get_random_string(length) if algorithm == "S256": code_challenge = ( - base64.urlsafe_b64encode( - hashlib.sha256(code_verifier.encode()).digest() - ) - .decode() - .rstrip("=") + base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest()).decode().rstrip("=") ) else: code_challenge = code_verifier @@ -713,7 +628,7 @@ def get_pkce_auth(self, code_challenge, code_challenge_method): """ Helper method to retrieve a valid authorization code using pkce """ - oauth2_settings.PKCE_REQUIRED = True + self.oauth2_settings.PKCE_REQUIRED = True authcode_data = { "client_id": self.application.client_id, "state": "random_state_string", @@ -725,13 +640,13 @@ def get_pkce_auth(self, code_challenge, code_challenge_method): "code_challenge_method": code_challenge_method, } - response = self.client.post( - reverse("oauth2_provider:authorize"), data=authcode_data - ) + response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) query_dict = parse_qs(urlparse(response["Location"]).query) - oauth2_settings.PKCE_REQUIRED = False return query_dict["code"].pop() + +@pytest.mark.oauth2_settings(presets.DEFAULT_SCOPES_RW) +class TestAuthorizationCodeTokenView(BaseAuthorizationCodeTokenView): def test_basic_auth(self): """ Request an access token using basic authentication for client authentication @@ -744,21 +659,15 @@ def test_basic_auth(self): "code": authorization_code, "redirect_uri": "http://example.org", } - auth_headers = get_basic_auth_header( - self.application.client_id, self.application.client_secret - ) + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 200) content = json.loads(response.content.decode("utf-8")) self.assertEqual(content["token_type"], "Bearer") self.assertEqual(content["scope"], "read write") - self.assertEqual( - content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS - ) + self.assertEqual(content["expires_in"], self.oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) def test_refresh(self): """ @@ -772,13 +681,9 @@ def test_refresh(self): "code": authorization_code, "redirect_uri": "http://example.org", } - auth_headers = get_basic_auth_header( - self.application.client_id, self.application.client_secret - ) + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) content = json.loads(response.content.decode("utf-8")) self.assertTrue("refresh_token" in content) @@ -789,27 +694,21 @@ def test_refresh(self): "code": authorization_code, "redirect_uri": "http://example.org", } - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) token_request_data = { "grant_type": "refresh_token", "refresh_token": content["refresh_token"], "scope": content["scope"], } - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 200) content = json.loads(response.content.decode("utf-8")) self.assertTrue("access_token" in content) # check refresh token cannot be used twice - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 400) content = json.loads(response.content.decode("utf-8")) self.assertTrue("invalid_grant" in content.values()) @@ -818,7 +717,7 @@ def test_refresh_with_grace_period(self): """ Request an access token using a refresh token """ - oauth2_settings.REFRESH_TOKEN_GRACE_PERIOD_SECONDS = 120 + self.oauth2_settings.REFRESH_TOKEN_GRACE_PERIOD_SECONDS = 120 self.client.login(username="test_user", password="123456") authorization_code = self.get_auth() @@ -827,13 +726,9 @@ def test_refresh_with_grace_period(self): "code": authorization_code, "redirect_uri": "http://example.org", } - auth_headers = get_basic_auth_header( - self.application.client_id, self.application.client_secret - ) + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) content = json.loads(response.content.decode("utf-8")) self.assertTrue("refresh_token" in content) @@ -844,9 +739,7 @@ def test_refresh_with_grace_period(self): "code": authorization_code, "redirect_uri": "http://example.org", } - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) token_request_data = { "grant_type": "refresh_token", @@ -854,9 +747,7 @@ def test_refresh_with_grace_period(self): "scope": content["scope"], } - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 200) content = json.loads(response.content.decode("utf-8")) @@ -865,9 +756,7 @@ def test_refresh_with_grace_period(self): first_refresh_token = content["refresh_token"] # check access token returns same data if used twice, see #497 - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 200) content = json.loads(response.content.decode("utf-8")) self.assertTrue("access_token" in content) @@ -875,7 +764,6 @@ def test_refresh_with_grace_period(self): # refresh token should be the same as well self.assertTrue("refresh_token" in content) self.assertEqual(content["refresh_token"], first_refresh_token) - oauth2_settings.REFRESH_TOKEN_GRACE_PERIOD_SECONDS = 0 def test_refresh_invalidates_old_tokens(self): """ @@ -889,13 +777,9 @@ def test_refresh_invalidates_old_tokens(self): "code": authorization_code, "redirect_uri": "http://example.org", } - auth_headers = get_basic_auth_header( - self.application.client_id, self.application.client_secret - ) + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) content = json.loads(response.content.decode("utf-8")) rt = content["refresh_token"] @@ -906,9 +790,7 @@ def test_refresh_invalidates_old_tokens(self): "refresh_token": rt, "scope": content["scope"], } - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 200) refresh_token = RefreshToken.objects.filter(token=rt).first() @@ -927,13 +809,9 @@ def test_refresh_no_scopes(self): "code": authorization_code, "redirect_uri": "http://example.org", } - auth_headers = get_basic_auth_header( - self.application.client_id, self.application.client_secret - ) + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) content = json.loads(response.content.decode("utf-8")) self.assertTrue("refresh_token" in content) @@ -941,9 +819,7 @@ def test_refresh_no_scopes(self): "grant_type": "refresh_token", "refresh_token": content["refresh_token"], } - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 200) content = json.loads(response.content.decode("utf-8")) @@ -961,13 +837,9 @@ def test_refresh_bad_scopes(self): "code": authorization_code, "redirect_uri": "http://example.org", } - auth_headers = get_basic_auth_header( - self.application.client_id, self.application.client_secret - ) + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) content = json.loads(response.content.decode("utf-8")) self.assertTrue("refresh_token" in content) @@ -976,9 +848,7 @@ def test_refresh_bad_scopes(self): "refresh_token": content["refresh_token"], "scope": "read write nuke", } - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 400) def test_refresh_fail_repeating_requests(self): @@ -993,13 +863,9 @@ def test_refresh_fail_repeating_requests(self): "code": authorization_code, "redirect_uri": "http://example.org", } - auth_headers = get_basic_auth_header( - self.application.client_id, self.application.client_secret - ) + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) content = json.loads(response.content.decode("utf-8")) self.assertTrue("refresh_token" in content) @@ -1008,13 +874,9 @@ def test_refresh_fail_repeating_requests(self): "refresh_token": content["refresh_token"], "scope": content["scope"], } - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 200) - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 400) def test_refresh_repeating_requests(self): @@ -1022,7 +884,7 @@ def test_refresh_repeating_requests(self): Trying to refresh an access token with the same refresh token more than once succeeds in the grace period and fails outside """ - oauth2_settings.REFRESH_TOKEN_GRACE_PERIOD_SECONDS = 120 + self.oauth2_settings.REFRESH_TOKEN_GRACE_PERIOD_SECONDS = 120 self.client.login(username="test_user", password="123456") authorization_code = self.get_auth() @@ -1031,13 +893,9 @@ def test_refresh_repeating_requests(self): "code": authorization_code, "redirect_uri": "http://example.org", } - auth_headers = get_basic_auth_header( - self.application.client_id, self.application.client_secret - ) + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) content = json.loads(response.content.decode("utf-8")) self.assertTrue("refresh_token" in content) @@ -1046,28 +904,19 @@ def test_refresh_repeating_requests(self): "refresh_token": content["refresh_token"], "scope": content["scope"], } - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 200) - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 200) # try refreshing outside the refresh window, see #497 rt = RefreshToken.objects.get(token=content["refresh_token"]) self.assertIsNotNone(rt.revoked) - rt.revoked = timezone.now() - datetime.timedelta( - minutes=10 - ) # instead of mocking out datetime + rt.revoked = timezone.now() - datetime.timedelta(minutes=10) # instead of mocking out datetime rt.save() - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 400) - oauth2_settings.REFRESH_TOKEN_GRACE_PERIOD_SECONDS = 0 def test_refresh_repeating_requests_non_rotating_tokens(self): """ @@ -1081,13 +930,9 @@ def test_refresh_repeating_requests_non_rotating_tokens(self): "code": authorization_code, "redirect_uri": "http://example.org", } - auth_headers = get_basic_auth_header( - self.application.client_id, self.application.client_secret - ) + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) content = json.loads(response.content.decode("utf-8")) self.assertTrue("refresh_token" in content) @@ -1096,19 +941,13 @@ def test_refresh_repeating_requests_non_rotating_tokens(self): "refresh_token": content["refresh_token"], "scope": content["scope"], } - oauth2_settings.ROTATE_REFRESH_TOKEN = False + self.oauth2_settings.ROTATE_REFRESH_TOKEN = False - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 200) - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 200) - oauth2_settings.ROTATE_REFRESH_TOKEN = True - def test_basic_auth_bad_authcode(self): """ Request an access token using a bad authorization code @@ -1120,13 +959,9 @@ def test_basic_auth_bad_authcode(self): "code": "BLAH", "redirect_uri": "http://example.org", } - auth_headers = get_basic_auth_header( - self.application.client_id, self.application.client_secret - ) + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 400) def test_basic_auth_bad_granttype(self): @@ -1135,18 +970,10 @@ def test_basic_auth_bad_granttype(self): """ self.client.login(username="test_user", password="123456") - token_request_data = { - "grant_type": "UNKNOWN", - "code": "BLAH", - "redirect_uri": "http://example.org", - } - auth_headers = get_basic_auth_header( - self.application.client_id, self.application.client_secret - ) + token_request_data = {"grant_type": "UNKNOWN", "code": "BLAH", "redirect_uri": "http://example.org"} + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 400) def test_basic_auth_grant_expired(self): @@ -1169,13 +996,9 @@ def test_basic_auth_grant_expired(self): "code": "BLAH", "redirect_uri": "http://example.org", } - auth_headers = get_basic_auth_header( - self.application.client_id, self.application.client_secret - ) + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 400) def test_basic_auth_bad_secret(self): @@ -1192,9 +1015,7 @@ def test_basic_auth_bad_secret(self): } auth_headers = get_basic_auth_header(self.application.client_id, "BOOM!") - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 401) def test_basic_auth_wrong_auth_type(self): @@ -1210,17 +1031,13 @@ def test_basic_auth_wrong_auth_type(self): "redirect_uri": "http://example.org", } - user_pass = "{0}:{1}".format( - self.application.client_id, self.application.client_secret - ) + user_pass = "{0}:{1}".format(self.application.client_id, self.application.client_secret) auth_string = base64.b64encode(user_pass.encode("utf-8")) auth_headers = { "HTTP_AUTHORIZATION": "Wrong " + auth_string.decode("utf-8"), } - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 401) def test_request_body_params(self): @@ -1238,17 +1055,13 @@ def test_request_body_params(self): "client_secret": self.application.client_secret, } - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) self.assertEqual(response.status_code, 200) content = json.loads(response.content.decode("utf-8")) self.assertEqual(content["token_type"], "Bearer") self.assertEqual(content["scope"], "read write") - self.assertEqual( - content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS - ) + self.assertEqual(content["expires_in"], self.oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) def test_public(self): """ @@ -1267,115 +1080,69 @@ def test_public(self): "client_id": self.application.client_id, } - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) self.assertEqual(response.status_code, 200) content = json.loads(response.content.decode("utf-8")) self.assertEqual(content["token_type"], "Bearer") self.assertEqual(content["scope"], "read write") - self.assertEqual( - content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS - ) - - def test_id_token_public(self): - """ - Request an access token using client_type: public - """ - self.client.login(username="test_user", password="123456") - - self.application.client_type = Application.CLIENT_PUBLIC - self.application.save() - authorization_code = self.get_auth(scope="openid") - - token_request_data = { - "grant_type": "authorization_code", - "code": authorization_code, - "redirect_uri": "http://example.org", - "client_id": self.application.client_id, - "scope": "openid", - } - - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data - ) - self.assertEqual(response.status_code, 200) - - content = json.loads(response.content.decode("utf-8")) - self.assertEqual(content["token_type"], "Bearer") - self.assertEqual(content["scope"], "openid") - self.assertIn("access_token", content) - self.assertIn("id_token", content) - self.assertEqual( - content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS - ) + self.assertEqual(content["expires_in"], self.oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) def test_public_pkce_S256_authorize_get(self): """ Request an access token using client_type: public and PKCE enabled. Tests if the authorize get is successfull - for the S256 algorithm + for the S256 algorithm and form data are properly passed. """ self.client.login(username="test_user", password="123456") self.application.client_type = Application.CLIENT_PUBLIC self.application.save() code_verifier, code_challenge = self.generate_pkce_codes("S256") - oauth2_settings.PKCE_REQUIRED = True + self.oauth2_settings.PKCE_REQUIRED = True - query_string = urlencode( - { - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": "read write", - "redirect_uri": "http://example.org", - "response_type": "code", - "allow": True, - "code_challenge": code_challenge, - "code_challenge_method": "S256", - } - ) - url = "{url}?{qs}".format( - url=reverse("oauth2_provider:authorize"), qs=query_string - ) + query_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "http://example.org", + "response_type": "code", + "allow": True, + "code_challenge": code_challenge, + "code_challenge_method": "S256", + } - response = self.client.get(url) - self.assertEqual(response.status_code, 200) - oauth2_settings.PKCE_REQUIRED = False + response = self.client.get(reverse("oauth2_provider:authorize"), data=query_data) + self.assertContains(response, 'value="S256"', count=1, status_code=200) + self.assertContains(response, 'value="{0}"'.format(code_challenge), count=1, status_code=200) def test_public_pkce_plain_authorize_get(self): """ Request an access token using client_type: public and PKCE enabled. Tests if the authorize get is successfull - for the plain algorithm + for the plain algorithm and form data are properly passed. """ self.client.login(username="test_user", password="123456") self.application.client_type = Application.CLIENT_PUBLIC self.application.save() code_verifier, code_challenge = self.generate_pkce_codes("plain") - oauth2_settings.PKCE_REQUIRED = True + self.oauth2_settings.PKCE_REQUIRED = True - query_string = urlencode( - { - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": "read write", - "redirect_uri": "http://example.org", - "response_type": "code", - "allow": True, - "code_challenge": code_challenge, - "code_challenge_method": "plain", - } - ) - url = "{url}?{qs}".format( - url=reverse("oauth2_provider:authorize"), qs=query_string - ) + query_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "http://example.org", + "response_type": "code", + "allow": True, + "code_challenge": code_challenge, + "code_challenge_method": "plain", + } - response = self.client.get(url) - self.assertEqual(response.status_code, 200) - oauth2_settings.PKCE_REQUIRED = False + response = self.client.get(reverse("oauth2_provider:authorize"), data=query_data) + self.assertContains(response, 'value="plain"', count=1, status_code=200) + self.assertContains(response, 'value="{0}"'.format(code_challenge), count=1, status_code=200) def test_public_pkce_S256(self): """ @@ -1388,7 +1155,7 @@ def test_public_pkce_S256(self): self.application.save() code_verifier, code_challenge = self.generate_pkce_codes("S256") authorization_code = self.get_pkce_auth(code_challenge, "S256") - oauth2_settings.PKCE_REQUIRED = True + self.oauth2_settings.PKCE_REQUIRED = True token_request_data = { "grant_type": "authorization_code", @@ -1398,18 +1165,13 @@ def test_public_pkce_S256(self): "code_verifier": code_verifier, } - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) self.assertEqual(response.status_code, 200) content = json.loads(response.content.decode("utf-8")) self.assertEqual(content["token_type"], "Bearer") self.assertEqual(content["scope"], "read write") - self.assertEqual( - content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS - ) - oauth2_settings.PKCE_REQUIRED = False + self.assertEqual(content["expires_in"], self.oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) def test_public_pkce_plain(self): """ @@ -1422,7 +1184,7 @@ def test_public_pkce_plain(self): self.application.save() code_verifier, code_challenge = self.generate_pkce_codes("plain") authorization_code = self.get_pkce_auth(code_challenge, "plain") - oauth2_settings.PKCE_REQUIRED = True + self.oauth2_settings.PKCE_REQUIRED = True token_request_data = { "grant_type": "authorization_code", @@ -1432,18 +1194,13 @@ def test_public_pkce_plain(self): "code_verifier": code_verifier, } - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) self.assertEqual(response.status_code, 200) content = json.loads(response.content.decode("utf-8")) self.assertEqual(content["token_type"], "Bearer") self.assertEqual(content["scope"], "read write") - self.assertEqual( - content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS - ) - oauth2_settings.PKCE_REQUIRED = False + self.assertEqual(content["expires_in"], self.oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) def test_public_pkce_invalid_algorithm(self): """ @@ -1455,28 +1212,22 @@ def test_public_pkce_invalid_algorithm(self): self.application.client_type = Application.CLIENT_PUBLIC self.application.save() code_verifier, code_challenge = self.generate_pkce_codes("invalid") - oauth2_settings.PKCE_REQUIRED = True + self.oauth2_settings.PKCE_REQUIRED = True - query_string = urlencode( - { - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": "read write", - "redirect_uri": "http://example.org", - "response_type": "code", - "allow": True, - "code_challenge": code_challenge, - "code_challenge_method": "invalid", - } - ) - url = "{url}?{qs}".format( - url=reverse("oauth2_provider:authorize"), qs=query_string - ) + query_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "http://example.org", + "response_type": "code", + "allow": True, + "code_challenge": code_challenge, + "code_challenge_method": "invalid", + } - response = self.client.get(url) + response = self.client.get(reverse("oauth2_provider:authorize"), data=query_data) self.assertEqual(response.status_code, 302) self.assertIn("error=invalid_request", response["Location"]) - oauth2_settings.PKCE_REQUIRED = False def test_public_pkce_missing_code_challenge(self): """ @@ -1489,27 +1240,21 @@ def test_public_pkce_missing_code_challenge(self): self.application.skip_authorization = True self.application.save() code_verifier, code_challenge = self.generate_pkce_codes("S256") - oauth2_settings.PKCE_REQUIRED = True + self.oauth2_settings.PKCE_REQUIRED = True - query_string = urlencode( - { - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": "read write", - "redirect_uri": "http://example.org", - "response_type": "code", - "allow": True, - "code_challenge_method": "S256", - } - ) - url = "{url}?{qs}".format( - url=reverse("oauth2_provider:authorize"), qs=query_string - ) + query_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "http://example.org", + "response_type": "code", + "allow": True, + "code_challenge_method": "S256", + } - response = self.client.get(url) + response = self.client.get(reverse("oauth2_provider:authorize"), data=query_data) self.assertEqual(response.status_code, 302) self.assertIn("error=invalid_request", response["Location"]) - oauth2_settings.PKCE_REQUIRED = False def test_public_pkce_missing_code_challenge_method(self): """ @@ -1521,26 +1266,20 @@ def test_public_pkce_missing_code_challenge_method(self): self.application.client_type = Application.CLIENT_PUBLIC self.application.save() code_verifier, code_challenge = self.generate_pkce_codes("S256") - oauth2_settings.PKCE_REQUIRED = True + self.oauth2_settings.PKCE_REQUIRED = True - query_string = urlencode( - { - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": "read write", - "redirect_uri": "http://example.org", - "response_type": "code", - "allow": True, - "code_challenge": code_challenge, - } - ) - url = "{url}?{qs}".format( - url=reverse("oauth2_provider:authorize"), qs=query_string - ) + query_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "http://example.org", + "response_type": "code", + "allow": True, + "code_challenge": code_challenge, + } - response = self.client.get(url) + response = self.client.get(reverse("oauth2_provider:authorize"), data=query_data) self.assertEqual(response.status_code, 200) - oauth2_settings.PKCE_REQUIRED = False def test_public_pkce_S256_invalid_code_verifier(self): """ @@ -1553,7 +1292,7 @@ def test_public_pkce_S256_invalid_code_verifier(self): self.application.save() code_verifier, code_challenge = self.generate_pkce_codes("S256") authorization_code = self.get_pkce_auth(code_challenge, "S256") - oauth2_settings.PKCE_REQUIRED = True + self.oauth2_settings.PKCE_REQUIRED = True token_request_data = { "grant_type": "authorization_code", @@ -1563,11 +1302,8 @@ def test_public_pkce_S256_invalid_code_verifier(self): "code_verifier": "invalid", } - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) self.assertEqual(response.status_code, 400) - oauth2_settings.PKCE_REQUIRED = False def test_public_pkce_plain_invalid_code_verifier(self): """ @@ -1580,7 +1316,7 @@ def test_public_pkce_plain_invalid_code_verifier(self): self.application.save() code_verifier, code_challenge = self.generate_pkce_codes("plain") authorization_code = self.get_pkce_auth(code_challenge, "plain") - oauth2_settings.PKCE_REQUIRED = True + self.oauth2_settings.PKCE_REQUIRED = True token_request_data = { "grant_type": "authorization_code", @@ -1590,11 +1326,8 @@ def test_public_pkce_plain_invalid_code_verifier(self): "code_verifier": "invalid", } - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) self.assertEqual(response.status_code, 400) - oauth2_settings.PKCE_REQUIRED = False def test_public_pkce_S256_missing_code_verifier(self): """ @@ -1607,7 +1340,7 @@ def test_public_pkce_S256_missing_code_verifier(self): self.application.save() code_verifier, code_challenge = self.generate_pkce_codes("S256") authorization_code = self.get_pkce_auth(code_challenge, "S256") - oauth2_settings.PKCE_REQUIRED = True + self.oauth2_settings.PKCE_REQUIRED = True token_request_data = { "grant_type": "authorization_code", @@ -1616,11 +1349,8 @@ def test_public_pkce_S256_missing_code_verifier(self): "client_id": self.application.client_id, } - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) self.assertEqual(response.status_code, 400) - oauth2_settings.PKCE_REQUIRED = False def test_public_pkce_plain_missing_code_verifier(self): """ @@ -1633,7 +1363,7 @@ def test_public_pkce_plain_missing_code_verifier(self): self.application.save() code_verifier, code_challenge = self.generate_pkce_codes("plain") authorization_code = self.get_pkce_auth(code_challenge, "plain") - oauth2_settings.PKCE_REQUIRED = True + self.oauth2_settings.PKCE_REQUIRED = True token_request_data = { "grant_type": "authorization_code", @@ -1642,11 +1372,8 @@ def test_public_pkce_plain_missing_code_verifier(self): "client_id": self.application.client_id, } - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) self.assertEqual(response.status_code, 400) - oauth2_settings.PKCE_REQUIRED = False def test_malicious_redirect_uri(self): """ @@ -1666,9 +1393,7 @@ def test_malicious_redirect_uri(self): "client_id": self.application.client_id, } - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) self.assertEqual(response.status_code, 400) data = response.json() self.assertEqual(data["error"], "invalid_request") @@ -1692,9 +1417,7 @@ def test_code_exchange_succeed_when_redirect_uri_match(self): "response_type": "code", "allow": True, } - response = self.client.post( - reverse("oauth2_provider:authorize"), data=authcode_data - ) + response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) query_dict = parse_qs(urlparse(response["Location"]).query) authorization_code = query_dict["code"].pop() @@ -1704,21 +1427,15 @@ def test_code_exchange_succeed_when_redirect_uri_match(self): "code": authorization_code, "redirect_uri": "http://example.org?foo=bar", } - auth_headers = get_basic_auth_header( - self.application.client_id, self.application.client_secret - ) + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 200) content = json.loads(response.content.decode("utf-8")) self.assertEqual(content["token_type"], "Bearer") self.assertEqual(content["scope"], "read write") - self.assertEqual( - content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS - ) + self.assertEqual(content["expires_in"], self.oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) def test_code_exchange_fails_when_redirect_uri_does_not_match(self): """ @@ -1735,9 +1452,7 @@ def test_code_exchange_fails_when_redirect_uri_does_not_match(self): "response_type": "code", "allow": True, } - response = self.client.post( - reverse("oauth2_provider:authorize"), data=authcode_data - ) + response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) query_dict = parse_qs(urlparse(response["Location"]).query) authorization_code = query_dict["code"].pop() @@ -1747,13 +1462,9 @@ def test_code_exchange_fails_when_redirect_uri_does_not_match(self): "code": authorization_code, "redirect_uri": "http://example.org?foo=baraa", } - auth_headers = get_basic_auth_header( - self.application.client_id, self.application.client_secret - ) + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 400) data = response.json() self.assertEqual(data["error"], "invalid_request") @@ -1781,9 +1492,7 @@ def test_code_exchange_succeed_when_redirect_uri_match_with_multiple_query_param "response_type": "code", "allow": True, } - response = self.client.post( - reverse("oauth2_provider:authorize"), data=authcode_data - ) + response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) query_dict = parse_qs(urlparse(response["Location"]).query) authorization_code = query_dict["code"].pop() @@ -1793,70 +1502,15 @@ def test_code_exchange_succeed_when_redirect_uri_match_with_multiple_query_param "code": authorization_code, "redirect_uri": "http://example.com?bar=baz&foo=bar", } - auth_headers = get_basic_auth_header( - self.application.client_id, self.application.client_secret - ) + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 200) content = json.loads(response.content.decode("utf-8")) self.assertEqual(content["token_type"], "Bearer") self.assertEqual(content["scope"], "read write") - self.assertEqual( - content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS - ) - - def test_id_token_code_exchange_succeed_when_redirect_uri_match_with_multiple_query_params( - self, - ): - """ - Tests code exchange succeed when redirect uri matches the one used for code request - """ - self.client.login(username="test_user", password="123456") - self.application.redirect_uris = "http://localhost http://example.com?foo=bar" - self.application.save() - - # retrieve a valid authorization code - authcode_data = { - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": "openid", - "redirect_uri": "http://example.com?bar=baz&foo=bar", - "response_type": "code", - "allow": True, - } - response = self.client.post( - reverse("oauth2_provider:authorize"), data=authcode_data - ) - query_dict = parse_qs(urlparse(response["Location"]).query) - authorization_code = query_dict["code"].pop() - - # exchange authorization code for a valid access token - token_request_data = { - "grant_type": "authorization_code", - "code": authorization_code, - "redirect_uri": "http://example.com?bar=baz&foo=bar", - } - auth_headers = get_basic_auth_header( - self.application.client_id, self.application.client_secret - ) - - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) - self.assertEqual(response.status_code, 200) - - content = json.loads(response.content.decode("utf-8")) - self.assertEqual(content["token_type"], "Bearer") - self.assertEqual(content["scope"], "openid") - self.assertIn("access_token", content) - self.assertIn("id_token", content) - self.assertEqual( - content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS - ) + self.assertEqual(content["expires_in"], self.oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) def test_oob_as_html(self): """ @@ -1902,7 +1556,7 @@ def test_oob_as_html(self): content = json.loads(response.content.decode("utf-8")) self.assertEqual(content["token_type"], "Bearer") self.assertEqual(content["scope"], "read write") - self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + self.assertEqual(content["expires_in"], self.oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) def test_oob_as_json(self): """ @@ -1919,9 +1573,7 @@ def test_oob_as_json(self): "allow": True, } - response = self.client.post( - reverse("oauth2_provider:authorize"), data=authcode_data - ) + response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) self.assertEqual(response.status_code, 200) self.assertRegex(response["Content-Type"], "^application/json") @@ -1938,19 +1590,136 @@ def test_oob_as_json(self): "client_secret": self.application.client_secret, } - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) self.assertEqual(response.status_code, 200) content = json.loads(response.content.decode("utf-8")) self.assertEqual(content["token_type"], "Bearer") self.assertEqual(content["scope"], "read write") - self.assertEqual( - content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS - ) + self.assertEqual(content["expires_in"], self.oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) +class TestOIDCAuthorizationCodeTokenView(BaseAuthorizationCodeTokenView): + def setUp(self): + super().setUp() + self.application.algorithm = Application.RS256_ALGORITHM + self.application.save() + + def test_id_token_public(self): + """ + Request an access token using client_type: public + """ + self.client.login(username="test_user", password="123456") + + self.application.client_type = Application.CLIENT_PUBLIC + self.application.save() + authorization_code = self.get_auth(scope="openid") + + token_request_data = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": "http://example.org", + "client_id": self.application.client_id, + "scope": "openid", + } + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) + self.assertEqual(response.status_code, 200) + + content = json.loads(response.content.decode("utf-8")) + self.assertEqual(content["token_type"], "Bearer") + self.assertEqual(content["scope"], "openid") + self.assertIn("access_token", content) + self.assertIn("id_token", content) + self.assertEqual(content["expires_in"], self.oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + + def test_id_token_code_exchange_succeed_when_redirect_uri_match_with_multiple_query_params( + self, + ): + """ + Tests code exchange succeed when redirect uri matches the one used for code request + """ + self.client.login(username="test_user", password="123456") + self.application.redirect_uris = "http://localhost http://example.com?foo=bar" + self.application.save() + + # retrieve a valid authorization code + authcode_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "openid", + "redirect_uri": "http://example.com?bar=baz&foo=bar", + "response_type": "code", + "allow": True, + } + response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) + query_dict = parse_qs(urlparse(response["Location"]).query) + authorization_code = query_dict["code"].pop() + + # exchange authorization code for a valid access token + token_request_data = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": "http://example.com?bar=baz&foo=bar", + } + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + self.assertEqual(response.status_code, 200) + + content = json.loads(response.content.decode("utf-8")) + self.assertEqual(content["token_type"], "Bearer") + self.assertEqual(content["scope"], "openid") + self.assertIn("access_token", content) + self.assertIn("id_token", content) + self.assertEqual(content["expires_in"], self.oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + + +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) +class TestOIDCAuthorizationCodeHSAlgorithm(BaseAuthorizationCodeTokenView): + def setUp(self): + super().setUp() + self.oauth2_settings.OIDC_RSA_PRIVATE_KEY = None + self.application.algorithm = Application.HS256_ALGORITHM + self.application.save() + + def test_id_token(self): + """ + Request an access token using an HS256 application + """ + self.client.login(username="test_user", password="123456") + + authorization_code = self.get_auth(scope="openid") + + token_request_data = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": "http://example.org", + "client_id": self.application.client_id, + "client_secret": self.application.client_secret, + "scope": "openid", + } + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) + self.assertEqual(response.status_code, 200) + + content = response.json() + self.assertEqual(content["token_type"], "Bearer") + self.assertEqual(content["scope"], "openid") + self.assertIn("access_token", content) + self.assertIn("id_token", content) + self.assertEqual(content["expires_in"], self.oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + + # Check decoding JWT using HS256 + key = self.application.jwk_key + assert key.key_type == "oct" + jwt_token = jwt.JWT(key=key, jwt=content["id_token"]) + claims = json.loads(jwt_token.claims) + assert claims["sub"] == "1" + + +@pytest.mark.oauth2_settings(presets.DEFAULT_SCOPES_RW) class TestAuthorizationCodeProtectedResource(BaseTest): def test_resource_access_allowed(self): self.client.login(username="test_user", password="123456") @@ -1964,9 +1733,7 @@ def test_resource_access_allowed(self): "response_type": "code", "allow": True, } - response = self.client.post( - reverse("oauth2_provider:authorize"), data=authcode_data - ) + response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) query_dict = parse_qs(urlparse(response["Location"]).query) authorization_code = query_dict["code"].pop() @@ -1976,13 +1743,9 @@ def test_resource_access_allowed(self): "code": authorization_code, "redirect_uri": "http://example.org", } - auth_headers = get_basic_auth_header( - self.application.client_id, self.application.client_secret - ) + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) content = json.loads(response.content.decode("utf-8")) access_token = content["access_token"] @@ -1997,6 +1760,25 @@ def test_resource_access_allowed(self): response = view(request) self.assertEqual(response, "This is a protected resource") + def test_resource_access_deny(self): + auth_headers = { + "HTTP_AUTHORIZATION": "Bearer " + "faketoken", + } + request = self.factory.get("/fake-resource", **auth_headers) + request.user = self.test_user + + view = ResourceView.as_view() + response = view(request) + self.assertEqual(response.status_code, 403) + + +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) +class TestOIDCAuthorizationCodeProtectedResource(BaseTest): + def setUp(self): + super().setUp() + self.application.algorithm = Application.RS256_ALGORITHM + self.application.save() + def test_id_token_resource_access_allowed(self): self.client.login(username="test_user", password="123456") @@ -2009,9 +1791,7 @@ def test_id_token_resource_access_allowed(self): "response_type": "code", "allow": True, } - response = self.client.post( - reverse("oauth2_provider:authorize"), data=authcode_data - ) + response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) query_dict = parse_qs(urlparse(response["Location"]).query) authorization_code = query_dict["code"].pop() @@ -2021,13 +1801,9 @@ def test_id_token_resource_access_allowed(self): "code": authorization_code, "redirect_uri": "http://example.org", } - auth_headers = get_basic_auth_header( - self.application.client_id, self.application.client_secret - ) + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) content = json.loads(response.content.decode("utf-8")) access_token = content["access_token"] id_token = content["id_token"] @@ -2054,39 +1830,23 @@ def test_id_token_resource_access_allowed(self): response = view(request) self.assertEqual(response, "This is a protected resource") - def test_resource_access_deny(self): - auth_headers = { - "HTTP_AUTHORIZATION": "Bearer " + "faketoken", - } - request = self.factory.get("/fake-resource", **auth_headers) - request.user = self.test_user - - view = ResourceView.as_view() - response = view(request) - self.assertEqual(response.status_code, 403) - +@pytest.mark.oauth2_settings(presets.DEFAULT_SCOPES_RO) class TestDefaultScopes(BaseTest): def test_pre_auth_default_scopes(self): """ Test response for a valid client_id with response_type: code using default scopes """ self.client.login(username="test_user", password="123456") - oauth2_settings._DEFAULT_SCOPES = ["read"] - query_string = urlencode( - { - "client_id": self.application.client_id, - "response_type": "code", - "state": "random_state_string", - "redirect_uri": "http://example.org", - } - ) - url = "{url}?{qs}".format( - url=reverse("oauth2_provider:authorize"), qs=query_string - ) + query_data = { + "client_id": self.application.client_id, + "response_type": "code", + "state": "random_state_string", + "redirect_uri": "http://example.org", + } - response = self.client.get(url) + response = self.client.get(reverse("oauth2_provider:authorize"), data=query_data) self.assertEqual(response.status_code, 200) # check form is in context and form params are valid @@ -2097,4 +1857,3 @@ def test_pre_auth_default_scopes(self): self.assertEqual(form["state"].value(), "random_state_string") self.assertEqual(form["scope"].value(), "read") self.assertEqual(form["client_id"].value(), self.application.client_id) - oauth2_settings._DEFAULT_SCOPES = ["read", "write"] diff --git a/tests/test_client_credential.py b/tests/test_client_credential.py index 09401cf..8b9aa3b 100644 --- a/tests/test_client_credential.py +++ b/tests/test_client_credential.py @@ -1,6 +1,7 @@ import json from urllib.parse import quote_plus +import pytest from django.contrib.auth import get_user_model from django.test import RequestFactory, TestCase from django.urls import reverse @@ -10,10 +11,10 @@ from oauth2_provider.models import get_access_token_model, get_application_model from oauth2_provider.oauth2_backends import OAuthLibCore from oauth2_provider.oauth2_validators import OAuth2Validator -from oauth2_provider.settings import oauth2_settings from oauth2_provider.views import ProtectedResourceView from oauth2_provider.views.mixins import OAuthLibMixin +from . import presets from .utils import get_basic_auth_header @@ -28,6 +29,8 @@ def get(self, request, *args, **kwargs): return "This is a protected resource" +@pytest.mark.usefixtures("oauth2_settings") +@pytest.mark.oauth2_settings(presets.DEFAULT_SCOPES_RW) class BaseTest(TestCase): def setUp(self): self.factory = RequestFactory() @@ -41,9 +44,6 @@ def setUp(self): authorization_grant_type=Application.GRANT_CLIENT_CREDENTIALS, ) - oauth2_settings._SCOPES = ["read", "write"] - oauth2_settings._DEFAULT_SCOPES = ["read", "write"] - def tearDown(self): self.application.delete() self.test_user.delete() @@ -105,7 +105,7 @@ class TestExtendedRequest(BaseTest): @classmethod def setUpClass(cls): cls.request_factory = RequestFactory() - super(TestExtendedRequest, cls).setUpClass() + super().setUpClass() def test_extended_request(self): class TestView(OAuthLibMixin, View): @@ -158,11 +158,7 @@ def test_client_resource_password_based(self): authorization_grant_type=Application.GRANT_PASSWORD, ) - token_request_data = { - "grant_type": "password", - "username": "test_user", - "password": "123456" - } + token_request_data = {"grant_type": "password", "username": "test_user", "password": "123456"} auth_headers = get_basic_auth_header( quote_plus(self.application.client_id), quote_plus(self.application.client_secret) ) diff --git a/tests/test_commands.py b/tests/test_commands.py index 274ecce..ff5deba 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -12,7 +12,6 @@ class CreateApplicationTest(TestCase): - def test_command_creates_application(self): output = StringIO() self.assertEqual(Application.objects.count(), 0) diff --git a/tests/test_decorators.py b/tests/test_decorators.py index 0732b29..ce17a89 100644 --- a/tests/test_decorators.py +++ b/tests/test_decorators.py @@ -6,7 +6,6 @@ from oauth2_provider.decorators import protected_resource, rw_protected_resource from oauth2_provider.models import get_access_token_model, get_application_model -from oauth2_provider.settings import oauth2_settings Application = get_application_model() @@ -18,7 +17,7 @@ class TestProtectedResourceDecorator(TestCase): @classmethod def setUpClass(cls): cls.request_factory = RequestFactory() - super(TestProtectedResourceDecorator, cls).setUpClass() + super().setUpClass() def setUp(self): self.user = UserModel.objects.create_user("test_user", "test@example.com", "123456") @@ -34,11 +33,9 @@ def setUp(self): scope="read write", expires=timezone.now() + timedelta(seconds=300), token="secret-access-token-key", - application=self.application + application=self.application, ) - oauth2_settings._SCOPES = ["read", "write"] - def test_access_denied(self): @protected_resource() def view(request, *args, **kwargs): diff --git a/tests/test_generator.py b/tests/test_generator.py index 211713b..cc79280 100644 --- a/tests/test_generator.py +++ b/tests/test_generator.py @@ -1,10 +1,7 @@ +import pytest from django.test import TestCase -from oauth2_provider.generators import ( - BaseHashGenerator, ClientIdGenerator, ClientSecretGenerator, - generate_client_id, generate_client_secret -) -from oauth2_provider.settings import oauth2_settings +from oauth2_provider.generators import BaseHashGenerator, generate_client_id, generate_client_secret class MockHashGenerator(BaseHashGenerator): @@ -12,23 +9,20 @@ def hash(self): return 42 +@pytest.mark.usefixtures("oauth2_settings") class TestGenerators(TestCase): - def tearDown(self): - oauth2_settings.CLIENT_ID_GENERATOR_CLASS = ClientIdGenerator - oauth2_settings.CLIENT_SECRET_GENERATOR_CLASS = ClientSecretGenerator - def test_generate_client_id(self): - g = oauth2_settings.CLIENT_ID_GENERATOR_CLASS() + g = self.oauth2_settings.CLIENT_ID_GENERATOR_CLASS() self.assertEqual(len(g.hash()), 40) - oauth2_settings.CLIENT_ID_GENERATOR_CLASS = MockHashGenerator + self.oauth2_settings.CLIENT_ID_GENERATOR_CLASS = MockHashGenerator self.assertEqual(generate_client_id(), 42) def test_generate_secret_id(self): - g = oauth2_settings.CLIENT_SECRET_GENERATOR_CLASS() + g = self.oauth2_settings.CLIENT_SECRET_GENERATOR_CLASS() self.assertEqual(len(g.hash()), 128) - oauth2_settings.CLIENT_SECRET_GENERATOR_CLASS = MockHashGenerator + self.oauth2_settings.CLIENT_SECRET_GENERATOR_CLASS = MockHashGenerator self.assertEqual(generate_client_secret(), 42) def test_basegen_misuse(self): diff --git a/tests/test_hybrid.py b/tests/test_hybrid.py index 1f45aee..d198988 100644 --- a/tests/test_hybrid.py +++ b/tests/test_hybrid.py @@ -3,20 +3,25 @@ import json from urllib.parse import parse_qs, urlencode, urlparse +import pytest from django.contrib.auth import get_user_model from django.test import RequestFactory, TestCase from django.urls import reverse from django.utils import timezone +from jwcrypto import jwt from oauthlib.oauth2.rfc6749 import errors as oauthlib_errors from oauth2_provider.models import ( - get_access_token_model, get_application_model, - get_grant_model, get_refresh_token_model + get_access_token_model, + get_application_model, + get_grant_model, + get_refresh_token_model, ) -from oauth2_provider.settings import oauth2_settings -from oauth2_provider.views import ProtectedResourceView +from oauth2_provider.oauth2_validators import OAuth2Validator +from oauth2_provider.views import ProtectedResourceView, ScopedProtectedResourceView -from .utils import get_basic_auth_header +from . import presets +from .utils import get_basic_auth_header, spy_on Application = get_application_model() @@ -32,13 +37,21 @@ def get(self, request, *args, **kwargs): return "This is a protected resource" +class ScopedResourceView(ScopedProtectedResourceView): + required_scopes = ["read"] + + def get(self, request, *args, **kwargs): + return "This is a protected resource" + + +@pytest.mark.usefixtures("oauth2_settings") class BaseTest(TestCase): def setUp(self): self.factory = RequestFactory() self.hy_test_user = UserModel.objects.create_user("hy_test_user", "test_hy@example.com", "123456") self.hy_dev_user = UserModel.objects.create_user("hy_dev_user", "dev_hy@example.com", "123456") - oauth2_settings.ALLOWED_REDIRECT_URI_SCHEMES = ["http", "custom-scheme"] + self.oauth2_settings.ALLOWED_REDIRECT_URI_SCHEMES = ["http", "custom-scheme"] self.application = Application( name="Hybrid Test Application", @@ -48,23 +61,17 @@ def setUp(self): user=self.hy_dev_user, client_type=Application.CLIENT_CONFIDENTIAL, authorization_grant_type=Application.GRANT_OPENID_HYBRID, + algorithm=Application.RS256_ALGORITHM, ) self.application.save() - oauth2_settings._SCOPES = ["read", "write", "openid"] - oauth2_settings._DEFAULT_SCOPES = ["read", "write"] - oauth2_settings.SCOPES = { - "read": "Reading scope", - "write": "Writing scope", - "openid": "OpenID connect" - } - def tearDown(self): self.application.delete() self.hy_test_user.delete() self.hy_dev_user.delete() +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) class TestRegressionIssue315Hybrid(BaseTest): """ Test to avoid regression for the issue 315: request object @@ -73,13 +80,15 @@ class TestRegressionIssue315Hybrid(BaseTest): def test_request_is_not_overwritten_code_token(self): self.client.login(username="hy_test_user", password="123456") - query_string = urlencode({ - "client_id": self.application.client_id, - "response_type": "code token", - "state": "random_state_string", - "scope": "openid read write", - "redirect_uri": "http://example.org", - }) + query_string = urlencode( + { + "client_id": self.application.client_id, + "response_type": "code token", + "state": "random_state_string", + "scope": "openid read write", + "redirect_uri": "http://example.org", + } + ) url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) response = self.client.get(url) @@ -88,14 +97,16 @@ def test_request_is_not_overwritten_code_token(self): def test_request_is_not_overwritten_code_id_token(self): self.client.login(username="hy_test_user", password="123456") - query_string = urlencode({ - "client_id": self.application.client_id, - "response_type": "code id_token", - "state": "random_state_string", - "scope": "openid read write", - "redirect_uri": "http://example.org", - "nonce": "nonce", - }) + query_string = urlencode( + { + "client_id": self.application.client_id, + "response_type": "code id_token", + "state": "random_state_string", + "scope": "openid read write", + "redirect_uri": "http://example.org", + "nonce": "nonce", + } + ) url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) response = self.client.get(url) @@ -104,14 +115,16 @@ def test_request_is_not_overwritten_code_id_token(self): def test_request_is_not_overwritten_code_id_token_token(self): self.client.login(username="hy_test_user", password="123456") - query_string = urlencode({ - "client_id": self.application.client_id, - "response_type": "code id_token token", - "state": "random_state_string", - "scope": "openid read write", - "redirect_uri": "http://example.org", - "nonce": "nonce", - }) + query_string = urlencode( + { + "client_id": self.application.client_id, + "response_type": "code id_token token", + "state": "random_state_string", + "scope": "openid read write", + "redirect_uri": "http://example.org", + "nonce": "nonce", + } + ) url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) response = self.client.get(url) @@ -119,6 +132,7 @@ def test_request_is_not_overwritten_code_id_token_token(self): assert "request" not in response.context_data +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) class TestHybridView(BaseTest): def test_skip_authorization_completely(self): """ @@ -128,13 +142,15 @@ def test_skip_authorization_completely(self): self.application.skip_authorization = True self.application.save() - query_string = urlencode({ - "client_id": self.application.client_id, - "response_type": "code", - "state": "random_state_string", - "scope": "read write", - "redirect_uri": "http://example.org", - }) + query_string = urlencode( + { + "client_id": self.application.client_id, + "response_type": "code", + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "http://example.org", + } + ) url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) response = self.client.get(url) @@ -148,13 +164,15 @@ def test_id_token_skip_authorization_completely(self): self.application.skip_authorization = True self.application.save() - query_string = urlencode({ - "client_id": self.application.client_id, - "response_type": "code", - "state": "random_state_string", - "scope": "openid", - "redirect_uri": "http://example.org", - }) + query_string = urlencode( + { + "client_id": self.application.client_id, + "response_type": "code", + "state": "random_state_string", + "scope": "openid", + "redirect_uri": "http://example.org", + } + ) url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) response = self.client.get(url) @@ -166,17 +184,19 @@ def test_pre_auth_invalid_client(self): """ self.client.login(username="hy_test_user", password="123456") - query_string = urlencode({ - "client_id": "fakeclientid", - "response_type": "code", - }) + query_string = urlencode( + { + "client_id": "fakeclientid", + "response_type": "code", + } + ) url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) response = self.client.get(url) self.assertEqual(response.status_code, 400) self.assertEqual( response.context_data["url"], - "?error=invalid_request&error_description=Invalid+client_id+parameter+value." + "?error=invalid_request&error_description=Invalid+client_id+parameter+value.", ) def test_pre_auth_valid_client(self): @@ -185,13 +205,15 @@ def test_pre_auth_valid_client(self): """ self.client.login(username="hy_test_user", password="123456") - query_string = urlencode({ - "client_id": self.application.client_id, - "response_type": "code id_token", - "state": "random_state_string", - "scope": "read write", - "redirect_uri": "http://example.org", - }) + query_string = urlencode( + { + "client_id": self.application.client_id, + "response_type": "code id_token", + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "http://example.org", + } + ) url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) response = self.client.get(url) @@ -212,14 +234,16 @@ def test_id_token_pre_auth_valid_client(self): """ self.client.login(username="hy_test_user", password="123456") - query_string = urlencode({ - "client_id": self.application.client_id, - "response_type": "code id_token", - "state": "random_state_string", - "scope": "openid", - "redirect_uri": "http://example.org", - "nonce": "nonce", - }) + query_string = urlencode( + { + "client_id": self.application.client_id, + "response_type": "code id_token", + "state": "random_state_string", + "scope": "openid", + "redirect_uri": "http://example.org", + "nonce": "nonce", + } + ) url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) response = self.client.get(url) @@ -241,13 +265,15 @@ def test_pre_auth_valid_client_custom_redirect_uri_scheme(self): """ self.client.login(username="hy_test_user", password="123456") - query_string = urlencode({ - "client_id": self.application.client_id, - "response_type": "code id_token", - "state": "random_state_string", - "scope": "read write", - "redirect_uri": "custom-scheme://example.com", - }) + query_string = urlencode( + { + "client_id": self.application.client_id, + "response_type": "code id_token", + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "custom-scheme://example.com", + } + ) url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) response = self.client.get(url) @@ -264,20 +290,23 @@ def test_pre_auth_valid_client_custom_redirect_uri_scheme(self): def test_pre_auth_approval_prompt(self): tok = AccessToken.objects.create( - user=self.hy_test_user, token="1234567890", + user=self.hy_test_user, + token="1234567890", application=self.application, expires=timezone.now() + datetime.timedelta(days=1), - scope="read write" + scope="read write", ) self.client.login(username="hy_test_user", password="123456") - query_string = urlencode({ - "client_id": self.application.client_id, - "response_type": "code id_token", - "state": "random_state_string", - "scope": "read write", - "redirect_uri": "http://example.org", - "approval_prompt": "auto", - }) + query_string = urlencode( + { + "client_id": self.application.client_id, + "response_type": "code id_token", + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "http://example.org", + "approval_prompt": "auto", + } + ) url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) response = self.client.get(url) self.assertEqual(response.status_code, 302) @@ -288,44 +317,50 @@ def test_pre_auth_approval_prompt(self): self.assertEqual(response.status_code, 200) def test_pre_auth_approval_prompt_default(self): - oauth2_settings.REQUEST_APPROVAL_PROMPT = "force" - self.assertEqual(oauth2_settings.REQUEST_APPROVAL_PROMPT, "force") + self.oauth2_settings.REQUEST_APPROVAL_PROMPT = "force" + self.assertEqual(self.oauth2_settings.REQUEST_APPROVAL_PROMPT, "force") AccessToken.objects.create( - user=self.hy_test_user, token="1234567890", + user=self.hy_test_user, + token="1234567890", application=self.application, expires=timezone.now() + datetime.timedelta(days=1), - scope="read write" + scope="read write", ) self.client.login(username="hy_test_user", password="123456") - query_string = urlencode({ - "client_id": self.application.client_id, - "response_type": "code id_token", - "state": "random_state_string", - "scope": "read write", - "redirect_uri": "http://example.org", - }) + query_string = urlencode( + { + "client_id": self.application.client_id, + "response_type": "code id_token", + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "http://example.org", + } + ) url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) response = self.client.get(url) self.assertEqual(response.status_code, 200) def test_pre_auth_approval_prompt_default_override(self): - oauth2_settings.REQUEST_APPROVAL_PROMPT = "auto" + self.oauth2_settings.REQUEST_APPROVAL_PROMPT = "auto" AccessToken.objects.create( - user=self.hy_test_user, token="1234567890", + user=self.hy_test_user, + token="1234567890", application=self.application, expires=timezone.now() + datetime.timedelta(days=1), - scope="read write" + scope="read write", ) self.client.login(username="hy_test_user", password="123456") - query_string = urlencode({ - "client_id": self.application.client_id, - "response_type": "code", - "state": "random_state_string", - "scope": "read write", - "redirect_uri": "http://example.org", - }) + query_string = urlencode( + { + "client_id": self.application.client_id, + "response_type": "code", + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "http://example.org", + } + ) url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) response = self.client.get(url) self.assertEqual(response.status_code, 302) @@ -336,10 +371,12 @@ def test_pre_auth_default_redirect(self): """ self.client.login(username="hy_test_user", password="123456") - query_string = urlencode({ - "client_id": self.application.client_id, - "response_type": "code id_token", - }) + query_string = urlencode( + { + "client_id": self.application.client_id, + "response_type": "code id_token", + } + ) url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) response = self.client.get(url) @@ -354,11 +391,13 @@ def test_pre_auth_forbibben_redirect(self): """ self.client.login(username="hy_test_user", password="123456") - query_string = urlencode({ - "client_id": self.application.client_id, - "response_type": "code", - "redirect_uri": "http://forbidden.it", - }) + query_string = urlencode( + { + "client_id": self.application.client_id, + "response_type": "code", + "redirect_uri": "http://forbidden.it", + } + ) url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) response = self.client.get(url) @@ -370,10 +409,12 @@ def test_pre_auth_wrong_response_type(self): """ self.client.login(username="hy_test_user", password="123456") - query_string = urlencode({ - "client_id": self.application.client_id, - "response_type": "WRONG", - }) + query_string = urlencode( + { + "client_id": self.application.client_id, + "response_type": "WRONG", + } + ) url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) response = self.client.get(url) @@ -753,6 +794,7 @@ def test_code_post_auth_fails_when_redirect_uri_path_is_invalid(self): self.assertEqual(response.status_code, 400) +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) class TestHybridTokenView(BaseTest): def get_auth(self, scope="read write"): """ @@ -782,7 +824,7 @@ def test_basic_auth(self): token_request_data = { "grant_type": "authorization_code", "code": authorization_code, - "redirect_uri": "http://example.org" + "redirect_uri": "http://example.org", } auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) @@ -792,7 +834,7 @@ def test_basic_auth(self): content = json.loads(response.content.decode("utf-8")) self.assertEqual(content["token_type"], "Bearer") self.assertEqual(content["scope"], "read write") - self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + self.assertEqual(content["expires_in"], self.oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) def test_basic_auth_bad_authcode(self): """ @@ -803,7 +845,7 @@ def test_basic_auth_bad_authcode(self): token_request_data = { "grant_type": "authorization_code", "code": "BLAH", - "redirect_uri": "http://example.org" + "redirect_uri": "http://example.org", } auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) @@ -816,11 +858,7 @@ def test_basic_auth_bad_granttype(self): """ self.client.login(username="hy_test_user", password="123456") - token_request_data = { - "grant_type": "UNKNOWN", - "code": "BLAH", - "redirect_uri": "http://example.org" - } + token_request_data = {"grant_type": "UNKNOWN", "code": "BLAH", "redirect_uri": "http://example.org"} auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) @@ -832,14 +870,19 @@ def test_basic_auth_grant_expired(self): """ self.client.login(username="hy_test_user", password="123456") g = Grant( - application=self.application, user=self.hy_test_user, code="BLAH", - expires=timezone.now(), redirect_uri="", scope="") + application=self.application, + user=self.hy_test_user, + code="BLAH", + expires=timezone.now(), + redirect_uri="", + scope="", + ) g.save() token_request_data = { "grant_type": "authorization_code", "code": "BLAH", - "redirect_uri": "http://example.org" + "redirect_uri": "http://example.org", } auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) @@ -856,7 +899,7 @@ def test_basic_auth_bad_secret(self): token_request_data = { "grant_type": "authorization_code", "code": authorization_code, - "redirect_uri": "http://example.org" + "redirect_uri": "http://example.org", } auth_headers = get_basic_auth_header(self.application.client_id, "BOOM!") @@ -873,7 +916,7 @@ def test_basic_auth_wrong_auth_type(self): token_request_data = { "grant_type": "authorization_code", "code": authorization_code, - "redirect_uri": "http://example.org" + "redirect_uri": "http://example.org", } user_pass = "{0}:{1}".format(self.application.client_id, self.application.client_secret) @@ -906,7 +949,7 @@ def test_request_body_params(self): content = json.loads(response.content.decode("utf-8")) self.assertEqual(content["token_type"], "Bearer") self.assertEqual(content["scope"], "read write") - self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + self.assertEqual(content["expires_in"], self.oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) def test_public(self): """ @@ -922,7 +965,7 @@ def test_public(self): "grant_type": "authorization_code", "code": authorization_code, "redirect_uri": "http://example.org", - "client_id": self.application.client_id + "client_id": self.application.client_id, } response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) @@ -931,7 +974,7 @@ def test_public(self): content = json.loads(response.content.decode("utf-8")) self.assertEqual(content["token_type"], "Bearer") self.assertEqual(content["scope"], "read write") - self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + self.assertEqual(content["expires_in"], self.oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) def test_id_token_public(self): """ @@ -959,7 +1002,7 @@ def test_id_token_public(self): self.assertEqual(content["scope"], "openid") self.assertIn("access_token", content) self.assertIn("id_token", content) - self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + self.assertEqual(content["expires_in"], self.oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) def test_malicious_redirect_uri(self): """ @@ -976,7 +1019,7 @@ def test_malicious_redirect_uri(self): "grant_type": "authorization_code", "code": authorization_code, "redirect_uri": "/../", - "client_id": self.application.client_id + "client_id": self.application.client_id, } response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) @@ -1008,7 +1051,7 @@ def test_code_exchange_succeed_when_redirect_uri_match(self): token_request_data = { "grant_type": "authorization_code", "code": authorization_code, - "redirect_uri": "http://example.org?foo=bar" + "redirect_uri": "http://example.org?foo=bar", } auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) @@ -1018,7 +1061,7 @@ def test_code_exchange_succeed_when_redirect_uri_match(self): content = json.loads(response.content.decode("utf-8")) self.assertEqual(content["token_type"], "Bearer") self.assertEqual(content["scope"], "openid read write") - self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + self.assertEqual(content["expires_in"], self.oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) def test_code_exchange_fails_when_redirect_uri_does_not_match(self): """ @@ -1043,7 +1086,7 @@ def test_code_exchange_fails_when_redirect_uri_does_not_match(self): token_request_data = { "grant_type": "authorization_code", "code": authorization_code, - "redirect_uri": "http://example.org?foo=baraa" + "redirect_uri": "http://example.org?foo=baraa", } auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) @@ -1078,7 +1121,7 @@ def test_code_exchange_succeed_when_redirect_uri_match_with_multiple_query_param token_request_data = { "grant_type": "authorization_code", "code": authorization_code, - "redirect_uri": "http://example.com?bar=baz&foo=bar" + "redirect_uri": "http://example.com?bar=baz&foo=bar", } auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) @@ -1088,7 +1131,7 @@ def test_code_exchange_succeed_when_redirect_uri_match_with_multiple_query_param content = json.loads(response.content.decode("utf-8")) self.assertEqual(content["token_type"], "Bearer") self.assertEqual(content["scope"], "openid read write") - self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + self.assertEqual(content["expires_in"], self.oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) def test_id_token_code_exchange_succeed_when_redirect_uri_match_with_multiple_query_params(self): """ @@ -1127,9 +1170,10 @@ def test_id_token_code_exchange_succeed_when_redirect_uri_match_with_multiple_qu self.assertEqual(content["scope"], "openid") self.assertIn("access_token", content) self.assertIn("id_token", content) - self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + self.assertEqual(content["expires_in"], self.oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) class TestHybridProtectedResource(BaseTest): def test_resource_access_allowed(self): self.client.login(username="hy_test_user", password="123456") @@ -1151,7 +1195,7 @@ def test_resource_access_allowed(self): token_request_data = { "grant_type": "authorization_code", "code": authorization_code, - "redirect_uri": "http://example.org" + "redirect_uri": "http://example.org", } auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) @@ -1221,6 +1265,11 @@ def test_id_token_resource_access_allowed(self): response = view(request) self.assertEqual(response, "This is a protected resource") + # If the resource requires more scopes than we requested, we should get an error + view = ScopedResourceView.as_view() + response = view(request) + self.assertEqual(response.status_code, 403) + def test_resource_access_deny(self): auth_headers = { "HTTP_AUTHORIZATION": "Bearer " + "faketoken", @@ -1233,21 +1282,22 @@ def test_resource_access_deny(self): self.assertEqual(response.status_code, 403) +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RO) class TestDefaultScopesHybrid(BaseTest): - def test_pre_auth_default_scopes(self): """ Test response for a valid client_id with response_type: code using default scopes """ self.client.login(username="hy_test_user", password="123456") - oauth2_settings._DEFAULT_SCOPES = ["read"] - query_string = urlencode({ - "client_id": self.application.client_id, - "response_type": "code token", - "state": "random_state_string", - "redirect_uri": "http://example.org", - }) + query_string = urlencode( + { + "client_id": self.application.client_id, + "response_type": "code token", + "state": "random_state_string", + "redirect_uri": "http://example.org", + } + ) url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) response = self.client.get(url) @@ -1261,4 +1311,121 @@ def test_pre_auth_default_scopes(self): self.assertEqual(form["state"].value(), "random_state_string") self.assertEqual(form["scope"].value(), "read") self.assertEqual(form["client_id"].value(), self.application.client_id) - oauth2_settings._DEFAULT_SCOPES = ["read", "write"] + + +@pytest.mark.django_db +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) +def test_id_token_nonce_in_token_response(oauth2_settings, test_user, hybrid_application, client, oidc_key): + client.force_login(test_user) + auth_rsp = client.post( + reverse("oauth2_provider:authorize"), + data={ + "client_id": hybrid_application.client_id, + "state": "random_state_string", + "scope": "openid", + "redirect_uri": "http://example.org", + "response_type": "code id_token", + "nonce": "random_nonce_string", + "allow": True, + }, + ) + assert auth_rsp.status_code == 302 + auth_data = parse_qs(urlparse(auth_rsp["Location"]).fragment) + assert "code" in auth_data + assert "id_token" in auth_data + # Decode the id token - is the nonce correct + jwt_token = jwt.JWT(key=oidc_key, jwt=auth_data["id_token"][0]) + claims = json.loads(jwt_token.claims) + assert "nonce" in claims + assert claims["nonce"] == "random_nonce_string" + code = auth_data["code"][0] + client.logout() + # Get the token response using the code + token_rsp = client.post( + reverse("oauth2_provider:token"), + data={ + "grant_type": "authorization_code", + "code": code, + "redirect_uri": "http://example.org", + "client_id": hybrid_application.client_id, + "client_secret": hybrid_application.client_secret, + "scope": "openid", + }, + ) + assert token_rsp.status_code == 200 + token_data = token_rsp.json() + assert "id_token" in token_data + # The nonce should be present in this id token also + jwt_token = jwt.JWT(key=oidc_key, jwt=token_data["id_token"]) + claims = json.loads(jwt_token.claims) + assert "nonce" in claims + assert claims["nonce"] == "random_nonce_string" + + +@pytest.mark.django_db +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) +def test_claims_passed_to_code_generation( + oauth2_settings, test_user, hybrid_application, client, mocker, oidc_key +): + # Add a spy on to OAuth2Validator.finalize_id_token + mocker.patch.object( + OAuth2Validator, + "finalize_id_token", + spy_on(OAuth2Validator.finalize_id_token), + ) + claims = {"id_token": {"email": {"essential": True}}} + client.force_login(test_user) + auth_form_rsp = client.get( + reverse("oauth2_provider:authorize"), + data={ + "client_id": hybrid_application.client_id, + "state": "random_state_string", + "scope": "openid", + "redirect_uri": "http://example.org", + "response_type": "code id_token", + "nonce": "random_nonce_string", + "claims": json.dumps(claims), + }, + ) + # Check that claims has made it in to the form to be submitted + assert auth_form_rsp.status_code == 200 + form_initial_data = auth_form_rsp.context_data["form"].initial + assert "claims" in form_initial_data + assert json.loads(form_initial_data["claims"]) == claims + # Filter out not specified values + form_data = {key: value for key, value in form_initial_data.items() if value is not None} + # Now submitting the form (with allow=True) should persist requested claims + auth_rsp = client.post( + reverse("oauth2_provider:authorize"), + data={"allow": True, **form_data}, + ) + assert auth_rsp.status_code == 302 + auth_data = parse_qs(urlparse(auth_rsp["Location"]).fragment) + assert "code" in auth_data + assert "id_token" in auth_data + assert OAuth2Validator.finalize_id_token.spy.call_count == 1 + oauthlib_request = OAuth2Validator.finalize_id_token.spy.call_args[0][4] + assert oauthlib_request.claims == claims + assert Grant.objects.get().claims == json.dumps(claims) + OAuth2Validator.finalize_id_token.spy.reset_mock() + + # Get the token response using the code + client.logout() + code = auth_data["code"][0] + token_rsp = client.post( + reverse("oauth2_provider:token"), + data={ + "grant_type": "authorization_code", + "code": code, + "redirect_uri": "http://example.org", + "client_id": hybrid_application.client_id, + "client_secret": hybrid_application.client_secret, + "scope": "openid", + }, + ) + assert token_rsp.status_code == 200 + token_data = token_rsp.json() + assert "id_token" in token_data + assert OAuth2Validator.finalize_id_token.spy.call_count == 1 + oauthlib_request = OAuth2Validator.finalize_id_token.spy.call_args[0][4] + assert oauthlib_request.claims == claims diff --git a/tests/test_implicit.py b/tests/test_implicit.py index 4e8879a..a586340 100644 --- a/tests/test_implicit.py +++ b/tests/test_implicit.py @@ -1,15 +1,17 @@ import json -from urllib.parse import parse_qs, urlencode, urlparse +from urllib.parse import parse_qs, urlparse +import pytest from django.contrib.auth import get_user_model from django.test import RequestFactory, TestCase from django.urls import reverse -from jwcrypto import jwk, jwt +from jwcrypto import jwt from oauth2_provider.models import get_application_model -from oauth2_provider.settings import oauth2_settings from oauth2_provider.views import ProtectedResourceView +from . import presets + Application = get_application_model() UserModel = get_user_model() @@ -21,6 +23,7 @@ def get(self, request, *args, **kwargs): return "This is a protected resource" +@pytest.mark.usefixtures("oauth2_settings") class BaseTest(TestCase): def setUp(self): self.factory = RequestFactory() @@ -35,36 +38,27 @@ def setUp(self): authorization_grant_type=Application.GRANT_IMPLICIT, ) - oauth2_settings._SCOPES = ["read", "write", "openid"] - oauth2_settings._DEFAULT_SCOPES = ["read"] - oauth2_settings.SCOPES = { - "read": "Reading scope", - "write": "Writing scope", - "openid": "OpenID connect" - } - self.key = jwk.JWK.from_pem(oauth2_settings.OIDC_RSA_PRIVATE_KEY.encode("utf8")) - def tearDown(self): self.application.delete() self.test_user.delete() self.dev_user.delete() +@pytest.mark.oauth2_settings(presets.DEFAULT_SCOPES_RO) class TestImplicitAuthorizationCodeView(BaseTest): def test_pre_auth_valid_client_default_scopes(self): """ Test response for a valid client_id with response_type: token and default_scopes """ self.client.login(username="test_user", password="123456") - query_string = urlencode({ + query_data = { "client_id": self.application.client_id, "response_type": "token", "state": "random_state_string", "redirect_uri": "http://example.org", - }) + } - url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) - response = self.client.get(url) + response = self.client.get(reverse("oauth2_provider:authorize"), data=query_data) self.assertEqual(response.status_code, 200) self.assertIn("form", response.context) @@ -77,16 +71,15 @@ def test_pre_auth_valid_client(self): """ self.client.login(username="test_user", password="123456") - query_string = urlencode({ + query_data = { "client_id": self.application.client_id, "response_type": "token", "state": "random_state_string", "scope": "read write", "redirect_uri": "http://example.org", - }) - url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + } - response = self.client.get(url) + response = self.client.get(reverse("oauth2_provider:authorize"), data=query_data) self.assertEqual(response.status_code, 200) # check form is in context and form params are valid @@ -104,13 +97,12 @@ def test_pre_auth_invalid_client(self): """ self.client.login(username="test_user", password="123456") - query_string = urlencode({ + query_data = { "client_id": "fakeclientid", "response_type": "token", - }) - url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + } - response = self.client.get(url) + response = self.client.get(reverse("oauth2_provider:authorize"), data=query_data) self.assertEqual(response.status_code, 400) def test_pre_auth_default_redirect(self): @@ -119,13 +111,12 @@ def test_pre_auth_default_redirect(self): """ self.client.login(username="test_user", password="123456") - query_string = urlencode({ + query_data = { "client_id": self.application.client_id, "response_type": "token", - }) - url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + } - response = self.client.get(url) + response = self.client.get(reverse("oauth2_provider:authorize"), data=query_data) self.assertEqual(response.status_code, 200) form = response.context["form"] @@ -137,14 +128,13 @@ def test_pre_auth_forbibben_redirect(self): """ self.client.login(username="test_user", password="123456") - query_string = urlencode({ + query_data = { "client_id": self.application.client_id, "response_type": "token", "redirect_uri": "http://forbidden.it", - }) - url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + } - response = self.client.get(url) + response = self.client.get(reverse("oauth2_provider:authorize"), data=query_data) self.assertEqual(response.status_code, 400) def test_post_auth_allow(self): @@ -176,17 +166,15 @@ def test_skip_authorization_completely(self): self.application.skip_authorization = True self.application.save() - query_string = urlencode({ + query_data = { "client_id": self.application.client_id, "response_type": "token", "state": "random_state_string", "scope": "read write", "redirect_uri": "http://example.org", - }) - - url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + } - response = self.client.get(url) + response = self.client.get(reverse("oauth2_provider:authorize"), data=query_data) self.assertEqual(response.status_code, 302) self.assertIn("http://example.org#", response["Location"]) self.assertIn("access_token=", response["Location"]) @@ -252,6 +240,7 @@ def test_implicit_fails_when_redirect_uri_path_is_invalid(self): self.assertEqual(response.status_code, 400) +@pytest.mark.oauth2_settings(presets.DEFAULT_SCOPES_RO) class TestImplicitTokenView(BaseTest): def test_resource_access_allowed(self): self.client.login(username="test_user", password="123456") @@ -282,7 +271,14 @@ def test_resource_access_allowed(self): self.assertEqual(response, "This is a protected resource") +@pytest.mark.usefixtures("oidc_key") +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) class TestOpenIDConnectImplicitFlow(BaseTest): + def setUp(self): + super().setUp() + self.application.algorithm = Application.RS256_ALGORITHM + self.application.save() + def test_id_token_post_auth_allow(self): """ Test authorization code is given for an allowed request with response_type: id_token @@ -322,18 +318,16 @@ def test_id_token_skip_authorization_completely(self): self.application.skip_authorization = True self.application.save() - query_string = urlencode({ + query_data = { "client_id": self.application.client_id, "response_type": "id_token", "state": "random_state_string", "nonce": "random_nonce_string", "scope": "openid", "redirect_uri": "http://example.org", - }) - - url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + } - response = self.client.get(url) + response = self.client.get(reverse("oauth2_provider:authorize"), data=query_data) self.assertEqual(response.status_code, 302) self.assertIn("http://example.org#", response["Location"]) self.assertNotIn("access_token=", response["Location"]) @@ -356,17 +350,15 @@ def test_id_token_skip_authorization_completely_missing_nonce(self): self.application.skip_authorization = True self.application.save() - query_string = urlencode({ + query_data = { "client_id": self.application.client_id, "response_type": "id_token", "state": "random_state_string", "scope": "openid", "redirect_uri": "http://example.org", - }) - - url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + } - response = self.client.get(url) + response = self.client.get(reverse("oauth2_provider:authorize"), data=query_data) self.assertEqual(response.status_code, 302) self.assertIn("error=invalid_request", response["Location"]) self.assertIn("error_description=Request+is+missing+mandatory+nonce+paramete", response["Location"]) @@ -430,18 +422,16 @@ def test_access_token_and_id_token_skip_authorization_completely(self): self.application.skip_authorization = True self.application.save() - query_string = urlencode({ + query_data = { "client_id": self.application.client_id, "response_type": "id_token token", "state": "random_state_string", "nonce": "random_nonce_string", "scope": "openid", "redirect_uri": "http://example.org", - }) - - url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + } - response = self.client.get(url) + response = self.client.get(reverse("oauth2_provider:authorize"), data=query_data) self.assertEqual(response.status_code, 302) self.assertIn("http://example.org#", response["Location"]) self.assertIn("access_token=", response["Location"]) diff --git a/tests/test_introspection_auth.py b/tests/test_introspection_auth.py index db37f6c..8b2a6da 100644 --- a/tests/test_introspection_auth.py +++ b/tests/test_introspection_auth.py @@ -1,10 +1,13 @@ import calendar import datetime -from django.conf.urls import include, url +import pytest +from django.conf import settings +from django.conf.urls import include from django.contrib.auth import get_user_model from django.http import HttpResponse from django.test import TestCase, override_settings +from django.urls import path from django.utils import timezone from oauthlib.common import Request @@ -13,6 +16,8 @@ from oauth2_provider.settings import oauth2_settings from oauth2_provider.views import ScopedProtectedResourceView +from . import presets + try: from unittest import mock @@ -41,6 +46,7 @@ def mocked_requests_post(url, data, *args, **kwargs): """ Mock the response from the authentication server """ + class MockResponse: def __init__(self, json_data, status_code): self.json_data = json_data @@ -50,30 +56,39 @@ def json(self): return self.json_data if "token" in data and data["token"] and data["token"] != "12345678900": - return MockResponse({ - "active": True, - "scope": "read write dolphin", - "client_id": "client_id_{}".format(data["token"]), - "username": "{}_user".format(data["token"]), - "exp": int(calendar.timegm(exp.timetuple())), - }, 200) + return MockResponse( + { + "active": True, + "scope": "read write dolphin", + "client_id": "client_id_{}".format(data["token"]), + "username": "{}_user".format(data["token"]), + "exp": int(calendar.timegm(exp.timetuple())), + }, + 200, + ) - return MockResponse({ - "active": False, - }, 200) + return MockResponse( + { + "active": False, + }, + 200, + ) urlpatterns = [ - url(r"^oauth2/", include("oauth2_provider.urls")), - url(r"^oauth2-test-resource/$", ScopeResourceView.as_view()), + path("oauth2/", include("oauth2_provider.urls")), + path("oauth2-test-resource/", ScopeResourceView.as_view()), ] @override_settings(ROOT_URLCONF=__name__) +@pytest.mark.usefixtures("oauth2_settings") +@pytest.mark.oauth2_settings(presets.INTROSPECTION_SETTINGS) class TestTokenIntrospectionAuth(TestCase): """ Tests for Authorization through token introspection """ + def setUp(self): self.validator = OAuth2Validator() self.request = mock.MagicMock(wraps=Request) @@ -90,29 +105,24 @@ def setUp(self): ) self.resource_server_token = AccessToken.objects.create( - user=self.resource_server_user, token="12345678900", + user=self.resource_server_user, + token="12345678900", application=self.application, expires=timezone.now() + datetime.timedelta(days=1), - scope="introspection" + scope="introspection", ) self.invalid_token = AccessToken.objects.create( - user=self.resource_server_user, token="12345678901", + user=self.resource_server_user, + token="12345678901", application=self.application, expires=timezone.now() + datetime.timedelta(days=-1), - scope="read write dolphin" + scope="read write dolphin", ) - oauth2_settings._SCOPES = ["read", "write", "introspection", "dolphin"] - oauth2_settings.RESOURCE_SERVER_INTROSPECTION_URL = "http://example.org/introspection" - oauth2_settings.RESOURCE_SERVER_AUTH_TOKEN = self.resource_server_token.token - oauth2_settings.READ_SCOPE = "read" - oauth2_settings.WRITE_SCOPE = "write" + self.oauth2_settings.RESOURCE_SERVER_AUTH_TOKEN = self.resource_server_token.token def tearDown(self): - oauth2_settings._SCOPES = ["read", "write"] - oauth2_settings.RESOURCE_SERVER_INTROSPECTION_URL = None - oauth2_settings.RESOURCE_SERVER_AUTH_TOKEN = None self.resource_server_token.delete() self.application.delete() AccessToken.objects.all().delete() @@ -125,9 +135,9 @@ def test_get_token_from_authentication_server_not_existing_token(self, mock_get) """ token = self.validator._get_token_from_authentication_server( self.resource_server_token.token, - oauth2_settings.RESOURCE_SERVER_INTROSPECTION_URL, - oauth2_settings.RESOURCE_SERVER_AUTH_TOKEN, - oauth2_settings.RESOURCE_SERVER_INTROSPECTION_CREDENTIALS + self.oauth2_settings.RESOURCE_SERVER_INTROSPECTION_URL, + self.oauth2_settings.RESOURCE_SERVER_AUTH_TOKEN, + self.oauth2_settings.RESOURCE_SERVER_INTROSPECTION_CREDENTIALS, ) self.assertIsNone(token) @@ -138,14 +148,33 @@ def test_get_token_from_authentication_server_existing_token(self, mock_get): """ token = self.validator._get_token_from_authentication_server( "foo", - oauth2_settings.RESOURCE_SERVER_INTROSPECTION_URL, - oauth2_settings.RESOURCE_SERVER_AUTH_TOKEN, - oauth2_settings.RESOURCE_SERVER_INTROSPECTION_CREDENTIALS + self.oauth2_settings.RESOURCE_SERVER_INTROSPECTION_URL, + self.oauth2_settings.RESOURCE_SERVER_AUTH_TOKEN, + self.oauth2_settings.RESOURCE_SERVER_INTROSPECTION_CREDENTIALS, ) self.assertIsInstance(token, AccessToken) self.assertEqual(token.user.username, "foo_user") self.assertEqual(token.scope, "read write dolphin") + @mock.patch("requests.post", side_effect=mocked_requests_post) + def test_get_token_from_authentication_server_expires_timezone(self, mock_get): + """ + Test method _get_token_from_authentication_server for projects with USE_TZ False + """ + settings_use_tz_backup = settings.USE_TZ + settings.USE_TZ = False + try: + self.validator._get_token_from_authentication_server( + "foo", + oauth2_settings.RESOURCE_SERVER_INTROSPECTION_URL, + oauth2_settings.RESOURCE_SERVER_AUTH_TOKEN, + oauth2_settings.RESOURCE_SERVER_INTROSPECTION_CREDENTIALS, + ) + except ValueError as exception: + self.fail(str(exception)) + finally: + settings.USE_TZ = settings_use_tz_backup + @mock.patch("requests.post", side_effect=mocked_requests_post) def test_validate_bearer_token(self, mock_get): """ diff --git a/tests/test_introspection_view.py b/tests/test_introspection_view.py index a06a73e..0f68320 100644 --- a/tests/test_introspection_view.py +++ b/tests/test_introspection_view.py @@ -1,13 +1,16 @@ import calendar import datetime +import pytest from django.contrib.auth import get_user_model from django.test import TestCase from django.urls import reverse from django.utils import timezone from oauth2_provider.models import get_access_token_model, get_application_model -from oauth2_provider.settings import oauth2_settings + +from . import presets +from .utils import get_basic_auth_header Application = get_application_model() @@ -15,10 +18,13 @@ UserModel = get_user_model() +@pytest.mark.usefixtures("oauth2_settings") +@pytest.mark.oauth2_settings(presets.INTROSPECTION_SETTINGS) class TestTokenIntrospectionViews(TestCase): """ Tests for Authorized Token Introspection Views """ + def setUp(self): self.resource_server_user = UserModel.objects.create_user("resource_server", "test@example.com") self.test_user = UserModel.objects.create_user("bar_user", "dev@example.com") @@ -32,46 +38,46 @@ def setUp(self): ) self.resource_server_token = AccessToken.objects.create( - user=self.resource_server_user, token="12345678900", + user=self.resource_server_user, + token="12345678900", application=self.application, expires=timezone.now() + datetime.timedelta(days=1), - scope="introspection" + scope="introspection", ) self.valid_token = AccessToken.objects.create( - user=self.test_user, token="12345678901", + user=self.test_user, + token="12345678901", application=self.application, expires=timezone.now() + datetime.timedelta(days=1), - scope="read write dolphin" + scope="read write dolphin", ) self.invalid_token = AccessToken.objects.create( - user=self.test_user, token="12345678902", + user=self.test_user, + token="12345678902", application=self.application, expires=timezone.now() + datetime.timedelta(days=-1), - scope="read write dolphin" + scope="read write dolphin", ) self.token_without_user = AccessToken.objects.create( - user=None, token="12345678903", + user=None, + token="12345678903", application=self.application, expires=timezone.now() + datetime.timedelta(days=1), - scope="read write dolphin" + scope="read write dolphin", ) self.token_without_app = AccessToken.objects.create( - user=self.test_user, token="12345678904", + user=self.test_user, + token="12345678904", application=None, expires=timezone.now() + datetime.timedelta(days=1), - scope="read write dolphin" + scope="read write dolphin", ) - oauth2_settings._SCOPES = ["read", "write", "introspection", "dolphin"] - oauth2_settings.READ_SCOPE = "read" - oauth2_settings.WRITE_SCOPE = "write" - def tearDown(self): - oauth2_settings._SCOPES = ["read", "write"] AccessToken.objects.all().delete() Application.objects.all().delete() UserModel.objects.all().delete() @@ -92,20 +98,22 @@ def test_view_get_valid_token(self): "HTTP_AUTHORIZATION": "Bearer " + self.resource_server_token.token, } response = self.client.get( - reverse("oauth2_provider:introspect"), - {"token": self.valid_token.token}, - **auth_headers) + reverse("oauth2_provider:introspect"), {"token": self.valid_token.token}, **auth_headers + ) self.assertEqual(response.status_code, 200) content = response.json() self.assertIsInstance(content, dict) - self.assertDictEqual(content, { - "active": True, - "scope": self.valid_token.scope, - "client_id": self.valid_token.application.client_id, - "username": self.valid_token.user.get_username(), - "exp": int(calendar.timegm(self.valid_token.expires.timetuple())), - }) + self.assertDictEqual( + content, + { + "active": True, + "scope": self.valid_token.scope, + "client_id": self.valid_token.application.client_id, + "username": self.valid_token.user.get_username(), + "exp": int(calendar.timegm(self.valid_token.expires.timetuple())), + }, + ) def test_view_get_valid_token_without_user(self): """ @@ -116,19 +124,21 @@ def test_view_get_valid_token_without_user(self): "HTTP_AUTHORIZATION": "Bearer " + self.resource_server_token.token, } response = self.client.get( - reverse("oauth2_provider:introspect"), - {"token": self.token_without_user.token}, - **auth_headers) + reverse("oauth2_provider:introspect"), {"token": self.token_without_user.token}, **auth_headers + ) self.assertEqual(response.status_code, 200) content = response.json() self.assertIsInstance(content, dict) - self.assertDictEqual(content, { - "active": True, - "scope": self.token_without_user.scope, - "client_id": self.token_without_user.application.client_id, - "exp": int(calendar.timegm(self.token_without_user.expires.timetuple())), - }) + self.assertDictEqual( + content, + { + "active": True, + "scope": self.token_without_user.scope, + "client_id": self.token_without_user.application.client_id, + "exp": int(calendar.timegm(self.token_without_user.expires.timetuple())), + }, + ) def test_view_get_valid_token_without_app(self): """ @@ -139,19 +149,21 @@ def test_view_get_valid_token_without_app(self): "HTTP_AUTHORIZATION": "Bearer " + self.resource_server_token.token, } response = self.client.get( - reverse("oauth2_provider:introspect"), - {"token": self.token_without_app.token}, - **auth_headers) + reverse("oauth2_provider:introspect"), {"token": self.token_without_app.token}, **auth_headers + ) self.assertEqual(response.status_code, 200) content = response.json() self.assertIsInstance(content, dict) - self.assertDictEqual(content, { - "active": True, - "scope": self.token_without_app.scope, - "username": self.token_without_app.user.get_username(), - "exp": int(calendar.timegm(self.token_without_app.expires.timetuple())), - }) + self.assertDictEqual( + content, + { + "active": True, + "scope": self.token_without_app.scope, + "username": self.token_without_app.user.get_username(), + "exp": int(calendar.timegm(self.token_without_app.expires.timetuple())), + }, + ) def test_view_get_invalid_token(self): """ @@ -162,16 +174,18 @@ def test_view_get_invalid_token(self): "HTTP_AUTHORIZATION": "Bearer " + self.resource_server_token.token, } response = self.client.get( - reverse("oauth2_provider:introspect"), - {"token": self.invalid_token.token}, - **auth_headers) + reverse("oauth2_provider:introspect"), {"token": self.invalid_token.token}, **auth_headers + ) self.assertEqual(response.status_code, 200) content = response.json() self.assertIsInstance(content, dict) - self.assertDictEqual(content, { - "active": False, - }) + self.assertDictEqual( + content, + { + "active": False, + }, + ) def test_view_get_notexisting_token(self): """ @@ -182,16 +196,18 @@ def test_view_get_notexisting_token(self): "HTTP_AUTHORIZATION": "Bearer " + self.resource_server_token.token, } response = self.client.get( - reverse("oauth2_provider:introspect"), - {"token": "kaudawelsch"}, - **auth_headers) + reverse("oauth2_provider:introspect"), {"token": "kaudawelsch"}, **auth_headers + ) self.assertEqual(response.status_code, 401) content = response.json() self.assertIsInstance(content, dict) - self.assertDictEqual(content, { - "active": False, - }) + self.assertDictEqual( + content, + { + "active": False, + }, + ) def test_view_post_valid_token(self): """ @@ -202,20 +218,22 @@ def test_view_post_valid_token(self): "HTTP_AUTHORIZATION": "Bearer " + self.resource_server_token.token, } response = self.client.post( - reverse("oauth2_provider:introspect"), - {"token": self.valid_token.token}, - **auth_headers) + reverse("oauth2_provider:introspect"), {"token": self.valid_token.token}, **auth_headers + ) self.assertEqual(response.status_code, 200) content = response.json() self.assertIsInstance(content, dict) - self.assertDictEqual(content, { - "active": True, - "scope": self.valid_token.scope, - "client_id": self.valid_token.application.client_id, - "username": self.valid_token.user.get_username(), - "exp": int(calendar.timegm(self.valid_token.expires.timetuple())), - }) + self.assertDictEqual( + content, + { + "active": True, + "scope": self.valid_token.scope, + "client_id": self.valid_token.application.client_id, + "username": self.valid_token.user.get_username(), + "exp": int(calendar.timegm(self.valid_token.expires.timetuple())), + }, + ) def test_view_post_invalid_token(self): """ @@ -226,16 +244,18 @@ def test_view_post_invalid_token(self): "HTTP_AUTHORIZATION": "Bearer " + self.resource_server_token.token, } response = self.client.post( - reverse("oauth2_provider:introspect"), - {"token": self.invalid_token.token}, - **auth_headers) + reverse("oauth2_provider:introspect"), {"token": self.invalid_token.token}, **auth_headers + ) self.assertEqual(response.status_code, 200) content = response.json() self.assertIsInstance(content, dict) - self.assertDictEqual(content, { - "active": False, - }) + self.assertDictEqual( + content, + { + "active": False, + }, + ) def test_view_post_notexisting_token(self): """ @@ -246,13 +266,85 @@ def test_view_post_notexisting_token(self): "HTTP_AUTHORIZATION": "Bearer " + self.resource_server_token.token, } response = self.client.post( - reverse("oauth2_provider:introspect"), - {"token": "kaudawelsch"}, - **auth_headers) + reverse("oauth2_provider:introspect"), {"token": "kaudawelsch"}, **auth_headers + ) self.assertEqual(response.status_code, 401) content = response.json() self.assertIsInstance(content, dict) - self.assertDictEqual(content, { - "active": False, - }) + self.assertDictEqual( + content, + { + "active": False, + }, + ) + + def test_view_post_valid_client_creds_basic_auth(self): + """Test HTTP basic auth working""" + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) + response = self.client.post( + reverse("oauth2_provider:introspect"), {"token": self.valid_token.token}, **auth_headers + ) + self.assertEqual(response.status_code, 200) + content = response.json() + self.assertIsInstance(content, dict) + self.assertDictEqual( + content, + { + "active": True, + "scope": self.valid_token.scope, + "client_id": self.valid_token.application.client_id, + "username": self.valid_token.user.get_username(), + "exp": int(calendar.timegm(self.valid_token.expires.timetuple())), + }, + ) + + def test_view_post_invalid_client_creds_basic_auth(self): + """Must fail for invalid client credentials""" + auth_headers = get_basic_auth_header( + self.application.client_id, self.application.client_secret + "_so_wrong" + ) + response = self.client.post( + reverse("oauth2_provider:introspect"), {"token": self.valid_token.token}, **auth_headers + ) + self.assertEqual(response.status_code, 403) + + def test_view_post_valid_client_creds_plaintext(self): + """Test introspecting with credentials in request body""" + response = self.client.post( + reverse("oauth2_provider:introspect"), + { + "token": self.valid_token.token, + "client_id": self.application.client_id, + "client_secret": self.application.client_secret, + }, + ) + self.assertEqual(response.status_code, 200) + content = response.json() + self.assertIsInstance(content, dict) + self.assertDictEqual( + content, + { + "active": True, + "scope": self.valid_token.scope, + "client_id": self.valid_token.application.client_id, + "username": self.valid_token.user.get_username(), + "exp": int(calendar.timegm(self.valid_token.expires.timetuple())), + }, + ) + + def test_view_post_invalid_client_creds_plaintext(self): + """Must fail for invalid creds in request body.""" + response = self.client.post( + reverse("oauth2_provider:introspect"), + { + "token": self.valid_token.token, + "client_id": self.application.client_id, + "client_secret": self.application.client_secret + "_so_wrong", + }, + ) + self.assertEqual(response.status_code, 403) + + def test_select_related_in_view_for_less_db_queries(self): + with self.assertNumQueries(1): + self.client.post(reverse("oauth2_provider:introspect")) diff --git a/tests/test_mixins.py b/tests/test_mixins.py index 79988c9..1294b75 100644 --- a/tests/test_mixins.py +++ b/tests/test_mixins.py @@ -1,4 +1,8 @@ +import logging + +import pytest from django.core.exceptions import ImproperlyConfigured +from django.http import HttpResponse from django.test import RequestFactory, TestCase from django.views.generic import View from oauthlib.oauth2 import Server @@ -6,44 +10,73 @@ from oauth2_provider.oauth2_backends import OAuthLibCore from oauth2_provider.oauth2_validators import OAuth2Validator from oauth2_provider.views.mixins import ( - OAuthLibMixin, ProtectedResourceMixin, ScopedResourceMixin + OAuthLibMixin, + OIDCOnlyMixin, + ProtectedResourceMixin, + ScopedResourceMixin, ) +from . import presets + +@pytest.mark.usefixtures("oauth2_settings") class BaseTest(TestCase): @classmethod def setUpClass(cls): cls.request_factory = RequestFactory() - super(BaseTest, cls).setUpClass() + super().setUpClass() class TestOAuthLibMixin(BaseTest): - def test_missing_oauthlib_backend_class(self): + def test_missing_oauthlib_backend_class_uses_fallback(self): + class CustomOauthLibBackend: + def __init__(self, *args, **kwargs): + pass + + self.oauth2_settings.OAUTH2_BACKEND_CLASS = CustomOauthLibBackend + class TestView(OAuthLibMixin, View): server_class = Server validator_class = OAuth2Validator test_view = TestView() - self.assertRaises(ImproperlyConfigured, test_view.get_oauthlib_backend_class) + self.assertEqual(CustomOauthLibBackend, test_view.get_oauthlib_backend_class()) + core = test_view.get_oauthlib_core() + self.assertTrue(isinstance(core, CustomOauthLibBackend)) + + def test_missing_server_class_uses_fallback(self): + class CustomServer: + def __init__(self, *args, **kwargs): + pass + + self.oauth2_settings.OAUTH2_SERVER_CLASS = CustomServer - def test_missing_server_class(self): class TestView(OAuthLibMixin, View): validator_class = OAuth2Validator oauthlib_backend_class = OAuthLibCore test_view = TestView() - self.assertRaises(ImproperlyConfigured, test_view.get_server) + self.assertEqual(CustomServer, test_view.get_server_class()) + core = test_view.get_oauthlib_core() + self.assertTrue(isinstance(core.server, CustomServer)) + + def test_missing_validator_class_uses_fallback(self): + class CustomValidator: + pass + + self.oauth2_settings.OAUTH2_VALIDATOR_CLASS = CustomValidator - def test_missing_validator_class(self): class TestView(OAuthLibMixin, View): server_class = Server oauthlib_backend_class = OAuthLibCore test_view = TestView() - self.assertRaises(ImproperlyConfigured, test_view.get_server) + self.assertEqual(CustomValidator, test_view.get_validator_class()) + core = test_view.get_oauthlib_core() + self.assertTrue(isinstance(core.server.request_validator, CustomValidator)) def test_correct_server(self): class TestView(OAuthLibMixin, View): @@ -58,7 +91,7 @@ class TestView(OAuthLibMixin, View): self.assertIsInstance(test_view.get_server(), Server) def test_custom_backend(self): - class AnotherOauthLibBackend(object): + class AnotherOauthLibBackend: pass class TestView(OAuthLibMixin, View): @@ -70,9 +103,7 @@ class TestView(OAuthLibMixin, View): request.user = "fake" test_view = TestView() - self.assertEqual( - test_view.get_oauthlib_backend_class(), AnotherOauthLibBackend - ) + self.assertEqual(test_view.get_oauthlib_backend_class(), AnotherOauthLibBackend) class TestScopedResourceMixin(BaseTest): @@ -103,3 +134,38 @@ class TestView(ProtectedResourceMixin, View): view = TestView.as_view() response = view(request) self.assertEqual(response.status_code, 200) + + +@pytest.fixture +def oidc_only_view(): + class TView(OIDCOnlyMixin, View): + def get(self, *args, **kwargs): + return HttpResponse("OK") + + return TView.as_view() + + +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) +def test_oidc_only_mixin_oidc_enabled(oauth2_settings, rf, oidc_only_view): + assert oauth2_settings.OIDC_ENABLED + rsp = oidc_only_view(rf.get("/")) + assert rsp.status_code == 200 + assert rsp.content.decode("utf-8") == "OK" + + +def test_oidc_only_mixin_oidc_disabled_debug(oauth2_settings, rf, settings, oidc_only_view): + assert oauth2_settings.OIDC_ENABLED is False + settings.DEBUG = True + with pytest.raises(ImproperlyConfigured) as exc: + oidc_only_view(rf.get("/")) + assert "OIDC views are not enabled" in str(exc.value) + + +def test_oidc_only_mixin_oidc_disabled_no_debug(oauth2_settings, rf, settings, oidc_only_view, caplog): + assert oauth2_settings.OIDC_ENABLED is False + settings.DEBUG = False + with caplog.at_level(logging.WARNING, logger="oauth2_provider"): + rsp = oidc_only_view(rf.get("/")) + assert rsp.status_code == 404 + assert len(caplog.records) == 1 + assert "OIDC views are not enabled" in caplog.records[0].message diff --git a/tests/test_models.py b/tests/test_models.py index 95e8eb4..7b37486 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -6,10 +6,15 @@ from django.utils import timezone from oauth2_provider.models import ( - clear_expired, get_access_token_model, get_application_model, - get_grant_model, get_refresh_token_model + clear_expired, + get_access_token_model, + get_application_model, + get_grant_model, + get_id_token_model, + get_refresh_token_model, ) -from oauth2_provider.settings import oauth2_settings + +from . import presets Application = get_application_model() @@ -17,13 +22,18 @@ AccessToken = get_access_token_model() RefreshToken = get_refresh_token_model() UserModel = get_user_model() +IDToken = get_id_token_model() -class TestModels(TestCase): - +class BaseTestModels(TestCase): def setUp(self): self.user = UserModel.objects.create_user("test_user", "test@example.com", "123456") + def tearDown(self): + self.user.delete() + + +class TestModels(BaseTestModels): def test_allow_scopes(self): self.client.login(username="test_user", password="123456") app = Application.objects.create( @@ -34,13 +44,7 @@ def test_allow_scopes(self): authorization_grant_type=Application.GRANT_AUTHORIZATION_CODE, ) - access_token = AccessToken( - user=self.user, - scope="read write", - expires=0, - token="", - application=app - ) + access_token = AccessToken(user=self.user, scope="read write", expires=0, token="", application=app) self.assertTrue(access_token.allow_scopes(["read", "write"])) self.assertTrue(access_token.allow_scopes(["write", "read"])) @@ -93,21 +97,9 @@ def test_scopes_property(self): authorization_grant_type=Application.GRANT_AUTHORIZATION_CODE, ) - access_token = AccessToken( - user=self.user, - scope="read write", - expires=0, - token="", - application=app - ) + access_token = AccessToken(user=self.user, scope="read write", expires=0, token="", application=app) - access_token2 = AccessToken( - user=self.user, - scope="write", - expires=0, - token="", - application=app - ) + access_token2 = AccessToken(user=self.user, scope="write", expires=0, token="", application=app) self.assertEqual(access_token.scopes, {"read": "Reading scope", "write": "Writing scope"}) self.assertEqual(access_token2.scopes, {"write": "Writing scope"}) @@ -117,13 +109,10 @@ def test_scopes_property(self): OAUTH2_PROVIDER_APPLICATION_MODEL="tests.SampleApplication", OAUTH2_PROVIDER_ACCESS_TOKEN_MODEL="tests.SampleAccessToken", OAUTH2_PROVIDER_REFRESH_TOKEN_MODEL="tests.SampleRefreshToken", - OAUTH2_PROVIDER_GRANT_MODEL="tests.SampleGrant" + OAUTH2_PROVIDER_GRANT_MODEL="tests.SampleGrant", ) -class TestCustomModels(TestCase): - - def setUp(self): - self.user = UserModel.objects.create_user("test_user", "test@example.com", "123456") - +@pytest.mark.usefixtures("oauth2_settings") +class TestCustomModels(BaseTestModels): def test_custom_application_model(self): """ If a custom application model is installed, it should be present in @@ -132,7 +121,8 @@ def test_custom_application_model(self): See issue #90 (https://github.com/jazzband/django-oauth-toolkit/issues/90) """ related_object_names = [ - f.name for f in UserModel._meta.get_fields() + f.name + for f in UserModel._meta.get_fields() if (f.one_to_many or f.one_to_one) and f.auto_created and not f.concrete ] self.assertNotIn("oauth2_provider:application", related_object_names) @@ -140,22 +130,16 @@ def test_custom_application_model(self): def test_custom_application_model_incorrect_format(self): # Patch oauth2 settings to use a custom Application model - oauth2_settings.APPLICATION_MODEL = "IncorrectApplicationFormat" + self.oauth2_settings.APPLICATION_MODEL = "IncorrectApplicationFormat" self.assertRaises(ValueError, get_application_model) - # Revert oauth2 settings - oauth2_settings.APPLICATION_MODEL = "oauth2_provider.Application" - def test_custom_application_model_not_installed(self): # Patch oauth2 settings to use a custom Application model - oauth2_settings.APPLICATION_MODEL = "tests.ApplicationNotInstalled" + self.oauth2_settings.APPLICATION_MODEL = "tests.ApplicationNotInstalled" self.assertRaises(LookupError, get_application_model) - # Revert oauth2 settings - oauth2_settings.APPLICATION_MODEL = "oauth2_provider.Application" - def test_custom_access_token_model(self): """ If a custom access token model is installed, it should be present in @@ -163,7 +147,8 @@ def test_custom_access_token_model(self): """ # Django internals caches the related objects. related_object_names = [ - f.name for f in UserModel._meta.get_fields() + f.name + for f in UserModel._meta.get_fields() if (f.one_to_many or f.one_to_one) and f.auto_created and not f.concrete ] self.assertNotIn("oauth2_provider:access_token", related_object_names) @@ -171,22 +156,16 @@ def test_custom_access_token_model(self): def test_custom_access_token_model_incorrect_format(self): # Patch oauth2 settings to use a custom AccessToken model - oauth2_settings.ACCESS_TOKEN_MODEL = "IncorrectAccessTokenFormat" + self.oauth2_settings.ACCESS_TOKEN_MODEL = "IncorrectAccessTokenFormat" self.assertRaises(ValueError, get_access_token_model) - # Revert oauth2 settings - oauth2_settings.ACCESS_TOKEN_MODEL = "oauth2_provider.AccessToken" - def test_custom_access_token_model_not_installed(self): # Patch oauth2 settings to use a custom AccessToken model - oauth2_settings.ACCESS_TOKEN_MODEL = "tests.AccessTokenNotInstalled" + self.oauth2_settings.ACCESS_TOKEN_MODEL = "tests.AccessTokenNotInstalled" self.assertRaises(LookupError, get_access_token_model) - # Revert oauth2 settings - oauth2_settings.ACCESS_TOKEN_MODEL = "oauth2_provider.AccessToken" - def test_custom_refresh_token_model(self): """ If a custom refresh token model is installed, it should be present in @@ -194,7 +173,8 @@ def test_custom_refresh_token_model(self): """ # Django internals caches the related objects. related_object_names = [ - f.name for f in UserModel._meta.get_fields() + f.name + for f in UserModel._meta.get_fields() if (f.one_to_many or f.one_to_one) and f.auto_created and not f.concrete ] self.assertNotIn("oauth2_provider:refresh_token", related_object_names) @@ -202,22 +182,16 @@ def test_custom_refresh_token_model(self): def test_custom_refresh_token_model_incorrect_format(self): # Patch oauth2 settings to use a custom RefreshToken model - oauth2_settings.REFRESH_TOKEN_MODEL = "IncorrectRefreshTokenFormat" + self.oauth2_settings.REFRESH_TOKEN_MODEL = "IncorrectRefreshTokenFormat" self.assertRaises(ValueError, get_refresh_token_model) - # Revert oauth2 settings - oauth2_settings.REFRESH_TOKEN_MODEL = "oauth2_provider.RefreshToken" - def test_custom_refresh_token_model_not_installed(self): # Patch oauth2 settings to use a custom AccessToken model - oauth2_settings.REFRESH_TOKEN_MODEL = "tests.RefreshTokenNotInstalled" + self.oauth2_settings.REFRESH_TOKEN_MODEL = "tests.RefreshTokenNotInstalled" self.assertRaises(LookupError, get_refresh_token_model) - # Revert oauth2 settings - oauth2_settings.REFRESH_TOKEN_MODEL = "oauth2_provider.RefreshToken" - def test_custom_grant_model(self): """ If a custom grant model is installed, it should be present in @@ -225,7 +199,8 @@ def test_custom_grant_model(self): """ # Django internals caches the related objects. related_object_names = [ - f.name for f in UserModel._meta.get_fields() + f.name + for f in UserModel._meta.get_fields() if (f.one_to_many or f.one_to_one) and f.auto_created and not f.concrete ] self.assertNotIn("oauth2_provider:grant", related_object_names) @@ -233,24 +208,31 @@ def test_custom_grant_model(self): def test_custom_grant_model_incorrect_format(self): # Patch oauth2 settings to use a custom Grant model - oauth2_settings.GRANT_MODEL = "IncorrectGrantFormat" + self.oauth2_settings.GRANT_MODEL = "IncorrectGrantFormat" self.assertRaises(ValueError, get_grant_model) - # Revert oauth2 settings - oauth2_settings.GRANT_MODEL = "oauth2_provider.Grant" - def test_custom_grant_model_not_installed(self): # Patch oauth2 settings to use a custom AccessToken model - oauth2_settings.GRANT_MODEL = "tests.GrantNotInstalled" + self.oauth2_settings.GRANT_MODEL = "tests.GrantNotInstalled" self.assertRaises(LookupError, get_grant_model) - # Revert oauth2 settings - oauth2_settings.GRANT_MODEL = "oauth2_provider.Grant" +class TestGrantModel(BaseTestModels): + def setUp(self): + super().setUp() + self.application = Application.objects.create( + name="Test Application", + redirect_uris="", + user=self.user, + client_type=Application.CLIENT_CONFIDENTIAL, + authorization_grant_type=Application.GRANT_AUTHORIZATION_CODE, + ) -class TestGrantModel(TestCase): + def tearDown(self): + self.application.delete() + super().tearDown() def test_str(self): grant = Grant(code="test_code") @@ -261,12 +243,26 @@ def test_expires_can_be_none(self): self.assertIsNone(grant.expires) self.assertTrue(grant.is_expired()) + def test_redirect_uri_can_be_longer_than_255_chars(self): + long_redirect_uri = "http://example.com/{}".format("authorized/" * 25) + self.assertTrue(len(long_redirect_uri) > 255) + grant = Grant.objects.create( + user=self.user, + code="test_code", + application=self.application, + expires=timezone.now(), + redirect_uri=long_redirect_uri, + scope="", + ) + grant.refresh_from_db() -class TestAccessTokenModel(TestCase): + # It would be necessary to run test using another DB engine than sqlite + # that transform varchar(255) into text data type. + # https://sqlite.org/datatype3.html#affinity_name_examples + self.assertEqual(grant.redirect_uri, long_redirect_uri) - def setUp(self): - self.user = UserModel.objects.create_user("test_user", "test@example.com", "123456") +class TestAccessTokenModel(BaseTestModels): def test_str(self): access_token = AccessToken(token="test_token") self.assertEqual("%s" % access_token, access_token.token) @@ -288,17 +284,16 @@ def test_expires_can_be_none(self): self.assertTrue(access_token.is_expired()) -class TestRefreshTokenModel(TestCase): - +class TestRefreshTokenModel(BaseTestModels): def test_str(self): refresh_token = RefreshToken(token="test_token") self.assertEqual("%s" % refresh_token, refresh_token.token) -class TestClearExpired(TestCase): - +@pytest.mark.usefixtures("oauth2_settings") +class TestClearExpired(BaseTestModels): def setUp(self): - self.user = UserModel.objects.create_user("test_user", "test@example.com", "123456") + super().setUp() # Insert two tokens on database. app = Application.objects.create( name="test_app", @@ -315,7 +310,7 @@ def setUp(self): user=self.user, created=timezone.now(), updated=timezone.now(), - ) + ) AccessToken.objects.create( token="666", expires=timezone.now(), @@ -324,14 +319,14 @@ def setUp(self): user=self.user, created=timezone.now(), updated=timezone.now(), - ) + ) def test_clear_expired_tokens(self): - oauth2_settings.REFRESH_TOKEN_EXPIRE_SECONDS = 60 + self.oauth2_settings.REFRESH_TOKEN_EXPIRE_SECONDS = 60 assert clear_expired() is None def test_clear_expired_tokens_incorect_timetype(self): - oauth2_settings.REFRESH_TOKEN_EXPIRE_SECONDS = "A" + self.oauth2_settings.REFRESH_TOKEN_EXPIRE_SECONDS = "A" with pytest.raises(ImproperlyConfigured) as excinfo: clear_expired() result = excinfo.value.__class__.__name__ @@ -339,7 +334,7 @@ def test_clear_expired_tokens_incorect_timetype(self): def test_clear_expired_tokens_with_tokens(self): self.client.login(username="test_user", password="123456") - oauth2_settings.REFRESH_TOKEN_EXPIRE_SECONDS = 0 + self.oauth2_settings.REFRESH_TOKEN_EXPIRE_SECONDS = 0 ttokens = AccessToken.objects.count() expiredt = AccessToken.objects.filter(expires__lte=timezone.now()).count() assert ttokens == 2 @@ -347,3 +342,93 @@ def test_clear_expired_tokens_with_tokens(self): clear_expired() expiredt = AccessToken.objects.filter(expires__lte=timezone.now()).count() assert expiredt == 0 + + +@pytest.mark.django_db +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) +def test_id_token_methods(oidc_tokens, rf): + id_token = IDToken.objects.get() + + # Token was just created, so should be valid + assert id_token.is_valid() + + # if expires is None, it should always be expired + # the column is NOT NULL, but could be NULL in sub-classes + id_token.expires = None + assert id_token.is_expired() + + # if no scopes are passed, they should be valid + assert id_token.allow_scopes(None) + + # if the requested scopes are in the token, they should be valid + assert id_token.allow_scopes(["openid"]) + + # if the requested scopes are not in the token, they should not be valid + assert id_token.allow_scopes(["fizzbuzz"]) is False + + # we should be able to get a list of the scopes on the token + assert id_token.scopes == {"openid": "OpenID connect"} + + # the id token should stringify as the JWT token + id_token_str = str(id_token) + assert str(id_token.jti) in id_token_str + assert id_token_str.endswith(str(id_token.user_id)) + + # revoking the token should delete it + id_token.revoke() + assert IDToken.objects.filter(jti=id_token.jti).count() == 0 + + +@pytest.mark.django_db +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) +def test_application_key(oauth2_settings, application): + # RS256 key + key = application.jwk_key + assert key.key_type == "RSA" + + # RS256 key, but not configured + oauth2_settings.OIDC_RSA_PRIVATE_KEY = None + with pytest.raises(ImproperlyConfigured) as exc: + application.jwk_key + assert "You must set OIDC_RSA_PRIVATE_KEY" in str(exc.value) + + # HS256 key + application.algorithm = Application.HS256_ALGORITHM + key = application.jwk_key + assert key.key_type == "oct" + + # No algorithm + application.algorithm = Application.NO_ALGORITHM + with pytest.raises(ImproperlyConfigured) as exc: + application.jwk_key + assert "This application does not support signed tokens" == str(exc.value) + + +@pytest.mark.django_db +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) +def test_application_clean(oauth2_settings, application): + # RS256, RSA key is configured + application.clean() + + # RS256, RSA key is not configured + oauth2_settings.OIDC_RSA_PRIVATE_KEY = None + with pytest.raises(ValidationError) as exc: + application.clean() + assert "You must set OIDC_RSA_PRIVATE_KEY" in str(exc.value) + + # HS256 algorithm, auth code + confidential -> allowed + application.algorithm = Application.HS256_ALGORITHM + application.clean() + + # HS256, auth code + public -> forbidden + application.client_type = Application.CLIENT_PUBLIC + with pytest.raises(ValidationError) as exc: + application.clean() + assert "You cannot use HS256" in str(exc.value) + + # HS256, hybrid + confidential -> forbidden + application.client_type = Application.CLIENT_CONFIDENTIAL + application.authorization_grant_type = Application.GRANT_OPENID_HYBRID + with pytest.raises(ValidationError) as exc: + application.clean() + assert "You cannot use HS256" in str(exc.value) diff --git a/tests/test_oauth2_backends.py b/tests/test_oauth2_backends.py index 0d98dad..acff2ca 100644 --- a/tests/test_oauth2_backends.py +++ b/tests/test_oauth2_backends.py @@ -1,8 +1,10 @@ import json +import pytest from django.test import RequestFactory, TestCase from oauth2_provider.backends import get_oauthlib_core +from oauth2_provider.models import redirect_to_uri_allowed from oauth2_provider.oauth2_backends import JSONOAuthLibCore, OAuthLibCore @@ -12,16 +14,16 @@ import mock +@pytest.mark.usefixtures("oauth2_settings") class TestOAuthLibCoreBackend(TestCase): - def setUp(self): self.factory = RequestFactory() self.oauthlib_core = OAuthLibCore() def test_swappable_server_class(self): - with mock.patch("oauth2_provider.oauth2_backends.oauth2_settings.OAUTH2_SERVER_CLASS"): - oauthlib_core = OAuthLibCore() - self.assertTrue(isinstance(oauthlib_core.server, mock.MagicMock)) + self.oauth2_settings.OAUTH2_SERVER_CLASS = mock.MagicMock + oauthlib_core = OAuthLibCore() + self.assertTrue(isinstance(oauthlib_core.server, mock.MagicMock)) def test_form_urlencoded_extract_params(self): payload = "grant_type=password&username=john&password=123456" @@ -33,11 +35,13 @@ def test_form_urlencoded_extract_params(self): self.assertIn("password=123456", body) def test_application_json_extract_params(self): - payload = json.dumps({ - "grant_type": "password", - "username": "john", - "password": "123456", - }) + payload = json.dumps( + { + "grant_type": "password", + "username": "john", + "password": "123456", + } + ) request = self.factory.post("/o/token/", payload, content_type="application/json") uri, http_method, body, headers = self.oauthlib_core._extract_params(request) @@ -51,6 +55,7 @@ class TestCustomOAuthLibCoreBackend(TestCase): Tests that the public API behaves as expected when we override the OAuthLibCoreBackend core methods. """ + class MyOAuthLibCore(OAuthLibCore): def _get_extra_credentials(self, request): return 1 @@ -65,9 +70,7 @@ def test_create_token_response_gets_extra_credentials(self): payload = "grant_type=password&username=john&password=123456" request = self.factory.post("/o/token/", payload, content_type="application/x-www-form-urlencoded") - with mock.patch( - "oauthlib.openid.connect.core.endpoints.pre_configured.Server.create_token_response" - ) as create_token_response: + with mock.patch("oauthlib.oauth2.Server.create_token_response") as create_token_response: mocked = mock.MagicMock() create_token_response.return_value = mocked, mocked, mocked core = self.MyOAuthLibCore() @@ -81,11 +84,13 @@ def setUp(self): self.oauthlib_core = JSONOAuthLibCore() def test_application_json_extract_params(self): - payload = json.dumps({ - "grant_type": "password", - "username": "john", - "password": "123456", - }) + payload = json.dumps( + { + "grant_type": "password", + "username": "john", + "password": "123456", + } + ) request = self.factory.post("/o/token/", payload, content_type="application/json") uri, http_method, body, headers = self.oauthlib_core._extract_params(request) @@ -106,3 +111,23 @@ def test_validate_authorization_request_unsafe_query(self): oauthlib_core = get_oauthlib_core() oauthlib_core.verify_request(request, scopes=[]) + + +@pytest.mark.parametrize( + "uri, expected_result", + # localhost is _not_ a loopback URI + [ + ("http://localhost:3456", False), + # only http scheme is supported for loopback URIs + ("https://127.0.0.1:3456", False), + ("http://127.0.0.1:3456", True), + ("http://[::1]", True), + ("http://[::1]:34", True), + ], +) +def test_uri_loopback_redirect_check(uri, expected_result): + allowed_uris = ["http://127.0.0.1", "http://[::1]"] + if expected_result: + assert redirect_to_uri_allowed(uri, allowed_uris) + else: + assert not redirect_to_uri_allowed(uri, allowed_uris) diff --git a/tests/test_oauth2_validators.py b/tests/test_oauth2_validators.py index d924823..7997d3b 100644 --- a/tests/test_oauth2_validators.py +++ b/tests/test_oauth2_validators.py @@ -1,17 +1,21 @@ import contextlib import datetime +import json +import pytest from django.contrib.auth import get_user_model -from django.test import TransactionTestCase +from django.test import TestCase, TransactionTestCase from django.utils import timezone +from jwcrypto import jwt from oauthlib.common import Request from oauth2_provider.exceptions import FatalClientError -from oauth2_provider.models import ( - get_access_token_model, get_application_model, get_refresh_token_model -) +from oauth2_provider.models import get_access_token_model, get_application_model, get_refresh_token_model +from oauth2_provider.oauth2_backends import get_oauthlib_core from oauth2_provider.oauth2_validators import OAuth2Validator +from . import presets + try: from unittest import mock @@ -46,8 +50,12 @@ def setUp(self): self.request.grant_type = "not client" self.validator = OAuth2Validator() self.application = Application.objects.create( - client_id="client_id", client_secret="client_secret", user=self.user, - client_type=Application.CLIENT_PUBLIC, authorization_grant_type=Application.GRANT_PASSWORD) + client_id="client_id", + client_secret="client_secret", + user=self.user, + client_type=Application.CLIENT_PUBLIC, + authorization_grant_type=Application.GRANT_PASSWORD, + ) self.request.client = self.application def tearDown(self): @@ -163,13 +171,10 @@ def test_save_bearer_token__with_existing_tokens__does_not_create_new_tokens(sel token="123", user=self.user, expires=timezone.now() + datetime.timedelta(seconds=60), - application=self.application + application=self.application, ) refresh_token = RefreshToken.objects.create( - access_token=access_token, - token="abc", - user=self.user, - application=self.application + access_token=access_token, token="abc", user=self.user, application=self.application ) self.request.refresh_token_instance = refresh_token token = { @@ -196,13 +201,10 @@ def test_save_bearer_token__checks_to_rotate_tokens(self): token="123", user=self.user, expires=timezone.now() + datetime.timedelta(seconds=60), - application=self.application + application=self.application, ) refresh_token = RefreshToken.objects.create( - access_token=access_token, - token="abc", - user=self.user, - application=self.application + access_token=access_token, token="abc", user=self.user, application=self.application ) self.request.refresh_token_instance = refresh_token token = { @@ -234,13 +236,10 @@ def test_save_bearer_token__with_new_token_equal_to_existing_token__revokes_old_ token="123", user=self.user, expires=timezone.now() + datetime.timedelta(seconds=60), - application=self.application + application=self.application, ) refresh_token = RefreshToken.objects.create( - access_token=access_token, - token="abc", - user=self.user, - application=self.application + access_token=access_token, token="abc", user=self.user, application=self.application ) self.request.refresh_token_instance = refresh_token @@ -318,7 +317,9 @@ class TestOAuth2ValidatorProvidesErrorData(TransactionTestCase): def setUp(self): self.user = UserModel.objects.create_user( - "user", "test@example.com", "123456", + "user", + "test@example.com", + "123456", ) self.request = mock.MagicMock(wraps=Request) self.request.user = self.user @@ -340,13 +341,20 @@ def test_validate_bearer_token_does_not_add_error_when_no_token_is_provided(self def test_validate_bearer_token_adds_error_to_the_request_when_an_invalid_token_is_provided(self): access_token = mock.MagicMock(token="some_invalid_token") - self.assertFalse(self.validator.validate_bearer_token( - access_token.token, [], self.request, - )) - self.assertDictEqual(self.request.oauth2_error, { - "error": "invalid_token", - "error_description": "The access token is invalid.", - }) + self.assertFalse( + self.validator.validate_bearer_token( + access_token.token, + [], + self.request, + ) + ) + self.assertDictEqual( + self.request.oauth2_error, + { + "error": "invalid_token", + "error_description": "The access token is invalid.", + }, + ) def test_validate_bearer_token_adds_error_to_the_request_when_an_expired_token_is_provided(self): access_token = AccessToken.objects.create( @@ -355,13 +363,20 @@ def test_validate_bearer_token_adds_error_to_the_request_when_an_expired_token_i expires=timezone.now() - datetime.timedelta(seconds=1), application=self.application, ) - self.assertFalse(self.validator.validate_bearer_token( - access_token.token, [], self.request, - )) - self.assertDictEqual(self.request.oauth2_error, { - "error": "invalid_token", - "error_description": "The access token has expired.", - }) + self.assertFalse( + self.validator.validate_bearer_token( + access_token.token, + [], + self.request, + ) + ) + self.assertDictEqual( + self.request.oauth2_error, + { + "error": "invalid_token", + "error_description": "The access token has expired.", + }, + ) def test_validate_bearer_token_adds_error_to_the_request_when_a_valid_token_has_insufficient_scope(self): access_token = AccessToken.objects.create( @@ -370,13 +385,20 @@ def test_validate_bearer_token_adds_error_to_the_request_when_a_valid_token_has_ expires=timezone.now() + datetime.timedelta(seconds=1), application=self.application, ) - self.assertFalse(self.validator.validate_bearer_token( - access_token.token, ["some_extra_scope"], self.request, - )) - self.assertDictEqual(self.request.oauth2_error, { - "error": "insufficient_scope", - "error_description": "The access token is valid but does not have enough scope.", - }) + self.assertFalse( + self.validator.validate_bearer_token( + access_token.token, + ["some_extra_scope"], + self.request, + ) + ) + self.assertDictEqual( + self.request.oauth2_error, + { + "error": "insufficient_scope", + "error_description": "The access token is valid but does not have enough scope.", + }, + ) def test_validate_bearer_token_adds_error_to_the_request_when_a_invalid_custom_token_is_provided(self): access_token = AccessToken.objects.create( @@ -386,9 +408,115 @@ def test_validate_bearer_token_adds_error_to_the_request_when_a_invalid_custom_t application=self.application, ) with always_invalid_token(): - self.assertFalse(self.validator.validate_bearer_token( - access_token.token, [], self.request, - )) - self.assertDictEqual(self.request.oauth2_error, { - "error": "invalid_token", - }) + self.assertFalse( + self.validator.validate_bearer_token( + access_token.token, + [], + self.request, + ) + ) + self.assertDictEqual( + self.request.oauth2_error, + { + "error": "invalid_token", + }, + ) + + +class TestOAuth2ValidatorErrorResourceToken(TestCase): + """The following tests check logger information when response from oauth2 + is unsuccessful. + """ + + def setUp(self): + self.token = "test_token" + self.introspection_url = "http://example.com/token/introspection/" + self.introspection_token = "test_introspection_token" + self.validator = OAuth2Validator() + + def test_response_when_auth_server_response_return_404(self): + with self.assertLogs(logger="oauth2_provider") as mock_log: + self.validator._get_token_from_authentication_server( + self.token, self.introspection_url, self.introspection_token, None + ) + self.assertIn( + "ERROR:oauth2_provider:Introspection: Failed to " + "get a valid response from authentication server. " + "Status code: 404, Reason: " + "Not Found.\nNoneType: None", + mock_log.output, + ) + + +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) +def test_oidc_endpoint_generation(oauth2_settings, rf): + oauth2_settings.OIDC_ISS_ENDPOINT = "" + django_request = rf.get("/") + request = Request("/", headers=django_request.META) + validator = OAuth2Validator() + oidc_issuer_endpoint = validator.get_oidc_issuer_endpoint(request) + assert oidc_issuer_endpoint == "http://testserver/o" + + +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) +def test_oidc_endpoint_generation_ssl(oauth2_settings, rf, settings): + oauth2_settings.OIDC_ISS_ENDPOINT = "" + django_request = rf.get("/", secure=True) + # Calling the settings method with a django https request should generate a https url + oidc_issuer_endpoint = oauth2_settings.oidc_issuer(django_request) + assert oidc_issuer_endpoint == "https://testserver/o" + + # Should also work with an oauthlib request (via validator) + core = get_oauthlib_core() + uri, http_method, body, headers = core._extract_params(django_request) + request = Request(uri=uri, http_method=http_method, body=body, headers=headers) + validator = OAuth2Validator() + oidc_issuer_endpoint = validator.get_oidc_issuer_endpoint(request) + assert oidc_issuer_endpoint == "https://testserver/o" + + +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) +def test_get_jwt_bearer_token(oauth2_settings, mocker): + # oauthlib instructs us to make get_jwt_bearer_token call get_id_token + request = mocker.MagicMock(wraps=Request) + validator = OAuth2Validator() + mock_get_id_token = mocker.patch.object(validator, "get_id_token") + validator.get_jwt_bearer_token(None, None, request) + assert mock_get_id_token.call_count == 1 + assert mock_get_id_token.call_args[0] == (None, None, request) + assert mock_get_id_token.call_args[1] == {} + + +@pytest.mark.django_db +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) +def test_validate_id_token_expired_jwt(oauth2_settings, mocker, oidc_tokens): + mocker.patch("oauth2_provider.oauth2_validators.jwt.JWT", side_effect=jwt.JWTExpired) + validator = OAuth2Validator() + status = validator.validate_id_token(oidc_tokens.id_token, ["openid"], mocker.sentinel.request) + assert status is False + + +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) +def test_validate_id_token_no_token(oauth2_settings, mocker): + validator = OAuth2Validator() + status = validator.validate_id_token("", ["openid"], mocker.sentinel.request) + assert status is False + + +@pytest.mark.django_db +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) +def test_validate_id_token_app_removed(oauth2_settings, mocker, oidc_tokens): + oidc_tokens.application.delete() + validator = OAuth2Validator() + status = validator.validate_id_token(oidc_tokens.id_token, ["openid"], mocker.sentinel.request) + assert status is False + + +@pytest.mark.django_db +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) +def test_validate_id_token_bad_token_no_aud(oauth2_settings, mocker, oidc_key): + token = jwt.JWT(header=json.dumps({"alg": "RS256"}), claims=json.dumps({"bad": "token"})) + token.make_signed_token(oidc_key) + validator = OAuth2Validator() + status = validator.validate_id_token(token.serialize(), ["openid"], mocker.sentinel.request) + assert status is False diff --git a/tests/test_oidc_views.py b/tests/test_oidc_views.py index 43e46d2..5cbae54 100644 --- a/tests/test_oidc_views.py +++ b/tests/test_oidc_views.py @@ -1,17 +1,48 @@ -from __future__ import unicode_literals - +import pytest from django.test import TestCase from django.urls import reverse +from oauth2_provider.oauth2_validators import OAuth2Validator + +from . import presets + +@pytest.mark.usefixtures("oauth2_settings") +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) class TestConnectDiscoveryInfoView(TestCase): def test_get_connect_discovery_info(self): expected_response = { - "issuer": "http://localhost", + "issuer": "http://localhost/o", "authorization_endpoint": "http://localhost/o/authorize/", "token_endpoint": "http://localhost/o/token/", - "userinfo_endpoint": "http://localhost/userinfo/", - "jwks_uri": "http://localhost/o/jwks/", + "userinfo_endpoint": "http://localhost/o/userinfo/", + "jwks_uri": "http://localhost/o/.well-known/jwks.json", + "response_types_supported": [ + "code", + "token", + "id_token", + "id_token token", + "code token", + "code id_token", + "code id_token token", + ], + "subject_types_supported": ["public"], + "id_token_signing_alg_values_supported": ["RS256", "HS256"], + "token_endpoint_auth_methods_supported": ["client_secret_post", "client_secret_basic"], + } + response = self.client.get(reverse("oauth2_provider:oidc-connect-discovery-info")) + self.assertEqual(response.status_code, 200) + assert response.json() == expected_response + + def test_get_connect_discovery_info_without_issuer_url(self): + self.oauth2_settings.OIDC_ISS_ENDPOINT = None + self.oauth2_settings.OIDC_USERINFO_ENDPOINT = None + expected_response = { + "issuer": "http://testserver/o", + "authorization_endpoint": "http://testserver/o/authorize/", + "token_endpoint": "http://testserver/o/token/", + "userinfo_endpoint": "http://testserver/o/userinfo/", + "jwks_uri": "http://testserver/o/.well-known/jwks.json", "response_types_supported": [ "code", "token", @@ -19,29 +50,90 @@ def test_get_connect_discovery_info(self): "id_token token", "code token", "code id_token", - "code id_token token" + "code id_token token", ], "subject_types_supported": ["public"], "id_token_signing_alg_values_supported": ["RS256", "HS256"], - "token_endpoint_auth_methods_supported": ["client_secret_post", "client_secret_basic"] + "token_endpoint_auth_methods_supported": ["client_secret_post", "client_secret_basic"], } response = self.client.get(reverse("oauth2_provider:oidc-connect-discovery-info")) self.assertEqual(response.status_code, 200) assert response.json() == expected_response + def test_get_connect_discovery_info_without_rsa_key(self): + self.oauth2_settings.OIDC_RSA_PRIVATE_KEY = None + response = self.client.get(reverse("oauth2_provider:oidc-connect-discovery-info")) + self.assertEqual(response.status_code, 200) + assert response.json()["id_token_signing_alg_values_supported"] == ["HS256"] + +@pytest.mark.usefixtures("oauth2_settings") +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) class TestJwksInfoView(TestCase): def test_get_jwks_info(self): expected_response = { - "keys": [{ - "alg": "RS256", - "use": "sig", - "kid": "s4a1o8mFEd1tATAIH96caMlu4hOxzBUaI2QTqbYNBHs", - "e": "AQAB", - "kty": "RSA", - "n": "mwmIeYdjZkLgalTuhvvwjvnB5vVQc7G9DHgOm20Hw524bLVTk49IXJ2Scw42HOmowWWX-oMVT_ca3ZvVIeffVSN1-TxVy2zB65s0wDMwhiMoPv35z9IKHGMZgl9vlyso_2b7daVF_FQDdgIayUn8TQylBxEU1RFfW0QSYOBdAt8" # noqa - }] + "keys": [ + { + "alg": "RS256", + "use": "sig", + "kid": "s4a1o8mFEd1tATAIH96caMlu4hOxzBUaI2QTqbYNBHs", + "e": "AQAB", + "kty": "RSA", + "n": "mwmIeYdjZkLgalTuhvvwjvnB5vVQc7G9DHgOm20Hw524bLVTk49IXJ2Scw42HOmowWWX-oMVT_ca3ZvVIeffVSN1-TxVy2zB65s0wDMwhiMoPv35z9IKHGMZgl9vlyso_2b7daVF_FQDdgIayUn8TQylBxEU1RFfW0QSYOBdAt8", # noqa + } + ] } response = self.client.get(reverse("oauth2_provider:jwks-info")) self.assertEqual(response.status_code, 200) assert response.json() == expected_response + + def test_get_jwks_info_no_rsa_key(self): + self.oauth2_settings.OIDC_RSA_PRIVATE_KEY = None + response = self.client.get(reverse("oauth2_provider:jwks-info")) + self.assertEqual(response.status_code, 200) + assert response.json() == {"keys": []} + + +@pytest.mark.django_db +@pytest.mark.parametrize("method", ["get", "post"]) +def test_userinfo_endpoint(oidc_tokens, client, method): + auth_header = "Bearer %s" % oidc_tokens.access_token + rsp = getattr(client, method)( + reverse("oauth2_provider:user-info"), + HTTP_AUTHORIZATION=auth_header, + ) + data = rsp.json() + assert "sub" in data + assert data["sub"] == str(oidc_tokens.user.pk) + + +@pytest.mark.django_db +def test_userinfo_endpoint_bad_token(oidc_tokens, client): + # No access token + rsp = client.get(reverse("oauth2_provider:user-info")) + assert rsp.status_code == 401 + # Bad access token + rsp = client.get( + reverse("oauth2_provider:user-info"), + HTTP_AUTHORIZATION="Bearer not-a-real-token", + ) + assert rsp.status_code == 401 + + +@pytest.mark.django_db +def test_userinfo_endpoint_custom_claims(oidc_tokens, client, oauth2_settings): + class CustomValidator(OAuth2Validator): + def get_additional_claims(self, request): + return {"state": "very nice"} + + oidc_tokens.oauth2_settings.OAUTH2_VALIDATOR_CLASS = CustomValidator + auth_header = "Bearer %s" % oidc_tokens.access_token + rsp = client.get( + reverse("oauth2_provider:user-info"), + HTTP_AUTHORIZATION=auth_header, + ) + data = rsp.json() + assert "sub" in data + assert data["sub"] == str(oidc_tokens.user.pk) + assert "state" in data + assert data["state"] == "very nice" diff --git a/tests/test_password.py b/tests/test_password.py index f50404f..953b076 100644 --- a/tests/test_password.py +++ b/tests/test_password.py @@ -1,11 +1,11 @@ import json +import pytest from django.contrib.auth import get_user_model from django.test import RequestFactory, TestCase from django.urls import reverse from oauth2_provider.models import get_application_model -from oauth2_provider.settings import oauth2_settings from oauth2_provider.views import ProtectedResourceView from .utils import get_basic_auth_header @@ -21,6 +21,7 @@ def get(self, request, *args, **kwargs): return "This is a protected resource" +@pytest.mark.usefixtures("oauth2_settings") class BaseTest(TestCase): def setUp(self): self.factory = RequestFactory() @@ -34,9 +35,6 @@ def setUp(self): authorization_grant_type=Application.GRANT_PASSWORD, ) - oauth2_settings._SCOPES = ["read", "write"] - oauth2_settings._DEFAULT_SCOPES = ["read", "write"] - def tearDown(self): self.application.delete() self.test_user.delete() @@ -60,8 +58,8 @@ def test_get_token(self): content = json.loads(response.content.decode("utf-8")) self.assertEqual(content["token_type"], "Bearer") - self.assertEqual(content["scope"], "read write") - self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + self.assertEqual(set(content["scope"].split()), {"read", "write"}) + self.assertEqual(content["expires_in"], self.oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) def test_bad_credentials(self): """ diff --git a/tests/test_rest_framework.py b/tests/test_rest_framework.py index 21a6ccd..a25611b 100644 --- a/tests/test_rest_framework.py +++ b/tests/test_rest_framework.py @@ -1,11 +1,13 @@ from datetime import timedelta -from django.conf.urls import include, url +import pytest +from django.conf.urls import include from django.contrib.auth import get_user_model from django.core.exceptions import ImproperlyConfigured from django.http import HttpResponse from django.test import TestCase from django.test.utils import override_settings +from django.urls import path, re_path from django.utils import timezone from rest_framework import permissions from rest_framework.authentication import BaseAuthentication @@ -13,18 +15,16 @@ from rest_framework.views import APIView from oauth2_provider.contrib.rest_framework import ( - IsAuthenticatedOrTokenHasScope, OAuth2Authentication, - TokenHasReadWriteScope, TokenHasResourceScope, - TokenHasScope, TokenMatchesOASRequirements + IsAuthenticatedOrTokenHasScope, + OAuth2Authentication, + TokenHasReadWriteScope, + TokenHasResourceScope, + TokenHasScope, + TokenMatchesOASRequirements, ) from oauth2_provider.models import get_access_token_model, get_application_model -from oauth2_provider.settings import oauth2_settings - -try: - from unittest import mock -except ImportError: - import mock +from . import presets Application = get_application_model() @@ -84,7 +84,10 @@ class MethodScopeAltViewBad(OAuth2View): class MissingAuthentication(BaseAuthentication): def authenticate(self, request): - return ("junk", "junk",) + return ( + "junk", + "junk", + ) class BrokenOAuth2View(MockView): @@ -109,25 +112,25 @@ class AuthenticationNoneOAuth2View(MockView): urlpatterns = [ - url(r"^oauth2/", include("oauth2_provider.urls")), - url(r"^oauth2-test/$", OAuth2View.as_view()), - url(r"^oauth2-scoped-test/$", ScopedView.as_view()), - url(r"^oauth2-scoped-missing-auth/$", TokenHasScopeViewWrongAuth.as_view()), - url(r"^oauth2-read-write-test/$", ReadWriteScopedView.as_view()), - url(r"^oauth2-resource-scoped-test/$", ResourceScopedView.as_view()), - url(r"^oauth2-authenticated-or-scoped-test/$", AuthenticatedOrScopedView.as_view()), - url(r"^oauth2-method-scope-test/.*$", MethodScopeAltView.as_view()), - url(r"^oauth2-method-scope-fail/$", MethodScopeAltViewBad.as_view()), - url(r"^oauth2-method-scope-missing-auth/$", MethodScopeAltViewWrongAuth.as_view()), - url(r"^oauth2-authentication-none/$", AuthenticationNoneOAuth2View.as_view()), + path("oauth2/", include("oauth2_provider.urls")), + path("oauth2-test/", OAuth2View.as_view()), + path("oauth2-scoped-test/", ScopedView.as_view()), + path("oauth2-scoped-missing-auth/", TokenHasScopeViewWrongAuth.as_view()), + path("oauth2-read-write-test/", ReadWriteScopedView.as_view()), + path("oauth2-resource-scoped-test/", ResourceScopedView.as_view()), + path("oauth2-authenticated-or-scoped-test/", AuthenticatedOrScopedView.as_view()), + re_path(r"oauth2-method-scope-test/.*$", MethodScopeAltView.as_view()), + path("oauth2-method-scope-fail/", MethodScopeAltViewBad.as_view()), + path("oauth2-method-scope-missing-auth/", MethodScopeAltViewWrongAuth.as_view()), + path("oauth2-authentication-none/", AuthenticationNoneOAuth2View.as_view()), ] @override_settings(ROOT_URLCONF=__name__) +@pytest.mark.usefixtures("oauth2_settings") +@pytest.mark.oauth2_settings(presets.REST_FRAMEWORK_SCOPES) class TestOAuth2Authentication(TestCase): def setUp(self): - oauth2_settings._SCOPES = ["read", "write", "scope1", "scope2", "resource1"] - self.test_user = UserModel.objects.create_user("test_user", "test@example.com", "123456") self.dev_user = UserModel.objects.create_user("dev_user", "dev@example.com", "123456") @@ -144,12 +147,9 @@ def setUp(self): scope="read write", expires=timezone.now() + timedelta(seconds=300), token="secret-access-token-key", - application=self.application + application=self.application, ) - def tearDown(self): - oauth2_settings._SCOPES = ["read", "write"] - def _create_authorization_header(self, token): return "Bearer {0}".format(token) @@ -304,8 +304,8 @@ def test_resource_scoped_permission_post_denied(self): response = self.client.post("/oauth2-resource-scoped-test/", HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 403) - @mock.patch.object(oauth2_settings, "ERROR_RESPONSE_WITH_SCOPES", new=True) def test_required_scope_in_response(self): + self.oauth2_settings.ERROR_RESPONSE_WITH_SCOPES = True self.access_token.scope = "scope2" self.access_token.save() diff --git a/tests/test_scopes.py b/tests/test_scopes.py index f744d67..a310e22 100644 --- a/tests/test_scopes.py +++ b/tests/test_scopes.py @@ -1,18 +1,14 @@ import json from urllib.parse import parse_qs, urlparse +import pytest from django.contrib.auth import get_user_model from django.core.exceptions import ImproperlyConfigured from django.test import RequestFactory, TestCase from django.urls import reverse -from oauth2_provider.models import ( - get_access_token_model, get_application_model, get_grant_model -) -from oauth2_provider.settings import oauth2_settings -from oauth2_provider.views import ( - ReadWriteScopedResourceView, ScopedProtectedResourceView -) +from oauth2_provider.models import get_access_token_model, get_application_model, get_grant_model +from oauth2_provider.views import ReadWriteScopedResourceView, ScopedProtectedResourceView from .utils import get_basic_auth_header @@ -46,6 +42,19 @@ def post(self, request, *args, **kwargs): return "This is a write protected resource" +SCOPE_SETTINGS = { + "SCOPES": { + "read": "Read scope", + "write": "Write scope", + "scope1": "Custom scope 1", + "scope2": "Custom scope 2", + "scope3": "Custom scope 3", + }, +} + + +@pytest.mark.usefixtures("oauth2_settings") +@pytest.mark.oauth2_settings(SCOPE_SETTINGS) class BaseTest(TestCase): def setUp(self): self.factory = RequestFactory() @@ -60,12 +69,7 @@ def setUp(self): authorization_grant_type=Application.GRANT_AUTHORIZATION_CODE, ) - oauth2_settings._SCOPES = ["read", "write", "scope1", "scope2", "scope3"] - oauth2_settings.READ_SCOPE = "read" - oauth2_settings.WRITE_SCOPE = "write" - def tearDown(self): - oauth2_settings._SCOPES = ["read", "write"] self.application.delete() self.test_user.delete() self.dev_user.delete() @@ -117,7 +121,7 @@ def test_scopes_save_in_access_token(self): token_request_data = { "grant_type": "authorization_code", "code": authorization_code, - "redirect_uri": "http://example.org" + "redirect_uri": "http://example.org", } auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) @@ -153,7 +157,7 @@ def test_scopes_protection_valid(self): token_request_data = { "grant_type": "authorization_code", "code": authorization_code, - "redirect_uri": "http://example.org" + "redirect_uri": "http://example.org", } auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) @@ -195,7 +199,7 @@ def test_scopes_protection_fail(self): token_request_data = { "grant_type": "authorization_code", "code": authorization_code, - "redirect_uri": "http://example.org" + "redirect_uri": "http://example.org", } auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) @@ -237,7 +241,7 @@ def test_multi_scope_fail(self): token_request_data = { "grant_type": "authorization_code", "code": authorization_code, - "redirect_uri": "http://example.org" + "redirect_uri": "http://example.org", } auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) @@ -279,7 +283,7 @@ def test_multi_scope_valid(self): token_request_data = { "grant_type": "authorization_code", "code": authorization_code, - "redirect_uri": "http://example.org" + "redirect_uri": "http://example.org", } auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) @@ -320,7 +324,7 @@ def get_access_token(self, scopes): token_request_data = { "grant_type": "authorization_code", "code": authorization_code, - "redirect_uri": "http://example.org" + "redirect_uri": "http://example.org", } auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) @@ -329,27 +333,27 @@ def get_access_token(self, scopes): return content["access_token"] def test_improperly_configured(self): - oauth2_settings.SCOPES = {"scope1": "Scope 1"} + self.oauth2_settings.SCOPES = {"scope1": "Scope 1"} request = self.factory.get("/fake") view = ReadWriteResourceView.as_view() self.assertRaises(ImproperlyConfigured, view, request) - oauth2_settings.SCOPES = {"read": "Read Scope", "write": "Write Scope"} - oauth2_settings.READ_SCOPE = "ciccia" + self.oauth2_settings.SCOPES = {"read": "Read Scope", "write": "Write Scope"} + self.oauth2_settings.READ_SCOPE = "ciccia" view = ReadWriteResourceView.as_view() self.assertRaises(ImproperlyConfigured, view, request) def test_properly_configured(self): - oauth2_settings.SCOPES = {"scope1": "Scope 1"} + self.oauth2_settings.SCOPES = {"scope1": "Scope 1"} request = self.factory.get("/fake") view = ReadWriteResourceView.as_view() self.assertRaises(ImproperlyConfigured, view, request) - oauth2_settings.SCOPES = {"read": "Read Scope", "write": "Write Scope"} - oauth2_settings.READ_SCOPE = "ciccia" + self.oauth2_settings.SCOPES = {"read": "Read Scope", "write": "Write Scope"} + self.oauth2_settings.READ_SCOPE = "ciccia" view = ReadWriteResourceView.as_view() self.assertRaises(ImproperlyConfigured, view, request) diff --git a/tests/test_scopes_backend.py b/tests/test_scopes_backend.py index 5f62961..925a4e3 100644 --- a/tests/test_scopes_backend.py +++ b/tests/test_scopes_backend.py @@ -3,9 +3,9 @@ def test_settings_scopes_get_available_scopes(): scopes = SettingsScopes() - assert scopes.get_available_scopes() == ["read", "write"] + assert set(scopes.get_available_scopes()) == {"read", "write"} def test_settings_scopes_get_default_scopes(): scopes = SettingsScopes() - assert scopes.get_default_scopes() == ["read", "write"] + assert set(scopes.get_default_scopes()) == {"read", "write"} diff --git a/tests/test_settings.py b/tests/test_settings.py new file mode 100644 index 0000000..52bdafe --- /dev/null +++ b/tests/test_settings.py @@ -0,0 +1,169 @@ +import pytest +from django.core.exceptions import ImproperlyConfigured +from django.test import TestCase +from django.test.utils import override_settings +from oauthlib.common import Request + +from oauth2_provider.admin import ( + get_access_token_admin_class, + get_application_admin_class, + get_grant_admin_class, + get_id_token_admin_class, + get_refresh_token_admin_class, +) +from oauth2_provider.settings import OAuth2ProviderSettings, oauth2_settings, perform_import +from tests.admin import ( + CustomAccessTokenAdmin, + CustomApplicationAdmin, + CustomGrantAdmin, + CustomIDTokenAdmin, + CustomRefreshTokenAdmin, +) + +from . import presets + + +class TestAdminClass(TestCase): + def test_import_error_message_maintained(self): + """ + Make sure import errors are captured and raised sensibly. + """ + settings = OAuth2ProviderSettings({"CLIENT_ID_GENERATOR_CLASS": "invalid_module.InvalidClassName"}) + with self.assertRaises(ImportError): + settings.CLIENT_ID_GENERATOR_CLASS + + def test_get_application_admin_class(self): + """ + Test for getting class for application admin. + """ + application_admin_class = get_application_admin_class() + default_application_admin_class = oauth2_settings.APPLICATION_ADMIN_CLASS + assert application_admin_class == default_application_admin_class + + def test_get_access_token_admin_class(self): + """ + Test for getting class for access token admin. + """ + access_token_admin_class = get_access_token_admin_class() + default_access_token_admin_class = oauth2_settings.ACCESS_TOKEN_ADMIN_CLASS + assert access_token_admin_class == default_access_token_admin_class + + def test_get_grant_admin_class(self): + """ + Test for getting class for grant admin. + """ + grant_admin_class = get_grant_admin_class() + default_grant_admin_class = oauth2_settings.GRANT_ADMIN_CLASS + assert grant_admin_class == default_grant_admin_class + + def test_get_id_token_admin_class(self): + """ + Test for getting class for ID token admin. + """ + id_token_admin_class = get_id_token_admin_class() + default_id_token_admin_class = oauth2_settings.ID_TOKEN_ADMIN_CLASS + assert id_token_admin_class == default_id_token_admin_class + + def test_get_refresh_token_admin_class(self): + """ + Test for getting class for refresh token admin. + """ + refresh_token_admin_class = get_refresh_token_admin_class() + default_refresh_token_admin_class = oauth2_settings.REFRESH_TOKEN_ADMIN_CLASS + assert refresh_token_admin_class == default_refresh_token_admin_class + + @override_settings(OAUTH2_PROVIDER={"APPLICATION_ADMIN_CLASS": "tests.admin.CustomApplicationAdmin"}) + def test_get_custom_application_admin_class(self): + """ + Test for getting custom class for application admin. + """ + application_admin_class = get_application_admin_class() + assert application_admin_class == CustomApplicationAdmin + + @override_settings(OAUTH2_PROVIDER={"ACCESS_TOKEN_ADMIN_CLASS": "tests.admin.CustomAccessTokenAdmin"}) + def test_get_custom_access_token_admin_class(self): + """ + Test for getting custom class for access token admin. + """ + access_token_admin_class = get_access_token_admin_class() + assert access_token_admin_class == CustomAccessTokenAdmin + + @override_settings(OAUTH2_PROVIDER={"GRANT_ADMIN_CLASS": "tests.admin.CustomGrantAdmin"}) + def test_get_custom_grant_admin_class(self): + """ + Test for getting custom class for grant admin. + """ + grant_admin_class = get_grant_admin_class() + assert grant_admin_class == CustomGrantAdmin + + @override_settings(OAUTH2_PROVIDER={"ID_TOKEN_ADMIN_CLASS": "tests.admin.CustomIDTokenAdmin"}) + def test_get_custom_id_token_admin_class(self): + """ + Test for getting custom class for ID token admin. + """ + id_token_admin_class = get_id_token_admin_class() + assert id_token_admin_class == CustomIDTokenAdmin + + @override_settings(OAUTH2_PROVIDER={"REFRESH_TOKEN_ADMIN_CLASS": "tests.admin.CustomRefreshTokenAdmin"}) + def test_get_custom_refresh_token_admin_class(self): + """ + Test for getting custom class for refresh token admin. + """ + refresh_token_admin_class = get_refresh_token_admin_class() + assert refresh_token_admin_class == CustomRefreshTokenAdmin + + +def test_perform_import_when_none(): + assert perform_import(None, "REFRESH_TOKEN_ADMIN_CLASS") is None + + +def test_perform_import_list(): + imports = ["tests.admin.CustomIDTokenAdmin", "tests.admin.CustomGrantAdmin"] + assert perform_import(imports, "SOME_CLASSES") == [CustomIDTokenAdmin, CustomGrantAdmin] + + +def test_perform_import_already_imported(): + cls = perform_import(CustomRefreshTokenAdmin, "REFRESH_TOKEN_ADMIN_CLASS") + assert cls == CustomRefreshTokenAdmin + + +def test_invalid_scopes_raises_error(): + settings = OAuth2ProviderSettings( + { + "SCOPES": {"foo": "foo scope"}, + "DEFAULT_SCOPES": ["bar"], + } + ) + with pytest.raises(ImproperlyConfigured) as exc: + settings._DEFAULT_SCOPES + assert str(exc.value) == "Defined DEFAULT_SCOPES not present in SCOPES" + + +def test_missing_mandatory_setting_raises_error(): + settings = OAuth2ProviderSettings( + user_settings={}, defaults={"very_important": None}, mandatory=["very_important"] + ) + with pytest.raises(AttributeError) as exc: + settings.very_important + assert str(exc.value) == "OAuth2Provider setting: very_important is mandatory" + + +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) +@pytest.mark.parametrize("issuer_setting", ["http://foo.com/", None]) +@pytest.mark.parametrize("request_type", ["django", "oauthlib"]) +def test_generating_iss_endpoint(oauth2_settings, issuer_setting, request_type, rf): + oauth2_settings.OIDC_ISS_ENDPOINT = issuer_setting + if request_type == "django": + request = rf.get("/") + elif request_type == "oauthlib": + request = Request("/", headers=rf.get("/").META) + expected = issuer_setting or "http://testserver/o" + assert oauth2_settings.oidc_issuer(request) == expected + + +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) +def test_generating_iss_endpoint_type_error(oauth2_settings): + oauth2_settings.OIDC_ISS_ENDPOINT = None + with pytest.raises(TypeError) as exc: + oauth2_settings.oidc_issuer(None) + assert str(exc.value) == "request must be a django or oauthlib request: got None" diff --git a/tests/test_token_revocation.py b/tests/test_token_revocation.py index fdbc072..1ed1c91 100644 --- a/tests/test_token_revocation.py +++ b/tests/test_token_revocation.py @@ -5,10 +5,7 @@ from django.urls import reverse from django.utils import timezone -from oauth2_provider.models import ( - get_access_token_model, get_application_model, get_refresh_token_model -) -from oauth2_provider.settings import oauth2_settings +from oauth2_provider.models import get_access_token_model, get_application_model, get_refresh_token_model Application = get_application_model() @@ -31,8 +28,6 @@ def setUp(self): authorization_grant_type=Application.GRANT_AUTHORIZATION_CODE, ) - oauth2_settings._SCOPES = ["read", "write"] - def tearDown(self): self.application.delete() self.test_user.delete() @@ -41,15 +36,14 @@ def tearDown(self): class TestRevocationView(BaseTest): def test_revoke_access_token(self): - """ - - """ tok = AccessToken.objects.create( - user=self.test_user, token="1234567890", + user=self.test_user, + token="1234567890", application=self.application, expires=timezone.now() + datetime.timedelta(days=1), - scope="read write" + scope="read write", ) + data = { "client_id": self.application.client_id, "client_secret": self.application.client_secret, @@ -72,9 +66,11 @@ def test_revoke_access_token_public(self): public_app.save() tok = AccessToken.objects.create( - user=self.test_user, token="1234567890", application=public_app, + user=self.test_user, + token="1234567890", + application=public_app, expires=timezone.now() + datetime.timedelta(days=1), - scope="read write" + scope="read write", ) data = { @@ -87,21 +83,21 @@ def test_revoke_access_token_public(self): self.assertEqual(response.status_code, 200) def test_revoke_access_token_with_hint(self): - """ - - """ tok = AccessToken.objects.create( - user=self.test_user, token="1234567890", + user=self.test_user, + token="1234567890", application=self.application, expires=timezone.now() + datetime.timedelta(days=1), - scope="read write" + scope="read write", ) + data = { "client_id": self.application.client_id, "client_secret": self.application.client_secret, "token": tok.token, - "token_type_hint": "access_token" + "token_type_hint": "access_token", } + url = reverse("oauth2_provider:revoke-token") response = self.client.post(url, data=data) self.assertEqual(response.status_code, 200) @@ -109,18 +105,21 @@ def test_revoke_access_token_with_hint(self): def test_revoke_access_token_with_invalid_hint(self): tok = AccessToken.objects.create( - user=self.test_user, token="1234567890", + user=self.test_user, + token="1234567890", application=self.application, expires=timezone.now() + datetime.timedelta(days=1), - scope="read write" + scope="read write", ) # invalid hint should have no effect + data = { "client_id": self.application.client_id, "client_secret": self.application.client_secret, "token": tok.token, - "token_type_hint": "bad_hint" + "token_type_hint": "bad_hint", } + url = reverse("oauth2_provider:revoke-token") response = self.client.post(url, data=data) self.assertEqual(response.status_code, 200) @@ -128,20 +127,22 @@ def test_revoke_access_token_with_invalid_hint(self): def test_revoke_refresh_token(self): tok = AccessToken.objects.create( - user=self.test_user, token="1234567890", + user=self.test_user, + token="1234567890", application=self.application, expires=timezone.now() + datetime.timedelta(days=1), - scope="read write" + scope="read write", ) rtok = RefreshToken.objects.create( - user=self.test_user, token="999999999", - application=self.application, access_token=tok + user=self.test_user, token="999999999", application=self.application, access_token=tok ) + data = { "client_id": self.application.client_id, "client_secret": self.application.client_secret, "token": rtok.token, } + url = reverse("oauth2_provider:revoke-token") response = self.client.post(url, data=data) self.assertEqual(response.status_code, 200) @@ -151,14 +152,14 @@ def test_revoke_refresh_token(self): def test_revoke_refresh_token_with_revoked_access_token(self): tok = AccessToken.objects.create( - user=self.test_user, token="1234567890", + user=self.test_user, + token="1234567890", application=self.application, expires=timezone.now() + datetime.timedelta(days=1), - scope="read write" + scope="read write", ) rtok = RefreshToken.objects.create( - user=self.test_user, token="999999999", - application=self.application, access_token=tok + user=self.test_user, token="999999999", application=self.application, access_token=tok ) for token in (tok.token, rtok.token): data = { @@ -166,6 +167,7 @@ def test_revoke_refresh_token_with_revoked_access_token(self): "client_secret": self.application.client_secret, "token": token, } + url = reverse("oauth2_provider:revoke-token") response = self.client.post(url, data=data) self.assertEqual(response.status_code, 200) @@ -183,18 +185,20 @@ def test_revoke_token_with_wrong_hint(self): .. _`Section 4.1.2`: http://tools.ietf.org/html/draft-ietf-oauth-revocation-11#section-4.1.2 """ tok = AccessToken.objects.create( - user=self.test_user, token="1234567890", + user=self.test_user, + token="1234567890", application=self.application, expires=timezone.now() + datetime.timedelta(days=1), - scope="read write" + scope="read write", ) data = { "client_id": self.application.client_id, "client_secret": self.application.client_secret, "token": tok.token, - "token_type_hint": "refresh_token" + "token_type_hint": "refresh_token", } + url = reverse("oauth2_provider:revoke-token") response = self.client.post(url, data=data) self.assertEqual(response.status_code, 200) diff --git a/tests/test_token_view.py b/tests/test_token_view.py index fc3044c..784ea3b 100644 --- a/tests/test_token_view.py +++ b/tests/test_token_view.py @@ -17,6 +17,7 @@ class TestAuthorizedTokenViews(TestCase): """ TestCase superclass for Authorized Token Views" Test Cases """ + def setUp(self): self.foo_user = UserModel.objects.create_user("foo_user", "test@example.com", "123456") self.bar_user = UserModel.objects.create_user("bar_user", "dev@example.com", "123456") @@ -38,6 +39,7 @@ class TestAuthorizedTokenListView(TestAuthorizedTokenViews): """ Tests for the Authorized Token ListView """ + def test_list_view_authorization_required(self): """ Test that the view redirects to login page if user is not logged-in. @@ -62,10 +64,11 @@ def test_list_view_one_token(self): """ self.client.login(username="bar_user", password="123456") AccessToken.objects.create( - user=self.bar_user, token="1234567890", + user=self.bar_user, + token="1234567890", application=self.application, expires=timezone.now() + datetime.timedelta(days=1), - scope="read write" + scope="read write", ) response = self.client.get(reverse("oauth2_provider:authorized-token-list")) @@ -80,16 +83,18 @@ def test_list_view_two_tokens(self): """ self.client.login(username="bar_user", password="123456") AccessToken.objects.create( - user=self.bar_user, token="1234567890", + user=self.bar_user, + token="1234567890", application=self.application, expires=timezone.now() + datetime.timedelta(days=1), - scope="read write" + scope="read write", ) AccessToken.objects.create( - user=self.bar_user, token="0123456789", + user=self.bar_user, + token="0123456789", application=self.application, expires=timezone.now() + datetime.timedelta(days=1), - scope="read write" + scope="read write", ) response = self.client.get(reverse("oauth2_provider:authorized-token-list")) @@ -102,10 +107,11 @@ def test_list_view_shows_correct_user_token(self): """ self.client.login(username="bar_user", password="123456") AccessToken.objects.create( - user=self.foo_user, token="1234567890", + user=self.foo_user, + token="1234567890", application=self.application, expires=timezone.now() + datetime.timedelta(days=1), - scope="read write" + scope="read write", ) response = self.client.get(reverse("oauth2_provider:authorized-token-list")) @@ -117,15 +123,17 @@ class TestAuthorizedTokenDeleteView(TestAuthorizedTokenViews): """ Tests for the Authorized Token DeleteView """ + def test_delete_view_authorization_required(self): """ Test that the view redirects to login page if user is not logged-in. """ self.token = AccessToken.objects.create( - user=self.foo_user, token="1234567890", + user=self.foo_user, + token="1234567890", application=self.application, expires=timezone.now() + datetime.timedelta(days=1), - scope="read write" + scope="read write", ) url = reverse("oauth2_provider:authorized-token-delete", kwargs={"pk": self.token.pk}) @@ -138,10 +146,11 @@ def test_delete_view_works(self): Test that a GET on this view returns 200 if the token belongs to the logged-in user. """ self.token = AccessToken.objects.create( - user=self.foo_user, token="1234567890", + user=self.foo_user, + token="1234567890", application=self.application, expires=timezone.now() + datetime.timedelta(days=1), - scope="read write" + scope="read write", ) self.client.login(username="foo_user", password="123456") @@ -154,10 +163,11 @@ def test_delete_view_token_belongs_to_user(self): Test that a 404 is returned when trying to GET this view with someone else"s tokens. """ self.token = AccessToken.objects.create( - user=self.foo_user, token="1234567890", + user=self.foo_user, + token="1234567890", application=self.application, expires=timezone.now() + datetime.timedelta(days=1), - scope="read write" + scope="read write", ) self.client.login(username="bar_user", password="123456") @@ -170,10 +180,11 @@ def test_delete_view_post_actually_deletes(self): Test that a POST on this view works if the token belongs to the logged-in user. """ self.token = AccessToken.objects.create( - user=self.foo_user, token="1234567890", + user=self.foo_user, + token="1234567890", application=self.application, expires=timezone.now() + datetime.timedelta(days=1), - scope="read write" + scope="read write", ) self.client.login(username="foo_user", password="123456") @@ -187,10 +198,11 @@ def test_delete_view_only_deletes_user_own_token(self): Test that a 404 is returned when trying to POST on this view with someone else"s tokens. """ self.token = AccessToken.objects.create( - user=self.foo_user, token="1234567890", + user=self.foo_user, + token="1234567890", application=self.application, expires=timezone.now() + datetime.timedelta(days=1), - scope="read write" + scope="read write", ) self.client.login(username="bar_user", password="123456") diff --git a/tests/test_validators.py b/tests/test_validators.py index 82930a9..0760e02 100644 --- a/tests/test_validators.py +++ b/tests/test_validators.py @@ -1,10 +1,11 @@ +import pytest from django.core.validators import ValidationError from django.test import TestCase -from oauth2_provider.settings import oauth2_settings from oauth2_provider.validators import RedirectURIValidator +@pytest.mark.usefixtures("oauth2_settings") class TestValidators(TestCase): def test_validate_good_uris(self): validator = RedirectURIValidator(allowed_schemes=["https"]) @@ -37,7 +38,7 @@ def test_validate_custom_uri_scheme(self): def test_validate_bad_uris(self): validator = RedirectURIValidator(allowed_schemes=["https"]) - oauth2_settings.ALLOWED_REDIRECT_URI_SCHEMES = ["https", "good"] + self.oauth2_settings.ALLOWED_REDIRECT_URI_SCHEMES = ["https", "good"] bad_uris = [ "http:/example.com", "HTTP://localhost", diff --git a/tests/urls.py b/tests/urls.py index 16dcf6d..0661a93 100644 --- a/tests/urls.py +++ b/tests/urls.py @@ -1,13 +1,11 @@ -from django.conf.urls import include, url from django.contrib import admin +from django.urls import include, path admin.autodiscover() urlpatterns = [ - url(r"^o/", include("oauth2_provider.urls", namespace="oauth2_provider")), + path("o/", include("oauth2_provider.urls", namespace="oauth2_provider")), + path("admin/", admin.site.urls), ] - - -urlpatterns += [url(r"^admin/", admin.site.urls)] diff --git a/tests/utils.py b/tests/utils.py index ec25905..b7dc200 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,4 +1,5 @@ import base64 +from unittest import mock def get_basic_auth_header(user, password): @@ -13,3 +14,19 @@ def get_basic_auth_header(user, password): } return auth_headers + + +def spy_on(meth): + """ + Util function to add a spy onto a method of a class. + """ + spy = mock.MagicMock() + + def wrapper(self, *args, **kwargs): + spy(self, *args, **kwargs) + return_value = meth(self, *args, **kwargs) + spy.returned = return_value + return return_value + + wrapper.spy = spy + return wrapper diff --git a/tox.ini b/tox.ini index 7d2de23..8371aab 100644 --- a/tox.ini +++ b/tox.ini @@ -1,67 +1,89 @@ [tox] envlist = - py37-flake8, - py37-docs, - py38-django{30,22,21}, - py37-django{30,22,21}, - py36-django{22,21}, - py35-django{22,21}, - py38-djangomaster, - py37-djangomaster, - py36-djangomaster, + flake8, + docs, + py{36,37,38,39}-dj{32,31,22}, + py{38,39}-djmain, + +[gh-actions] +python = + 3.6: py36 + 3.7: py37 + 3.8: py38, docs, flake8 + 3.9: py39 [pytest] django_find_project = false +addopts = + --cov=oauth2_provider + --cov-report= + --cov-append + -s +markers = + oauth2_settings: Custom OAuth2 settings to use - use with oauth2_settings fixture [testenv] commands = - pytest --cov=oauth2_provider --cov-report= --cov-append {posargs} -s + pytest {posargs} + coverage report + coverage xml setenv = - DJANGO_SETTINGS_MODULE = tests.settings - PYTHONPATH = {toxinidir} - PYTHONWARNINGS = all + DJANGO_SETTINGS_MODULE = tests.settings + PYTHONPATH = {toxinidir} + PYTHONWARNINGS = all deps = - django21: Django>=2.1,<2.2 - django22: Django>=2.2,<3 - django30: Django>=3.0,<3.1 - djangomaster: https://github.com/django/django/archive/master.tar.gz - djangorestframework - oauthlib>=3.0.1 - coverage - pytest - pytest-cov - pytest-django - pytest-xdist - py27: mock - requests - jwcrypto + dj22: Django>=2.2,<3 + dj31: Django>=3.1,<3.2 + dj32: Django>=3.2,<3.3 + djmain: https://github.com/django/django/archive/main.tar.gz + djangorestframework + oauthlib>=3.1.0 + jwcrypto + coverage + pytest + pytest-cov + pytest-django + pytest-xdist + pytest-mock + requests +passenv = + PYTEST_ADDOPTS + +[testenv:py{38,39}-djmain] +ignore_errors = true +ignore_outcome = true -[testenv:py37-docs] -basepython = python +[testenv:{docs,livedocs}] +basepython = python3.8 changedir = docs whitelist_externals = make -commands = make html -deps = sphinx<3 - oauthlib>=3.0.1 - m2r>=0.2.1 - jwcrypto +commands = + docs: make html + livedocs: make livehtml +deps = + sphinx<3 + oauthlib>=3.1.0 + m2r>=0.2.1 + sphinx-rtd-theme + livedocs: sphinx-autobuild + jwcrypto -[testenv:py37-flake8] +[testenv:flake8] +basepython = python3.8 skip_install = True -commands = - flake8 {toxinidir} +commands = flake8 {toxinidir} deps = - flake8 - flake8-isort - flake8-quotes + flake8 + flake8-isort + flake8-quotes + flake8-black [testenv:install] deps = twine setuptools>=39.0 wheel -whitelist_externals= - rm +whitelist_externals = rm commands = rm -rf dist python setup.py sdist bdist_wheel @@ -70,21 +92,26 @@ commands = [coverage:run] source = oauth2_provider -omit = - */migrations/* - oauth2_provider/settings.py +omit = */migrations/* + +[coverage:report] +show_missing = True [flake8] max-line-length = 110 exclude = docs/, oauth2_provider/migrations/, tests/migrations/, .tox/ application-import-names = oauth2_provider inline-quotes = double +extend-ignore = E203, W503 [isort] -balanced_wrapping = True default_section = THIRDPARTY known_first_party = oauth2_provider -line_length = 80 +line_length = 110 lines_after_imports = 2 -multi_line_output = 5 -skip = oauth2_provider/migrations/, .tox/ +multi_line_output = 3 +include_trailing_comma = True +force_grid_wrap = 0 +use_parentheses = True +ensure_newline_before_comments = True +skip = oauth2_provider/migrations/, .tox/, tests/migrations/