Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
e80b535
Add Authenticator for when serving OIDC as a proxy
DiamondJoseph Feb 19, 2025
0589387
Refactor Authentication router construction
DiamondJoseph Feb 19, 2025
939cdb3
Invert the creation of API routes
DiamondJoseph Feb 19, 2025
fd7be50
Remove incorrect caches
DiamondJoseph Feb 19, 2025
3743679
Linting, minimising code changes
DiamondJoseph Feb 19, 2025
018d72f
Fix linting for Python 3.9
DiamondJoseph Feb 19, 2025
b50d216
Rename default registries
DiamondJoseph Feb 19, 2025
4b88a7b
Name consistent and move scope handling
DiamondJoseph Feb 19, 2025
ed75f3e
Pass first authenticator to make decode route
DiamondJoseph Feb 19, 2025
471b797
Remove unused method params
DiamondJoseph Feb 19, 2025
c279181
types
DiamondJoseph Feb 19, 2025
a1669a0
Split decode and fetch token
DiamondJoseph Feb 19, 2025
7b20ba8
Refactor to pass OAuth and token decode methods consistently
DiamondJoseph Feb 20, 2025
a0bea99
Await awaitables
DiamondJoseph Feb 20, 2025
f511991
Restore route signature patching with Query registry
DiamondJoseph Feb 20, 2025
2cb5268
nit: reorder args
DiamondJoseph Feb 20, 2025
f6f9a10
Move definition of routes internally and remove use of non 3.9 Typing…
DiamondJoseph Feb 20, 2025
3ca1344
Fix misplaced prefix.
danielballan Mar 6, 2025
765bdae
Update suggested usage in comment.
danielballan Mar 7, 2025
f3cc565
Fix annotations
danielballan Mar 7, 2025
08b153c
Expired token should send 401 to prompt refresh
danielballan Mar 7, 2025
58cf84f
Stash authn database engine on app.state, not global dict.
danielballan Mar 8, 2025
e972d24
More type hints
danielballan Mar 8, 2025
d5cd87b
Oops
danielballan Mar 8, 2025
8378221
Rename decode_token
DiamondJoseph Mar 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion example_configs/external_service/custom.py
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just some type hinting for my benefit

Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any, Self

import numpy

from tiled.adapters.array import ArrayAdapter
Expand Down Expand Up @@ -55,7 +57,7 @@ def __init__(self, base_url, metadata=None):
self.client = MockClient(base_url)
self.metadata = metadata

def with_session_state(self, state):
def with_session_state(self, state: dict[str, Any]) -> Self:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From a design perspective, it's just expected to still be an Adapter (?)

return AuthenticatedAdapter(self.client, state["token"], metadata=self.metadata)


Expand Down
2 changes: 1 addition & 1 deletion tiled/_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def poll_enumerate():
# docker run --name tiled-test-postgres -p 5432:5432 -e POSTGRES_PASSWORD=secret -d docker.io/postgres:16
# and set this env var like:
#
# TILED_TEST_POSTGRESQL_URI=postgresql+asyncpg://postgres:secret@localhost:5432
# TILED_TEST_POSTGRESQL_URI=postgresql://postgres:secret@localhost:5432

TILED_TEST_POSTGRESQL_URI = os.getenv("TILED_TEST_POSTGRESQL_URI")

Expand Down
60 changes: 49 additions & 11 deletions tiled/authenticators.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
import re
import secrets
from collections.abc import Iterable
from typing import Any, Mapping, Optional, cast
from typing import Any, Callable, Mapping, Optional, cast

import httpx
from fastapi import APIRouter, Request
from fastapi.security import OAuth2AuthorizationCodeBearer
from jose import JWTError, jwt
from pydantic import Secret
from starlette.responses import RedirectResponse
Expand Down Expand Up @@ -181,6 +182,16 @@ def authorization_endpoint(self) -> httpx.URL:
cast(str, self._config_from_oidc_url.get("authorization_endpoint"))
)

