Skip to content

Commit 2ee2a42

Browse files
committed
Refactor to pass OAuth and token decode methods consistently
1 parent 12d9255 commit 2ee2a42

File tree

6 files changed

+24
-15
lines changed

6 files changed

+24
-15
lines changed

example_configs/external_service/custom.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Any, Self
2+
13
import numpy
24

35
from tiled.adapters.array import ArrayAdapter
@@ -55,7 +57,7 @@ def __init__(self, base_url, metadata=None):
5557
self.client = MockClient(base_url)
5658
self.metadata = metadata
5759

58-
def with_session_state(self, state):
60+
def with_session_state(self, state: dict[str, Any]) -> Self:
5961
return AuthenticatedAdapter(self.client, state["token"], metadata=self.metadata)
6062

6163

tiled/authn_database/connection_pool.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from fastapi import Depends
22
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
33

4-
from ..server.settings import get_settings
4+
from ..server.settings import Settings, get_settings
55
from ..utils import ensure_specified_sql_driver
66

77
# A given process probably only has one of these at a time, but we
@@ -31,7 +31,7 @@ async def close_database_connection_pool(database_settings):
3131
await engine.dispose()
3232

3333

34-
async def get_database_engine(settings=Depends(get_settings)):
34+
async def get_database_engine(settings: Settings = Depends(get_settings)):
3535
# Special case for single-user mode
3636
if settings.database_uri is None:
3737
return None

tiled/server/app.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -411,18 +411,20 @@ async def unhandled_exception_handler(
411411
token_decoder, authenticators, oauth2_scheme
412412
)
413413
get_current_principal = current_principal_getter(
414+
token_decoder,
414415
authenticators,
415416
oauth2_scheme,
416-
token_decoder,
417417
)
418418

419419
# And add this authentication_router itself to the app.
420420
app.include_router(authentication_router, prefix="/api/v1/auth")
421+
get_session_state = session_state_getter(token_decoder, oauth2_scheme)
421422

422423
else:
423424
get_current_principal = get_current_principal_from_api_key
424425

425-
get_session_state = session_state_getter(token_decoder)
426+
def get_session_state():
427+
return None
426428

427429
app.include_router(
428430
get_router(

tiled/server/authentication.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -270,12 +270,13 @@ async def token_from_request(
270270

271271
def session_state_getter(
272272
token_decoder: Callable[[str], Awaitable[Optional[dict[str, Any]]]],
273+
oauth2_scheme: OAuth2,
273274
):
274-
async def get_session_state(
275-
decoded_access_token: Optional[dict[str, Any]] = Depends(token_decoder)
276-
):
277-
if decoded_access_token:
278-
return decoded_access_token.get("state")
275+
async def get_session_state(access_token: Optional[str] = Depends(oauth2_scheme)):
276+
if access_token:
277+
decoded_access_token = await token_decoder(access_token)
278+
if decoded_access_token:
279+
return decoded_access_token.get("state")
279280

280281
return get_session_state
281282

tiled/server/dependencies.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional, Tuple, Union
1+
from typing import Any, Callable, Mapping, Optional, Tuple, Union
22

33
from fastapi import Depends, HTTPException, Query, Request, Security
44
from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND
@@ -13,13 +13,17 @@
1313
SLICE_REGEX = rf"^{DIM_REGEX}(?:,{DIM_REGEX})*$"
1414

1515

16-
def SecureEntryBuilder(get_current_principal, tree, get_session_state):
16+
def SecureEntryBuilder(
17+
tree: Mapping[str, Any],
18+
get_current_principal: Callable[..., Optional[str]],
19+
get_session_state: Callable[..., Optional[dict[str, Any]]],
20+
):
1721
def SecureEntry(scopes, structure_families=None):
1822
async def inner(
1923
path: str,
2024
request: Request,
21-
principal: str = Depends(get_current_principal),
22-
session_state: dict = Depends(get_session_state),
25+
principal: Optional[str] = Depends(get_current_principal),
26+
session_state: Optional[dict[str, Any]] = Depends(get_session_state),
2327
):
2428
"""
2529
Obtain a node in the tree from its path.

tiled/server/router.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def get_router(
7777
validation_registry: ValidationRegistry,
7878
) -> APIRouter:
7979
router = APIRouter()
80-
SecureEntry = SecureEntryBuilder(get_current_principal, tree, get_session_state)
80+
SecureEntry = SecureEntryBuilder(tree, get_current_principal, get_session_state)
8181

8282
@router.get("/", response_model=About)
8383
async def about(

0 commit comments

Comments
 (0)