diff --git a/.bumpversion.cfg b/.bumpversion.cfg index 8f3f6f5..a0c0178 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 2.0.1 +current_version = 2.0.2 commit = False tag = False parse = (?P\d+)\.(?P\d+)\.(?P\d+)(\-(?P[a-z]+)(?P\d+))? diff --git a/oauth2_lib/__init__.py b/oauth2_lib/__init__.py index 92b1936..42202ab 100644 --- a/oauth2_lib/__init__.py +++ b/oauth2_lib/__init__.py @@ -13,4 +13,4 @@ """This is the SURF Oauth2 module that interfaces with the oauth2 setup.""" -__version__ = "2.0.1" +__version__ = "2.0.2" diff --git a/oauth2_lib/fastapi.py b/oauth2_lib/fastapi.py index 5b72483..bcf3262 100644 --- a/oauth2_lib/fastapi.py +++ b/oauth2_lib/fastapi.py @@ -23,6 +23,7 @@ from httpx import AsyncClient, NetworkError from pydantic import BaseModel from starlette.requests import ClientDisconnect, HTTPConnection +from starlette.status import HTTP_403_FORBIDDEN from starlette.websockets import WebSocket from structlog import get_logger @@ -148,7 +149,7 @@ class HttpBearerExtractor(IdTokenExtractor): """ async def extract(self, request: Request) -> Optional[str]: - http_bearer = HTTPBearer(auto_error=True) + http_bearer = HTTPBearer(auto_error=False) credential = await http_bearer(request) return credential.credentials if credential else None @@ -209,12 +210,15 @@ async def authenticate(self, request: HTTPConnection, token: Optional[str] = Non token_or_extracted_id_token = token else: request = cast(Request, request) + if await self.is_bypassable_request(request): return None + if token is None: extracted_id_token = await self.id_token_extractor.extract(request) if not extracted_id_token: - return None + raise HTTPException(status_code=HTTP_403_FORBIDDEN, detail="Not authenticated") + token_or_extracted_id_token = extracted_id_token else: token_or_extracted_id_token = token @@ -346,7 +350,7 @@ async def authorize(self, request: HTTPConnection, user_info: OIDCUserModel) -> opa_input = { "input": { **(self.opa_kwargs or {}), - **user_info, + **(user_info or {}), "resource": request.url.path, "method": request_method, "arguments": {"path": request.path_params, "query": {**request.query_params}, "json": json}, @@ -383,7 +387,7 @@ async def authorize(self, request: RequestPath, user_info: OIDCUserModel) -> Opt opa_input = { "input": { **(self.opa_kwargs or {}), - **user_info, + **(user_info or {}), "resource": request, "method": "POST", } diff --git a/tests/test_fastapi.py b/tests/test_fastapi.py index a444d39..ed1eb7c 100644 --- a/tests/test_fastapi.py +++ b/tests/test_fastapi.py @@ -149,13 +149,11 @@ async def test_extract_token_success(): @pytest.mark.asyncio -async def test_extract_token_failure(): +async def test_extract_token_returns_none(): request = mock.MagicMock() request.headers = {} extractor = HttpBearerExtractor() - with pytest.raises(HTTPException) as exc_info: - await extractor.extract(request) - assert exc_info.value.status_code == 403, "Expected HTTP 403 error for missing token" + assert await extractor.extract(request) is None @pytest.mark.asyncio