23
23
from httpx import AsyncClient , NetworkError
24
24
from pydantic import BaseModel
25
25
from starlette .requests import ClientDisconnect , HTTPConnection
26
+ from starlette .status import HTTP_403_FORBIDDEN
26
27
from starlette .websockets import WebSocket
27
28
from structlog import get_logger
28
29
@@ -137,22 +138,18 @@ class IdTokenExtractor(ABC):
137
138
"""
138
139
139
140
@abstractmethod
140
- async def extract (self , request : Request , auto_error : bool = True ) -> Optional [str ]:
141
+ async def extract (self , request : Request ) -> Optional [str ]:
141
142
pass
142
143
143
144
144
145
class HttpBearerExtractor (IdTokenExtractor ):
145
146
"""Extracts bearer tokens using FastAPI's HTTPBearer.
146
147
147
148
Specifically designed for HTTP Authorization header token extraction.
148
-
149
- By default, if an HTTP Bearer token is not provided in the `Authorization` header,
150
- the `extract` method will cancel the request and send an error unless `auto_error`
151
- is set to `False`, allowing optional or multiple authentication methods.
152
149
"""
153
150
154
- async def extract (self , request : Request , auto_error : bool = True ) -> Optional [str ]:
155
- http_bearer = HTTPBearer (auto_error = auto_error )
151
+ async def extract (self , request : Request ) -> Optional [str ]:
152
+ http_bearer = HTTPBearer (auto_error = False )
156
153
credential = await http_bearer (request )
157
154
158
155
return credential .credentials if credential else None
@@ -218,7 +215,11 @@ async def authenticate(self, request: HTTPConnection, token: Optional[str] = Non
218
215
return None
219
216
220
217
if token is None :
221
- token_or_extracted_id_token = await self .id_token_extractor .extract (request , auto_error = True ) or ""
218
+ extracted_id_token = await self .id_token_extractor .extract (request )
219
+ if not extracted_id_token :
220
+ raise HTTPException (status_code = HTTP_403_FORBIDDEN , detail = "Not authenticated" )
221
+
222
+ token_or_extracted_id_token = extracted_id_token
222
223
else :
223
224
token_or_extracted_id_token = token
224
225
@@ -262,7 +263,7 @@ class Authorization(ABC):
262
263
"""
263
264
264
265
@abstractmethod
265
- async def authorize (self , request : HTTPConnection , user : OIDCUserModel ) -> Optional [bool ]:
266
+ async def authorize (self , request : HTTPConnection , user : Optional [ OIDCUserModel ] = None ) -> Optional [bool ]:
266
267
pass
267
268
268
269
@@ -273,7 +274,7 @@ class GraphqlAuthorization(ABC):
273
274
"""
274
275
275
276
@abstractmethod
276
- async def authorize (self , request : RequestPath , user : OIDCUserModel ) -> Optional [bool ]:
277
+ async def authorize (self , request : RequestPath , user : Optional [ OIDCUserModel ] = None ) -> Optional [bool ]:
277
278
pass
278
279
279
280
@@ -323,7 +324,7 @@ class OPAAuthorization(Authorization, OPAMixin):
323
324
Uses OAUTH2 settings and request information to authorize actions.
324
325
"""
325
326
326
- async def authorize (self , request : HTTPConnection , user_info : OIDCUserModel ) -> Optional [bool ]:
327
+ async def authorize (self , request : HTTPConnection , user_info : Optional [ OIDCUserModel ] = None ) -> Optional [bool ]:
327
328
if not (oauth2lib_settings .OAUTH2_ACTIVE and oauth2lib_settings .OAUTH2_AUTHORIZATION_ACTIVE ):
328
329
return None
329
330
@@ -379,7 +380,7 @@ def __init__(self, opa_url: str, auto_error: bool = False, opa_kwargs: Union[Map
379
380
# By default don't raise HTTP 403 because partial results are preferred
380
381
super ().__init__ (opa_url , auto_error , opa_kwargs )
381
382
382
- async def authorize (self , request : RequestPath , user_info : OIDCUserModel ) -> Optional [bool ]:
383
+ async def authorize (self , request : RequestPath , user_info : Optional [ OIDCUserModel ] = None ) -> Optional [bool ]:
383
384
if not (oauth2lib_settings .OAUTH2_ACTIVE and oauth2lib_settings .OAUTH2_AUTHORIZATION_ACTIVE ):
384
385
return None
385
386
0 commit comments