Skip to content

Commit

Permalink
Add error_type extension to Strawberry Permission classes & Drop Pyth…
Browse files Browse the repository at this point in the history
…on 3.9/3.10 (#64)

* Drop support for python 3.9 and 3.10, run pyupgrade

* orchestrator-core#649 Add "error_type" error_extensions to strawberry Permission classes

* Bump version to 2.1.0
  • Loading branch information
Mark90 authored Jul 8, 2024
1 parent 7b8020a commit cb592da
Show file tree
Hide file tree
Showing 10 changed files with 80 additions and 60 deletions.
2 changes: 1 addition & 1 deletion .bumpversion.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 2.0.2
current_version = 2.1.0
commit = False
tag = False
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(\-(?P<release>[a-z]+)(?P<build>\d+))?
Expand Down
10 changes: 6 additions & 4 deletions .github/workflows/test-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.9', '3.10', '3.11', '3.12']
python-version: ['3.11', '3.12']
pydantic-version: ['1.*', '2.*']
fail-fast: false
steps:
Expand All @@ -28,10 +28,12 @@ jobs:
flit install --deps develop
pip install -U "pydantic==${{ matrix.pydantic-version }}"
pip install pydantic_settings || true
- name: Lint
- name: Check formatting
run: |
black . --check
ruff .
black --check .
- name: Lint with ruff
run: |
ruff check .
- name: MyPy
run: |
mypy -p oauth2_lib
Expand Down
22 changes: 11 additions & 11 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
repos:
- repo: https://github.com/asottile/pyupgrade
rev: v3.8.0
rev: v3.16.0
hooks:
- id: pyupgrade
args:
- --py39-plus
- --py311-plus
- --keep-runtime-typing
- repo: https://github.com/psf/black
rev: 23.3.0
rev: 24.4.2
hooks:
- id: black
language_version: python3.9
language_version: python3.11
- repo: https://github.com/asottile/blacken-docs
rev: 1.14.0
rev: 1.18.0
hooks:
- id: blacken-docs
additional_dependencies: [black==22.10.0]
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v4.6.0
hooks:
- id: trailing-whitespace
exclude: .bumpversion.cfg
Expand All @@ -31,15 +31,15 @@ repos:
- id: detect-private-key
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.0.275
rev: v0.5.1
hooks:
- id: ruff
args: [ --fix, --exit-non-zero-on-fix, --show-fixes ]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.4.1
rev: v1.10.1
hooks:
- id: mypy
language_version: python3.9
language_version: python3.11
additional_dependencies: [pydantic<2.0.0, strawberry-graphql]
args:
- --no-warn-unused-ignores
Expand All @@ -53,10 +53,10 @@ repos:
- id: python-check-mock-methods
- id: rst-backticks
- repo: https://github.com/shellcheck-py/shellcheck-py
rev: v0.9.0.5
rev: v0.10.0.1
hooks:
- id: shellcheck
- repo: https://github.com/andreoliwa/nitpick
rev: v0.33.2
rev: v0.35.0
hooks:
- id: nitpick
2 changes: 1 addition & 1 deletion oauth2_lib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@

"""This is the SURF Oauth2 module that interfaces with the oauth2 setup."""

__version__ = "2.0.2"
__version__ = "2.1.0"
4 changes: 2 additions & 2 deletions oauth2_lib/async_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# limitations under the License.
from asyncio import new_event_loop
from http import HTTPStatus
from typing import Any, Union
from typing import Any

import structlog
from authlib.integrations.base_client import BaseOAuth
Expand Down Expand Up @@ -61,7 +61,7 @@ class FubarApiClient(AuthMixin, fubar_client.ApiClient)
"""

_token: Union[dict, None]
_token: dict | None

def __init__(
self,
Expand Down
32 changes: 16 additions & 16 deletions oauth2_lib/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
# limitations under the License.
import ssl
from abc import ABC, abstractmethod
from collections.abc import Awaitable, Mapping
from collections.abc import Awaitable, Callable, Mapping
from http import HTTPStatus
from json import JSONDecodeError
from typing import Any, Callable, Optional, Union, cast
from typing import Any, Optional, cast

from fastapi import HTTPException
from fastapi.requests import Request
Expand Down Expand Up @@ -93,8 +93,8 @@ class OIDCConfig(BaseModel):
authorization_endpoint: str
token_endpoint: str
userinfo_endpoint: str
introspect_endpoint: Optional[str] = None
introspection_endpoint: Optional[str] = None
introspect_endpoint: str | None = None
introspection_endpoint: str | None = None
jwks_uri: str
response_types_supported: list[str]
response_modes_supported: list[str]
Expand Down Expand Up @@ -126,7 +126,7 @@ class Authentication(ABC):
"""

@abstractmethod
async def authenticate(self, request: HTTPConnection, token: Optional[str] = None) -> Optional[dict]:
async def authenticate(self, request: HTTPConnection, token: str | None = None) -> dict | None:
"""Authenticate the user."""
pass

Expand All @@ -138,7 +138,7 @@ class IdTokenExtractor(ABC):
"""

@abstractmethod
async def extract(self, request: Request) -> Optional[str]:
async def extract(self, request: Request) -> str | None:
pass


Expand All @@ -148,7 +148,7 @@ class HttpBearerExtractor(IdTokenExtractor):
Specifically designed for HTTP Authorization header token extraction.
"""

async def extract(self, request: Request) -> Optional[str]:
async def extract(self, request: Request) -> str | None:
http_bearer = HTTPBearer(auto_error=False)
credential = await http_bearer(request)

Expand All @@ -168,7 +168,7 @@ def __init__(
resource_server_id: str,
resource_server_secret: str,
oidc_user_model_cls: type[OIDCUserModel],
id_token_extractor: Optional[IdTokenExtractor] = None,
id_token_extractor: IdTokenExtractor | None = None,
):
if not id_token_extractor:
self.id_token_extractor = HttpBearerExtractor()
Expand All @@ -179,9 +179,9 @@ def __init__(
self.resource_server_secret = resource_server_secret
self.user_model_cls = oidc_user_model_cls

self.openid_config: Optional[OIDCConfig] = None
self.openid_config: OIDCConfig | None = None

async def authenticate(self, request: HTTPConnection, token: Optional[str] = None) -> Optional[OIDCUserModel]:
async def authenticate(self, request: HTTPConnection, token: str | None = None) -> OIDCUserModel | None:
"""Return the OIDC user from OIDC introspect endpoint.
This is used as a security module in Fastapi projects
Expand Down Expand Up @@ -263,7 +263,7 @@ class Authorization(ABC):
"""

@abstractmethod
async def authorize(self, request: HTTPConnection, user: OIDCUserModel) -> Optional[bool]:
async def authorize(self, request: HTTPConnection, user: OIDCUserModel) -> bool | None:
pass


Expand All @@ -274,7 +274,7 @@ class GraphqlAuthorization(ABC):
"""

@abstractmethod
async def authorize(self, request: RequestPath, user: OIDCUserModel) -> Optional[bool]:
async def authorize(self, request: RequestPath, user: OIDCUserModel) -> bool | None:
pass


Expand All @@ -284,7 +284,7 @@ class OPAMixin:
Supports getting and evaluating OPA policy decisions.
"""

def __init__(self, opa_url: str, auto_error: bool = True, opa_kwargs: Union[Mapping[str, Any], None] = None):
def __init__(self, opa_url: str, auto_error: bool = True, opa_kwargs: Mapping[str, Any] | None = None):
self.opa_url = opa_url
self.auto_error = auto_error
self.opa_kwargs = opa_kwargs
Expand Down Expand Up @@ -324,7 +324,7 @@ class OPAAuthorization(Authorization, OPAMixin):
Uses OAUTH2 settings and request information to authorize actions.
"""

async def authorize(self, request: HTTPConnection, user_info: OIDCUserModel) -> Optional[bool]:
async def authorize(self, request: HTTPConnection, user_info: OIDCUserModel) -> bool | None:
if not (oauth2lib_settings.OAUTH2_ACTIVE and oauth2lib_settings.OAUTH2_AUTHORIZATION_ACTIVE):
return None

Expand Down Expand Up @@ -376,11 +376,11 @@ class GraphQLOPAAuthorization(GraphqlAuthorization, OPAMixin):
Customizable to handle partial results without raising HTTP 403.
"""

def __init__(self, opa_url: str, auto_error: bool = False, opa_kwargs: Union[Mapping[str, Any], None] = None):
def __init__(self, opa_url: str, auto_error: bool = False, opa_kwargs: Mapping[str, Any] | None = None):
# By default don't raise HTTP 403 because partial results are preferred
super().__init__(opa_url, auto_error, opa_kwargs)

async def authorize(self, request: RequestPath, user_info: OIDCUserModel) -> Optional[bool]:
async def authorize(self, request: RequestPath, user_info: OIDCUserModel) -> bool | None:
if not (oauth2lib_settings.OAUTH2_ACTIVE and oauth2lib_settings.OAUTH2_AUTHORIZATION_ACTIVE):
return None

Expand Down
39 changes: 27 additions & 12 deletions oauth2_lib/strawberry.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional, Union
from collections.abc import Callable
from enum import StrEnum, auto
from typing import Any

import asyncstdlib
import strawberry
Expand Down Expand Up @@ -39,7 +41,7 @@ def __init__(
super().__init__()

@asyncstdlib.cached_property
async def get_current_user(self) -> Optional[OIDCUserModel]:
async def get_current_user(self) -> OIDCUserModel | None:
"""Retrieve the OIDCUserModel once per graphql request.
Note:
Expand Down Expand Up @@ -118,8 +120,16 @@ async def is_authorized(info: OauthInfo, path: str) -> bool:
return authorized


class ErrorType(StrEnum):
"""Subset of the ErrorType enum in nwa-stdlib."""

NOT_AUTHENTICATED = auto()
NOT_AUTHORIZED = auto()


class IsAuthenticatedForQuery(BasePermission):
message = "User is not authenticated"
error_extensions = {"error_type": ErrorType.NOT_AUTHENTICATED}

async def has_permission(self, source: Any, info: OauthInfo, **kwargs) -> bool: # type: ignore
if not oauth2lib_settings.OAUTH2_ACTIVE:
Expand All @@ -135,6 +145,7 @@ async def has_permission(self, source: Any, info: OauthInfo, **kwargs) -> bool:

class IsAuthenticatedForMutation(BasePermission):
message = "User is not authenticated"
error_extensions = {"error_type": ErrorType.NOT_AUTHENTICATED}

async def has_permission(self, source: Any, info: OauthInfo, **kwargs) -> bool: # type: ignore
mutations_active = oauth2lib_settings.OAUTH2_ACTIVE and oauth2lib_settings.MUTATIONS_ENABLED
Expand All @@ -145,6 +156,8 @@ async def has_permission(self, source: Any, info: OauthInfo, **kwargs) -> bool:


class IsAuthorizedForQuery(BasePermission):
error_extensions = {"error_type": ErrorType.NOT_AUTHORIZED}

async def has_permission(self, source: Any, info: OauthInfo, **kwargs) -> bool: # type: ignore
if not (oauth2lib_settings.OAUTH2_ACTIVE and oauth2lib_settings.OAUTH2_AUTHORIZATION_ACTIVE):
logger.debug(
Expand All @@ -163,6 +176,8 @@ async def has_permission(self, source: Any, info: OauthInfo, **kwargs) -> bool:


class IsAuthorizedForMutation(BasePermission):
error_extensions = {"error_type": ErrorType.NOT_AUTHORIZED}

async def has_permission(self, source: Any, info: OauthInfo, **kwargs) -> bool: # type: ignore
mutations_active = (
oauth2lib_settings.OAUTH2_ACTIVE
Expand All @@ -182,9 +197,9 @@ async def has_permission(self, source: Any, info: OauthInfo, **kwargs) -> bool:

def authenticated_field(
description: str,
resolver: Union[StrawberryResolver, Callable, staticmethod, classmethod, None] = None,
deprecation_reason: Union[str, None] = None,
permission_classes: Union[list[type[BasePermission]], None] = None,
resolver: StrawberryResolver | Callable | staticmethod | classmethod | None = None,
deprecation_reason: str | None = None,
permission_classes: list[type[BasePermission]] | None = None,
) -> Any:
permissions = permission_classes if permission_classes else []
return strawberry.field(
Expand All @@ -197,9 +212,9 @@ def authenticated_field(

def authenticated_mutation_field(
description: str,
resolver: Union[StrawberryResolver, Callable, staticmethod, classmethod, None] = None,
deprecation_reason: Union[str, None] = None,
permission_classes: Union[list[type[BasePermission]], None] = None,
resolver: StrawberryResolver | Callable | staticmethod | classmethod | None = None,
deprecation_reason: str | None = None,
permission_classes: list[type[BasePermission]] | None = None,
) -> Any:
permissions = permission_classes if permission_classes else []
return strawberry.field(
Expand All @@ -212,10 +227,10 @@ def authenticated_mutation_field(

def authenticated_federated_field( # type: ignore
description: str,
resolver: Union[StrawberryResolver, Callable, staticmethod, classmethod, None] = None,
deprecation_reason: Union[str, None] = None,
requires: Union[list[str], None] = None,
permission_classes: Union[list[type[BasePermission]], None] = None,
resolver: StrawberryResolver | Callable | staticmethod | classmethod | None = None,
deprecation_reason: str | None = None,
requires: list[str] | None = None,
permission_classes: list[type[BasePermission]] | None = None,
**kwargs,
) -> Any:
permissions = permission_classes if permission_classes else []
Expand Down
19 changes: 12 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ classifiers = [
"Programming Language :: Python :: 3 :: Only",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.9",
]
requires = [
"requests>=2.19.0",
Expand All @@ -41,7 +39,7 @@ requires = [
"asyncstdlib",
]
description-file = "README.md"
requires-python = ">=3.9,<3.13"
requires-python = ">=3.11,<3.13"

[tool.flit.metadata.urls]
Documentation = "https://workfloworchestrator.org/"
Expand Down Expand Up @@ -102,6 +100,10 @@ exclude = [
"build",
".venv",
]
target-version = "py311"
line-length = 120

[tool.ruff.lint]
ignore = [
"B008",
"D100",
Expand All @@ -118,9 +120,9 @@ ignore = [
"B904",
"N802",
"N801",
"N818"
"N818",
"S113", # HTTPX has a default timeout
]
line-length = 120
select = [
"B",
"C",
Expand All @@ -134,7 +136,10 @@ select = [
"T",
"W",
]
target-version = "py310"

[tool.ruff.pydocstyle]
[tool.ruff.lint.flake8-tidy-imports]
ban-relative-imports = "all"


[tool.ruff.lint.pydocstyle]
convention = "google"
Loading

0 comments on commit cb592da

Please sign in to comment.