Skip to content

Commit

Permalink
Add auto_error parameter and fix authorization bug
Browse files Browse the repository at this point in the history
- Added `auto_error` parameter to the `extract` method in `IdTokenExtractor`.
- Updated `HttpBearerExtractor` to use the `auto_error` parameter.
- Fixed bug where bypassable requests sent `None` to the `authorize` function, causing authorization failures.
- Made `authenticate` method more explicit about `id_token_extractor` and removed unnecessary line that returned `None`.
  • Loading branch information
torkashvandmt committed Jun 6, 2024
1 parent 79097ec commit e0bfb18
Showing 1 changed file with 12 additions and 9 deletions.
21 changes: 12 additions & 9 deletions oauth2_lib/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,18 +137,22 @@ class IdTokenExtractor(ABC):
"""

@abstractmethod
async def extract(self, request: Request) -> Optional[str]:
async def extract(self, request: Request, auto_error: bool = True) -> Optional[str]:
pass


class HttpBearerExtractor(IdTokenExtractor):
"""Extracts bearer tokens using FastAPI's HTTPBearer.
Specifically designed for HTTP Authorization header token extraction.
By default, if an HTTP Bearer token is not provided in the `Authorization` header,
the `extract` method will cancel the request and send an error unless `auto_error`
is set to `False`, allowing optional or multiple authentication methods.
"""

async def extract(self, request: Request) -> Optional[str]:
http_bearer = HTTPBearer(auto_error=True)
async def extract(self, request: Request, auto_error: bool = True) -> Optional[str]:
http_bearer = HTTPBearer(auto_error=auto_error)
credential = await http_bearer(request)

return credential.credentials if credential else None
Expand Down Expand Up @@ -209,13 +213,12 @@ async def authenticate(self, request: HTTPConnection, token: Optional[str] = Non
token_or_extracted_id_token = token
else:
request = cast(Request, request)

if await self.is_bypassable_request(request):
return None

if token is None:
extracted_id_token = await self.id_token_extractor.extract(request)
if not extracted_id_token:
return None
token_or_extracted_id_token = extracted_id_token
token_or_extracted_id_token = await self.id_token_extractor.extract(request, auto_error=True)
else:
token_or_extracted_id_token = token

Expand Down Expand Up @@ -346,7 +349,7 @@ async def authorize(self, request: HTTPConnection, user_info: OIDCUserModel) ->
opa_input = {
"input": {
**(self.opa_kwargs or {}),
**user_info,
**(user_info or {}),
"resource": request.url.path,
"method": request_method,
"arguments": {"path": request.path_params, "query": {**request.query_params}, "json": json},
Expand Down Expand Up @@ -383,7 +386,7 @@ async def authorize(self, request: RequestPath, user_info: OIDCUserModel) -> Opt
opa_input = {
"input": {
**(self.opa_kwargs or {}),
**user_info,
**(user_info or {}),
"resource": request,
"method": "POST",
}
Expand Down

0 comments on commit e0bfb18

Please sign in to comment.