async def decode_token(self, access_token: str) -> dict[str, Any]:
keys = httpx.get(self.jwks_uri).raise_for_status().json().get("keys", [])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed, use a TTL cache to ensure we do not hammer the AuthN provider

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(for the result of httpx.get(self.jwks_uri))

return jwt.decode(
token=access_token,
key=keys,
algorithms=self.id_token_signing_alg_values_supported,
audience=self._audience,
access_token=access_token,
)

async def authenticate(self, request: Request) -> Optional[UserSessionState]:
code = request.query_params["code"]
# A proxy in the middle may make the request into something like
Expand All @@ -199,24 +210,51 @@ async def authenticate(self, request: Request) -> Optional[UserSessionState]:
logger.error("Authentication error: %r", response_body)
return None
response_body = response.json()
id_token = response_body["id_token"]
access_token = response_body["access_token"]
keys = httpx.get(self.jwks_uri).raise_for_status().json().get("keys", [])
try:
verified_body = jwt.decode(
token=id_token,
key=keys,
algorithms=self.id_token_signing_alg_values_supported,
audience=self._audience,
access_token=access_token,
verified_body = await self.decode_token(access_token)
return UserSessionState(verified_body["sub"], {})

except JWTError:
logger.exception(
"Authentication error. Unverified token: %r",
jwt.get_unverified_claims(access_token),
)
return None


class ProxiedOIDCAuthenticator(OIDCAuthenticator):
def __init__(
self,
audience: str,
client_id: str,
client_secret: str,
well_known_uri: str,
confirmation_message: str = "",
):
super().__init__(
audience, client_id, client_secret, well_known_uri, confirmation_message
)
self._oidc_bearer = OAuth2AuthorizationCodeBearer(
authorizationUrl=self.authorization_endpoint, tokenUrl=self.token_endpoint
)

@property
def oauth2_scheme(self) -> Callable[[Request], str]:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def oauth2_scheme(self) -> Callable[[Request], str]:
def oauth2_scheme(self) -> OAuth2:

return self._oidc_bearer

async def authenticate(self, request: Request) -> Optional[UserSessionState]:
access_token = self._oidc_bearer(request)
try:
verified_body = await self.decode_token(access_token)
return UserSessionState(verified_body["sub"], {})

except JWTError:
logger.exception(
"Authentication error. Unverified token: %r",
jwt.get_unverified_claims(id_token),
jwt.get_unverified_claims(access_token),
)
return None
return UserSessionState(verified_body["sub"], {})


async def exchange_code(
Expand Down
36 changes: 13 additions & 23 deletions tiled/authn_database/connection_pool.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
from fastapi import Depends
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from typing import Optional

from ..server.settings import get_settings
from ..utils import ensure_specified_sql_driver
from fastapi import Depends, Request
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine

# A given process probably only has one of these at a time, but we
# key on database_settings just case in some testing context or something
# we have two servers running in the same process.
_connection_pools = {}
from ..server.settings import DatabaseSettings
from ..utils import ensure_specified_sql_driver


def open_database_connection_pool(database_settings):
def open_database_connection_pool(database_settings: DatabaseSettings) -> AsyncEngine:
connect_args = {}
kwargs = {} # extra kwargs passed to create_engine
# kwargs["pool_size"] = database_settings.pool_size
Expand All @@ -21,29 +18,22 @@ def open_database_connection_pool(database_settings):
connect_args=connect_args,
**kwargs,
)
_connection_pools[database_settings] = engine
return engine


async def close_database_connection_pool(database_settings):
engine = _connection_pools.pop(database_settings, None)
async def close_database_connection_pool(engine: AsyncEngine) -> None:
if engine is not None:
await engine.dispose()


async def get_database_engine(settings=Depends(get_settings)):
# Special case for single-user mode
if settings.database_uri is None:
return None
try:
return _connection_pools[settings.database_settings]
except KeyError:
raise RuntimeError(
f"Could not find connection pool for {settings.database_settings}"
)
async def get_database_engine(request: Request) -> Optional[AsyncEngine]:
"Return engine if multi-user server, None is single-user server."
return request.app.state.authn_database_engine


async def get_database_session(engine=Depends(get_database_engine)):
async def get_database_session(
engine: AsyncEngine = Depends(get_database_engine),
) -> Optional[AsyncSession]:
# Special case for single-user mode
if engine is None:
yield None
Expand Down
4 changes: 2 additions & 2 deletions tiled/client/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from ..adapters.utils import IndexersMixin
from ..iterviews import ItemsView, KeysView, ValuesView
from ..queries import KeyLookup
from ..query_registration import query_registry
from ..query_registration import default_query_registry
from ..structures.core import Spec, StructureFamily
from ..structures.data_source import DataSource
from ..utils import UNCHANGED, OneShotCachedMap, Sentinel, node_repr, safe_json_dump
Expand Down Expand Up @@ -1055,7 +1055,7 @@ def _queries_to_params(*queries):
"Compute GET params from the queries."
params = collections.defaultdict(list)
for query in queries:
name = query_registry.query_type_to_name[type(query)]
name = default_query_registry.query_type_to_name[type(query)]
for field, value in query.encode().items():
if value is not None:
params[f"filter[{name}][condition][{field}]"].append(value)
Expand Down
3 changes: 1 addition & 2 deletions tiled/client/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,8 +454,7 @@ def from_app(
# Extract the API key from the app and set it.
from ..server.settings import get_settings

settings = app.dependency_overrides[get_settings]()
api_key = settings.single_user_api_key or None
api_key = get_settings().single_user_api_key
else:
# This is a multi-user server but no API key was passed,
# so we will leave it as None on the Context.
Expand Down
10 changes: 4 additions & 6 deletions tiled/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,12 @@

from .adapters.mapping import MapAdapter
from .media_type_registration import (
compression_registry as default_compression_registry,
default_compression_registry,
default_serialization_registry,
)
from .media_type_registration import (
serialization_registry as default_serialization_registry,
)
from .query_registration import query_registry as default_query_registry
from .query_registration import default_query_registry
from .utils import import_object, parse, prepend_to_sys_path
from .validation_registration import validation_registry as default_validation_registry
from .validation_registration import default_validation_registry


@cache
Expand Down
16 changes: 8 additions & 8 deletions tiled/media_type_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,21 +197,21 @@ def __call__(self, media_type, encoder, *args, **kwargs):
return self.dispatch(media_type, encoder)(*args, **kwargs)


serialization_registry = SerializationRegistry()
default_serialization_registry = SerializationRegistry()
"Global serialization registry. See Registry for usage examples."

deserialization_registry = SerializationRegistry()
default_deserialization_registry = SerializationRegistry()
"Global deserialization registry. See Registry for usage examples."

compression_registry = CompressionRegistry()
default_compression_registry = CompressionRegistry()
"Global compression registry. See Registry for usage examples."


for media_type in [
"application/json",
"application/x-msgpack",
]:
compression_registry.register(
default_compression_registry.register(
media_type,
"gzip",
lambda buffer: gzip.GzipFile(mode="wb", fileobj=buffer, compresslevel=9),
Expand All @@ -225,7 +225,7 @@ def __call__(self, media_type, encoder, *args, **kwargs):
"text/plain",
"text/html",
]:
compression_registry.register(
default_compression_registry.register(
media_type,
"gzip",
# Use a lower compression level. High compression is extremely slow
Expand Down Expand Up @@ -270,7 +270,7 @@ def close(self):
"text/html",
"text/plain",
]:
compression_registry.register(media_type, "zstd", ZstdBuffer)
default_compression_registry.register(media_type, "zstd", ZstdBuffer)

if modules_available("lz4"):
import lz4
Expand Down Expand Up @@ -326,7 +326,7 @@ def close(self):
"text/html",
"text/plain",
]:
compression_registry.register(media_type, "lz4", LZ4Buffer)
default_compression_registry.register(media_type, "lz4", LZ4Buffer)

if modules_available("blosc2"):
import blosc2
Expand Down Expand Up @@ -355,4 +355,4 @@ def close(self):
pass

for media_type in ["application/octet-stream", APACHE_ARROW_FILE_MIME_TYPE]:
compression_registry.register(media_type, "blosc2", BloscBuffer)
default_compression_registry.register(media_type, "blosc2", BloscBuffer)
4 changes: 2 additions & 2 deletions tiled/query_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ def inner(cls):


# Make a global registry.
query_registry = QueryRegistry()
register = query_registry.register
default_query_registry = QueryRegistry()
register = default_query_registry.register
"""Register a new type of query."""


Expand Down
39 changes: 25 additions & 14 deletions tiled/serialization/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@

import numpy

from ..media_type_registration import deserialization_registry, serialization_registry
from ..media_type_registration import (
default_deserialization_registry,
default_serialization_registry,
)
from ..utils import (
SerializationError,
UnsupportedShape,
Expand All @@ -22,13 +25,13 @@ def as_buffer(array, metadata):
return numpy.asarray(array).tobytes()


serialization_registry.register(
default_serialization_registry.register(
"array",
"application/octet-stream",
as_buffer,
)
if modules_available("orjson"):
serialization_registry.register(
default_serialization_registry.register(
"array",
"application/json",
lambda array, metadata: safe_json_dump(array),
Expand All @@ -43,10 +46,12 @@ def serialize_csv(array, metadata):
return file.getvalue().encode()


serialization_registry.register("array", "text/csv", serialize_csv)
serialization_registry.register("array", "text/x-comma-separated-values", serialize_csv)
serialization_registry.register("array", "text/plain", serialize_csv)
deserialization_registry.register(
default_serialization_registry.register("array", "text/csv", serialize_csv)
default_serialization_registry.register(
"array", "text/x-comma-separated-values", serialize_csv
)
default_serialization_registry.register("array", "text/plain", serialize_csv)
default_deserialization_registry.register(
"array",
"application/octet-stream",
lambda buffer, dtype, shape: numpy.frombuffer(buffer, dtype=dtype).reshape(shape),
Expand Down Expand Up @@ -90,10 +95,10 @@ def array_from_buffer_PIL(buffer, format, dtype, shape):
image = Image.open(file, format=format)
return numpy.asarray(image).asdtype(dtype).reshape(shape)

serialization_registry.register(
default_serialization_registry.register(
"array", "image/png", lambda array, metadata: save_to_buffer_PIL(array, "png")
)
deserialization_registry.register(
default_deserialization_registry.register(
"array",
"image/png",
lambda buffer, dtype, shape: array_from_buffer_PIL(buffer, "png", dtype, shape),
Expand All @@ -120,18 +125,24 @@ def save_to_buffer_tifffile(array, metadata):
imwrite(file, normalized_array)
return file.getbuffer()

serialization_registry.register("array", "image/tiff", save_to_buffer_tifffile)
deserialization_registry.register("array", "image/tiff", array_from_buffer_tifffile)
default_serialization_registry.register(
"array", "image/tiff", save_to_buffer_tifffile
)
default_deserialization_registry.register(
"array", "image/tiff", array_from_buffer_tifffile
)


def serialize_html(array, metadata):
"Try to display as image. Fall back to CSV."
try:
png_data = serialization_registry.dispatch("array", "image/png")(
png_data = default_serialization_registry.dispatch("array", "image/png")(
array, metadata
)
except Exception:
csv_data = serialization_registry.dispatch("array", "text/csv")(array, metadata)
csv_data = default_serialization_registry.dispatch("array", "text/csv")(
array, metadata
)
return "<html>" "<body>" f"{csv_data.decode()!s}" "</body>" "</html>"
else:
return (
Expand All @@ -145,4 +156,4 @@ def serialize_html(array, metadata):
)


serialization_registry.register("array", "text/html", serialize_html)
default_serialization_registry.register("array", "text/html", serialize_html)
Loading
Loading