diff --git a/lambda/pyproject.toml b/lambda/pyproject.toml index edd6bee6..6a2ecf6c 100644 --- a/lambda/pyproject.toml +++ b/lambda/pyproject.toml @@ -6,7 +6,7 @@ dependencies = [ "aws-lambda-powertools==3.25.0", "boto3==1.42.57", "cryptography==46.0.5", - "dataclasses-json==0.6.7", + "pydantic==2.12.5", "mohawk==1.1.0", "PyJWT==2.11.0", "requests==2.32.5" diff --git a/lambda/src/environment/service_provider.py b/lambda/src/environment/service_provider.py index 36c3249a..d2d91df3 100644 --- a/lambda/src/environment/service_provider.py +++ b/lambda/src/environment/service_provider.py @@ -5,7 +5,11 @@ import boto3 from aws_lambda_powertools.event_handler import CORSConfig, Response +from aws_lambda_powertools.logging import Logger +from src.middlewares.hawk_auth import HawkAuthenticationError, HawkAuthMiddleware, UidMismatchError +from src.middlewares.request_logging import RequestLoggingMiddleware +from src.middlewares.weave_timestamp import WeaveTimestampMiddleware from src.routes.auth.account_attached_clients import AccountAttachedClientsRoute from src.routes.auth.account_create import AccountCreateRoute from src.routes.auth.account_device import AccountDeviceRoute @@ -40,14 +44,7 @@ from src.routes.storage.delete_all import DeleteAllStorageRoute from src.routes.storage.delete_root import DeleteAllRootRoute from src.routes.token.request import GetTokenRoute -from src.services.api_router import ( - ApiRouter, - HawkAuthenticationError, - HawkAuthMiddleware, - RequestLoggingMiddleware, - UidMismatchError, - WeaveTimestampMiddleware, -) +from src.services.api_router import ApiRouter from src.services.auth_account_manager import AuthAccountManager from src.services.channel_service import ChannelService from src.services.device_manager import DeviceManager @@ -88,6 +85,9 @@ def wrapper(event, context, service_provider=None): return wrapper +logger = Logger() + + class ServiceProvider: @cached_property def aws_region(self): # pragma: nocover @@ -158,7 +158,9 @@ def handle_hawk_auth(ex): body=json.dumps({"code": 401, "errno": 110, "message": str(ex)}), ) - return {HawkAuthenticationError: handle_hawk_auth} + return { + HawkAuthenticationError: handle_hawk_auth, + } @cached_property def storage_api_router(self): @@ -186,6 +188,7 @@ def storage_api_router(self): WeaveTimestampMiddleware(), ], exception_handlers=self._storage_exception_handlers, + enable_validation=True, ) # Token API properties @@ -380,6 +383,7 @@ def auth_api_router(self): middlewares=[WeaveTimestampMiddleware()], cors=self.cors_config, exception_handlers=self._auth_exception_handlers, + enable_validation=True, ) @cached_property @@ -396,6 +400,7 @@ def token_api_router(self): ], middlewares=[WeaveTimestampMiddleware()], cors=self.cors_config, + enable_validation=True, ) @cached_property @@ -410,6 +415,7 @@ def profile_api_router(self): ], middlewares=[WeaveTimestampMiddleware()], cors=self.cors_config, + enable_validation=True, ) # HAWK Authorizer properties diff --git a/lambda/src/routes/auth/account_attached_clients.py b/lambda/src/routes/auth/account_attached_clients.py index c1511f09..9f8f2415 100644 --- a/lambda/src/routes/auth/account_attached_clients.py +++ b/lambda/src/routes/auth/account_attached_clients.py @@ -7,7 +7,7 @@ from src.services.device_manager import DeviceManager from src.shared.base_route import BaseRoute -from src.shared.utils import json_dumps +from src.shared.models import AttachedClientOutput, ClientListAdapter class AccountAttachedClientsRoute(BaseRoute): @@ -33,26 +33,20 @@ def handle(self, event) -> Response: devices = self._device_manager.get_devices(uid) clients = [] for d in devices: - clients.append( - { - "clientId": None, - "deviceId": d.get("id"), - "sessionTokenId": d.get("sessionTokenId"), - "refreshTokenId": None, - "isCurrentSession": d.get("sessionTokenId") == session_token_id, - "deviceType": d.get("type"), - "name": d.get("name"), - "createdTime": d.get("createdAt"), - "lastAccessTime": d.get("lastAccessTime"), - "scope": None, - "location": {}, - "userAgent": "", - "os": None, - } + client = AttachedClientOutput( + device_id=d.get("id"), + session_token_id=d.get("sessionTokenId"), + is_current_session=d.get("sessionTokenId") == session_token_id, + device_type=d.get("type"), + name=d.get("name"), + created_time=d.get("createdAt"), + last_access_time=d.get("lastAccessTime"), + user_agent="", ) + clients.append(client) return Response( status_code=200, content_type="application/json", - body=json_dumps(clients), + body=ClientListAdapter.dump_json(clients, by_alias=True).decode(), ) diff --git a/lambda/src/routes/auth/account_create.py b/lambda/src/routes/auth/account_create.py index 93b82512..ad678b34 100644 --- a/lambda/src/routes/auth/account_create.py +++ b/lambda/src/routes/auth/account_create.py @@ -11,6 +11,7 @@ from src.services.fxa_token_manager import FxATokenManager from src.services.oidc_validator import OIDCValidator from src.shared.base_route import BaseRoute +from src.shared.models import AccountCreateOutput BEARER_PATTERN = re.compile(r"^Bearer\s+(.+)$", re.IGNORECASE) AUTH_PW_PATTERN = re.compile(r"^[0-9a-f]{64}$") @@ -102,17 +103,16 @@ def handle(self, event) -> Response: session_token = self._token_manager.create_session_token(uid) key_fetch_token = self._token_manager.create_key_fetch_token(uid) + result = AccountCreateOutput( + uid=uid, + session_token=session_token.hex(), + key_fetch_token=key_fetch_token.hex(), + verified=True, + ) return Response( status_code=200, content_type="application/json", - body=json.dumps( - { - "uid": uid, - "sessionToken": session_token.hex(), - "keyFetchToken": key_fetch_token.hex(), - "verified": True, - } - ), + body=result.model_dump_json(by_alias=True), ) @staticmethod diff --git a/lambda/src/routes/auth/account_device.py b/lambda/src/routes/auth/account_device.py index 489a3a32..6928bd27 100644 --- a/lambda/src/routes/auth/account_device.py +++ b/lambda/src/routes/auth/account_device.py @@ -8,7 +8,7 @@ from src.services.device_manager import DeviceManager from src.shared.base_route import BaseRoute -from src.shared.utils import json_dumps +from src.shared.models import DeviceOutput class AccountDeviceRoute(BaseRoute): @@ -32,9 +32,9 @@ def handle(self, event) -> Response: session_token_id = event["requestContext"].get("hawk_token_id", "") body = json.loads(event.body or "{}") device = self._device_manager.upsert_device(uid, session_token_id, body) - + result = DeviceOutput.model_validate(device) return Response( status_code=200, content_type="application/json", - body=json_dumps(device), + body=result.model_dump_json(by_alias=True), ) diff --git a/lambda/src/routes/auth/account_devices.py b/lambda/src/routes/auth/account_devices.py index ac0b0092..c69671d3 100644 --- a/lambda/src/routes/auth/account_devices.py +++ b/lambda/src/routes/auth/account_devices.py @@ -7,7 +7,7 @@ from src.services.device_manager import DeviceManager from src.shared.base_route import BaseRoute -from src.shared.utils import json_dumps +from src.shared.models import DeviceListAdapter, DeviceOutput class AccountDevicesRoute(BaseRoute): @@ -35,11 +35,14 @@ def handle(self, event) -> Response: filter_idle = int(filter_ts) if filter_ts else None devices = self._device_manager.get_devices(uid, filter_idle) + results = [] for d in devices: - d["isCurrentDevice"] = d.get("sessionTokenId") == session_token_id + device = DeviceOutput.model_validate(d) + device.is_current_device = d.get("sessionTokenId") == session_token_id + results.append(device) return Response( status_code=200, content_type="application/json", - body=json_dumps(devices), + body=DeviceListAdapter.dump_json(results, by_alias=True).decode(), ) diff --git a/lambda/src/routes/auth/account_keys.py b/lambda/src/routes/auth/account_keys.py index 603f92df..214a615e 100644 --- a/lambda/src/routes/auth/account_keys.py +++ b/lambda/src/routes/auth/account_keys.py @@ -8,6 +8,7 @@ from src.services.fxa_crypto import derive_key_request_key, encrypt_key_bundle from src.services.fxa_token_manager import KEY_FETCH_TOKEN_INFO, FxATokenManager from src.shared.base_route import BaseRoute +from src.shared.models import AccountKeysOutput from src.shared.utils import extract_hawk_request_params @@ -58,10 +59,11 @@ def handle(self, event) -> Response: # Encrypt key bundle bundle = encrypt_key_bundle(key_request_key, k_a, wrap_kb) + result = AccountKeysOutput(bundle=bundle.hex()) return Response( status_code=200, content_type="application/json", - body=json.dumps({"bundle": bundle.hex()}), + body=result.model_dump_json(), ) @staticmethod diff --git a/lambda/src/routes/auth/account_login.py b/lambda/src/routes/auth/account_login.py index b678180b..fc1f3fe7 100644 --- a/lambda/src/routes/auth/account_login.py +++ b/lambda/src/routes/auth/account_login.py @@ -9,6 +9,7 @@ from src.services.fxa_crypto import constant_time_compare, derive_verify_hash from src.services.fxa_token_manager import FxATokenManager from src.shared.base_route import BaseRoute +from src.shared.models import AccountLoginOutput AUTH_PW_PATTERN = re.compile(r"^[0-9a-f]{64}$") @@ -74,22 +75,22 @@ def handle(self, event) -> Response: # Create session token (always) session_token = self._token_manager.create_session_token(uid) - result: dict = { - "uid": uid, - "sessionToken": session_token.hex(), - "verified": True, - } + result = AccountLoginOutput( + uid=uid, + session_token=session_token.hex(), + verified=True, + ) # Create key-fetch token if keys=true params = event.query_string_parameters or {} if params.get("keys") == "true": key_fetch_token = self._token_manager.create_key_fetch_token(uid) - result["keyFetchToken"] = key_fetch_token.hex() + result.key_fetch_token = key_fetch_token.hex() return Response( status_code=200, content_type="application/json", - body=json.dumps(result), + body=result.model_dump_json(by_alias=True, exclude_none=True), ) @staticmethod diff --git a/lambda/src/routes/auth/account_status.py b/lambda/src/routes/auth/account_status.py index 9a162292..aaa0585f 100644 --- a/lambda/src/routes/auth/account_status.py +++ b/lambda/src/routes/auth/account_status.py @@ -6,6 +6,7 @@ from src.services.auth_account_manager import AuthAccountManager from src.shared.base_route import BaseRoute +from src.shared.models import AccountStatusOutput class AccountStatusRoute(BaseRoute): @@ -30,8 +31,9 @@ def handle(self, event) -> Response: ) account = self._account_manager.get_account_by_email(email) + result = AccountStatusOutput(exists=account is not None) return Response( status_code=200, content_type="application/json", - body=json.dumps({"exists": account is not None}), + body=result.model_dump_json(), ) diff --git a/lambda/src/routes/auth/oauth_authorization.py b/lambda/src/routes/auth/oauth_authorization.py index 26dce762..2f30c8ab 100644 --- a/lambda/src/routes/auth/oauth_authorization.py +++ b/lambda/src/routes/auth/oauth_authorization.py @@ -8,6 +8,7 @@ from src.services.oauth_code_manager import OAuthCodeManager from src.shared.base_route import BaseRoute +from src.shared.models import OAuthAuthorizationOutput ALLOWED_REDIRECT_URIS = { "urn:ietf:wg:oauth:2.0:oob", @@ -74,16 +75,15 @@ def handle(self, event) -> Response: keys_jwe=keys_jwe, ) + result = OAuthAuthorizationOutput( + code=code, + state=state, + redirect=redirect_uri, + ) return Response( status_code=200, content_type="application/json", - body=json.dumps( - { - "code": code, - "state": state, - "redirect": redirect_uri, - } - ), + body=result.model_dump_json(), ) @staticmethod diff --git a/lambda/src/routes/auth/oauth_token.py b/lambda/src/routes/auth/oauth_token.py index a2f76600..3967fefc 100644 --- a/lambda/src/routes/auth/oauth_token.py +++ b/lambda/src/routes/auth/oauth_token.py @@ -11,6 +11,7 @@ from src.services.jwt_service import JWTService from src.services.oauth_code_manager import OAuthCodeManager from src.shared.base_route import BaseRoute +from src.shared.models import OAuthTokenOutput from src.shared.utils import extract_hawk_request_params DEFAULT_TTL = 900 # 15 minutes @@ -111,23 +112,22 @@ def _handle_authorization_code(self, body: dict) -> Response: scope=scope, ) - response_body: dict = { - "access_token": access_token, - "token_type": "bearer", - "expires_in": ttl, - "scope": scope, - "refresh_token": refresh_token, - "auth_at": int(time.time()), - } + result = OAuthTokenOutput( + access_token=access_token, + expires_in=ttl, + scope=scope, + refresh_token=refresh_token, + auth_at=int(time.time()), + ) keys_jwe = code_data.get("keysJwe", "") if keys_jwe: - response_body["keys_jwe"] = keys_jwe + result.keys_jwe = keys_jwe return Response( status_code=200, content_type="application/json", - body=json.dumps(response_body), + body=result.model_dump_json(exclude_none=True), ) def _handle_refresh_token(self, body: dict) -> Response: @@ -173,19 +173,17 @@ def _handle_refresh_token(self, body: dict) -> Response: scope=scope, ) + result = OAuthTokenOutput( + access_token=access_token, + expires_in=ttl, + scope=scope, + refresh_token=new_refresh_token, + auth_at=int(time.time()), + ) return Response( status_code=200, content_type="application/json", - body=json.dumps( - { - "access_token": access_token, - "token_type": "bearer", - "expires_in": ttl, - "scope": scope, - "refresh_token": new_refresh_token, - "auth_at": int(time.time()), - } - ), + body=result.model_dump_json(exclude_none=True), ) def _handle_fxa_credentials(self, event, body: dict) -> Response: @@ -221,18 +219,16 @@ def _handle_fxa_credentials(self, event, body: dict) -> Response: fxa_uid=uid, ) + result = OAuthTokenOutput( + access_token=access_token, + expires_in=ttl, + scope=scope, + auth_at=int(time.time()), + ) return Response( status_code=200, content_type="application/json", - body=json.dumps( - { - "access_token": access_token, - "token_type": "bearer", - "expires_in": ttl, - "scope": scope, - "auth_at": int(time.time()), - } - ), + body=result.model_dump_json(exclude_none=True), ) @staticmethod diff --git a/lambda/src/routes/auth/oidc_exchange.py b/lambda/src/routes/auth/oidc_exchange.py index 5c7fbd2a..a0c67167 100644 --- a/lambda/src/routes/auth/oidc_exchange.py +++ b/lambda/src/routes/auth/oidc_exchange.py @@ -9,6 +9,7 @@ from src.services.auth_account_manager import AuthAccountManager from src.services.oidc_validator import OIDCValidator from src.shared.base_route import BaseRoute +from src.shared.models import OIDCExchangeOutput logger = Logger() @@ -190,14 +191,13 @@ def handle(self, event) -> Response: # 5. Check if account exists account = self._account_manager.get_account_by_email(email) + result = OIDCExchangeOutput( + email=email, + access_token=access_token, + account_exists=account is not None, + ) return Response( status_code=200, content_type="application/json", - body=json.dumps( - { - "email": email, - "access_token": access_token, - "account_exists": account is not None, - } - ), + body=result.model_dump_json(), ) diff --git a/lambda/src/routes/auth/scoped_key_data.py b/lambda/src/routes/auth/scoped_key_data.py index fd862a39..54896c61 100644 --- a/lambda/src/routes/auth/scoped_key_data.py +++ b/lambda/src/routes/auth/scoped_key_data.py @@ -5,10 +5,11 @@ from aws_lambda_powertools.event_handler import APIGatewayRestResolver, Response from aws_lambda_powertools.event_handler.middlewares import BaseMiddlewareHandler +from pydantic import ValidationError as PydanticValidationError from src.services.auth_account_manager import AuthAccountManager from src.shared.base_route import BaseRoute -from src.shared.utils import json_dumps +from src.shared.models import ScopedKeyDataEntry, ScopedKeyDataInput class ScopedKeyDataRoute(BaseRoute): @@ -30,41 +31,40 @@ def handle_scoped_key_data(): def handle(self, event) -> Response: uid = event["requestContext"]["hawk_uid"] - # Parse body + # Parse and validate body body_str = event.body if not body_str: return self._error(400, 107, "Missing request body") try: - body = json.loads(body_str) - except (json.JSONDecodeError, TypeError): - return self._error(400, 107, "Invalid JSON body") + body_input = ScopedKeyDataInput.model_validate_json(body_str) + except PydanticValidationError: + return self._error(400, 107, "Missing or invalid scope") - scope = body.get("scope") - if not scope: - return self._error(400, 107, "Missing scope") + scope = body_input.scope # Look up account for createdAt and keyRotationSecret account = self._account_manager.get_account_by_uid(uid) if account is None: return self._error(401, 110, "Account not found") - created_at = account.get("createdAt", 0) + created_at = int(account.get("createdAt", 0)) key_rotation_secret = account.get("keyRotationSecret", "00" * 32) # Return key metadata for each scope result = {} for s in scope.split(): - result[s] = { - "identifier": s, - "keyRotationSecret": key_rotation_secret, - "keyRotationTimestamp": created_at, - } + entry = ScopedKeyDataEntry( + identifier=s, + key_rotation_secret=key_rotation_secret, + key_rotation_timestamp=created_at, + ) + result[s] = entry.model_dump(by_alias=True) return Response( status_code=200, content_type="application/json", - body=json_dumps(result), + body=json.dumps(result), ) @staticmethod diff --git a/lambda/src/routes/auth/session_status.py b/lambda/src/routes/auth/session_status.py index 87d8bf7f..9d1a64db 100644 --- a/lambda/src/routes/auth/session_status.py +++ b/lambda/src/routes/auth/session_status.py @@ -1,12 +1,12 @@ """SessionStatus route — GET /v1/session/status""" -import json from typing import Sequence from aws_lambda_powertools.event_handler import APIGatewayRestResolver, Response from aws_lambda_powertools.event_handler.middlewares import BaseMiddlewareHandler from src.shared.base_route import BaseRoute +from src.shared.models import SessionStatusOutput class SessionStatusRoute(BaseRoute): @@ -23,8 +23,9 @@ def handle_session_status(): def handle(self, event) -> Response: uid = event["requestContext"]["hawk_uid"] + result = SessionStatusOutput(state="verified", uid=uid) return Response( status_code=200, content_type="application/json", - body=json.dumps({"state": "verified", "uid": uid}), + body=result.model_dump_json(), ) diff --git a/lambda/src/routes/bso/delete.py b/lambda/src/routes/bso/delete.py index 51420536..fa7c71e9 100644 --- a/lambda/src/routes/bso/delete.py +++ b/lambda/src/routes/bso/delete.py @@ -1,3 +1,5 @@ +import json + from aws_lambda_powertools import Logger from aws_lambda_powertools.event_handler import APIGatewayRestResolver, Response @@ -8,8 +10,12 @@ StorageObjectNotFoundException, ValidationException, ) -from src.shared.models import ValidationError, validate_bso_id, validate_collection_name -from src.shared.utils import json_dumps +from src.shared.models import ( + ModifiedOutput, + ValidationError, + validate_bso_id, + validate_collection_name, +) logger = Logger() @@ -32,7 +38,7 @@ def handle(self, event) -> Response: return Response( status_code=401, content_type="application/json", - body=json_dumps({"error": "Unauthorized"}), + body=json.dumps({"error": "Unauthorized"}), ) path_params = event.path_parameters or {} @@ -49,12 +55,12 @@ def handle(self, event) -> Response: user_id, collection_name, object_id ) - response_body = {"modified": modified_timestamp} + result = ModifiedOutput(modified=modified_timestamp) return Response( status_code=200, content_type="application/json", - body=json_dumps(response_body), + body=result.model_dump_json(), headers={"X-Last-Modified": str(round(modified_timestamp, 2))}, ) @@ -62,24 +68,24 @@ def handle(self, event) -> Response: return Response( status_code=400, content_type="application/json", - body=json_dumps({"error": str(e)}), + body=json.dumps({"error": str(e)}), ) except CollectionNotFoundException as e: return Response( status_code=404, content_type="application/json", - body=json_dumps({"error": str(e)}), + body=json.dumps({"error": str(e)}), ) except StorageObjectNotFoundException as e: return Response( status_code=404, content_type="application/json", - body=json_dumps({"error": str(e)}), + body=json.dumps({"error": str(e)}), ) except Exception as e: logger.error(f"Internal server error: {e}") return Response( status_code=500, content_type="application/json", - body=json_dumps({"error": "Internal server error"}), + body=json.dumps({"error": "Internal server error"}), ) diff --git a/lambda/src/routes/bso/read.py b/lambda/src/routes/bso/read.py index 2c8fc937..d70a04ac 100644 --- a/lambda/src/routes/bso/read.py +++ b/lambda/src/routes/bso/read.py @@ -1,3 +1,5 @@ +import json + from aws_lambda_powertools import Logger from aws_lambda_powertools.event_handler import APIGatewayRestResolver, Response @@ -8,8 +10,12 @@ StorageObjectNotFoundException, ValidationException, ) -from src.shared.models import ValidationError, validate_bso_id, validate_collection_name -from src.shared.utils import json_dumps +from src.shared.models import ( + BSOOutput, + ValidationError, + validate_bso_id, + validate_collection_name, +) logger = Logger() @@ -32,7 +38,7 @@ def handle(self, event) -> Response: return Response( status_code=401, content_type="application/json", - body=json_dumps({"error": "Unauthorized"}), + body=json.dumps({"error": "Unauthorized"}), ) path_params = event.path_parameters or {} @@ -54,7 +60,7 @@ def handle(self, event) -> Response: return Response( status_code=400, content_type="application/json", - body=json_dumps( + body=json.dumps( { "error": "Cannot specify both X-If-Modified-Since and X-If-Unmodified-Since" } @@ -73,7 +79,7 @@ def handle(self, event) -> Response: return Response( status_code=400, content_type="application/json", - body=json_dumps({"error": "Invalid X-If-Modified-Since header"}), + body=json.dumps({"error": "Invalid X-If-Modified-Since header"}), ) # Get storage object using storage manager with user_id @@ -93,46 +99,38 @@ def handle(self, event) -> Response: headers={"X-Last-Modified": str(round(modified_timestamp, 2))}, ) - # Convert to dict using dataclass serialization - obj_dict = storage_object.to_dict() - - # TTL is write-only per Mozilla spec - always exclude from response - if "ttl" in obj_dict: # pragma: nocover - del obj_dict["ttl"] - - # Remove None values for optional fields - if obj_dict.get("sortindex") is None: - del obj_dict["sortindex"] + # Convert to Pydantic model (TTL is write-only per Mozilla spec) + bso = BSOOutput.from_bso(storage_object) return Response( status_code=200, content_type="application/json", - body=json_dumps(obj_dict), - headers={"X-Last-Modified": str(obj_dict["modified"])}, + body=bso.model_dump_json(exclude_none=True), + headers={"X-Last-Modified": str(bso.modified)}, ) except ValidationException as e: return Response( status_code=400, content_type="application/json", - body=json_dumps({"error": str(e)}), + body=json.dumps({"error": str(e)}), ) except CollectionNotFoundException as e: return Response( status_code=404, content_type="application/json", - body=json_dumps({"error": str(e)}), + body=json.dumps({"error": str(e)}), ) except StorageObjectNotFoundException as e: return Response( status_code=404, content_type="application/json", - body=json_dumps({"error": str(e)}), + body=json.dumps({"error": str(e)}), ) except Exception as e: logger.error(f"Internal server error: {e}") return Response( status_code=500, content_type="application/json", - body=json_dumps({"error": "Internal server error"}), + body=json.dumps({"error": "Internal server error"}), ) diff --git a/lambda/src/routes/bso/update.py b/lambda/src/routes/bso/update.py index 065956a3..dd00b503 100644 --- a/lambda/src/routes/bso/update.py +++ b/lambda/src/routes/bso/update.py @@ -2,6 +2,7 @@ from aws_lambda_powertools import Logger from aws_lambda_powertools.event_handler import APIGatewayRestResolver, Response +from pydantic import ValidationError as PydanticValidationError from src.services.storage_manager import StorageManager from src.shared.base_route import BaseRoute @@ -13,14 +14,12 @@ ValidationException, ) from src.shared.models import ( + BSOInput, ValidationError, validate_bso_id, validate_collection_name, validate_payload_size, - validate_sortindex, - validate_ttl, ) -from src.shared.utils import json_dumps logger = Logger() @@ -43,7 +42,7 @@ def handle(self, event) -> Response: return Response( status_code=401, content_type="application/json", - body=json_dumps({"error": "Unauthorized"}), + body=json.dumps({"error": "Unauthorized"}), ) path_params = event.path_parameters or {} @@ -55,46 +54,34 @@ def handle(self, event) -> Response: except ValidationError as e: raise ValidationException(str(e)) - # Parse BSO fields directly from request body (per SyncStorage API v1.5 spec) - try: - obj_data = json.loads(body) - except (json.JSONDecodeError, TypeError) as e: - raise ValidationException(f"Invalid request body: {e}") - - if not isinstance(obj_data, dict): - raise ValidationException("Invalid request body: expected JSON object") - # Validate BSO ID from path parameter (Requirements 10.2, 10.3) try: validate_bso_id(object_id) except ValidationError as e: raise ValidationException(str(e)) + # Parse and validate BSO fields via Pydantic model + try: + bso_input = BSOInput.model_validate_json(body) + except PydanticValidationError as e: + raise ValidationException(f"Invalid request body: {e}") + except (TypeError, ValueError) as e: # pragma: nocover + raise ValidationException(f"Invalid request body: {e}") + # If body contains id, validate it matches the path parameter - if "id" in obj_data and obj_data["id"] != object_id: + if bso_input.id is not None and bso_input.id != object_id: raise ValidationException("Object ID in body must match path parameter") - # Validate payload size if provided (Requirement 10.1) - payload = obj_data.get("payload") + # Validate payload byte-size for 413 (Requirement 10.1) + payload = bso_input.payload if payload is not None: try: validate_payload_size(payload) except ValidationError as e: raise RequestTooLargeException(str(e)) - # Validate sortindex if provided - sortindex = obj_data.get("sortindex") - try: - validate_sortindex(sortindex) - except ValidationError as e: - raise ValidationException(str(e)) - - # Validate TTL if provided - ttl = obj_data.get("ttl") - try: - validate_ttl(ttl) - except ValidationError as e: - raise ValidationException(str(e)) + sortindex = bso_input.sortindex + ttl = bso_input.ttl # Handle conditional update header (Requirements 5.1-5.3) if_unmodified_since = None @@ -123,7 +110,7 @@ def handle(self, event) -> Response: return Response( status_code=200, content_type="application/json", - body=json_dumps(modified), + body=json.dumps(modified), headers={"X-Last-Modified": str(round(modified, 2))}, ) @@ -131,36 +118,36 @@ def handle(self, event) -> Response: return Response( status_code=400, content_type="application/json", - body=json_dumps({"error": str(e)}), + body=json.dumps({"error": str(e)}), ) except CollectionNotFoundException as e: return Response( status_code=404, content_type="application/json", - body=json_dumps({"error": str(e)}), + body=json.dumps({"error": str(e)}), ) except StorageObjectNotFoundException as e: return Response( status_code=404, content_type="application/json", - body=json_dumps({"error": str(e)}), + body=json.dumps({"error": str(e)}), ) except PreconditionFailedException as e: return Response( status_code=412, content_type="application/json", - body=json_dumps({"error": str(e)}), + body=json.dumps({"error": str(e)}), ) except RequestTooLargeException as e: return Response( status_code=413, content_type="application/json", - body=json_dumps({"error": str(e)}), + body=json.dumps({"error": str(e)}), ) except Exception as e: logger.error(f"Internal server error: {e}") return Response( status_code=500, content_type="application/json", - body=json_dumps({"error": "Internal server error"}), + body=json.dumps({"error": "Internal server error"}), ) diff --git a/lambda/src/routes/collections/create.py b/lambda/src/routes/collections/create.py index 0e5a76c0..3125dcf0 100644 --- a/lambda/src/routes/collections/create.py +++ b/lambda/src/routes/collections/create.py @@ -13,8 +13,12 @@ ServerLimitExceededException, ValidationException, ) -from src.shared.models import BasicStorageObject, ValidationError, validate_collection_name -from src.shared.utils import json_dumps +from src.shared.models import ( + BasicStorageObject, + BatchResultOutput, + ValidationError, + validate_collection_name, +) logger = Logger() @@ -37,7 +41,7 @@ def handle(self, event) -> Response: return Response( status_code=401, content_type="application/json", - body=json_dumps({"error": "Unauthorized"}), + body=json.dumps({"error": "Unauthorized"}), ) path_params = event.path_parameters or {} @@ -83,7 +87,7 @@ def handle(self, event) -> Response: return Response( status_code=412, content_type="application/json", - body=json_dumps({"error": "Precondition failed"}), + body=json.dumps({"error": "Precondition failed"}), ) # Parse objects from request body - support application/json only @@ -131,18 +135,17 @@ def handle(self, event) -> Response: ) # Return Mozilla-compliant response format (Requirement 3.2) - # {"modified": timestamp, "success": [...], "failed": {...}} modified_ts = collection_data.modified.timestamp() - response_body = { - "modified": modified_ts, - "success": batch_result.success, - "failed": batch_result.failed, - } + result = BatchResultOutput( + modified=modified_ts, + success=batch_result.success, + failed=batch_result.failed, + ) return Response( status_code=201, # 201 Created for new collection content_type="application/json", - body=json_dumps(response_body), + body=result.model_dump_json(), headers={"X-Last-Modified": str(round(modified_ts, 2))}, ) @@ -151,32 +154,32 @@ def handle(self, event) -> Response: return Response( status_code=400, content_type="application/json", - body=json_dumps(CODE_SERVER_LIMIT_EXCEEDED), + body=json.dumps(CODE_SERVER_LIMIT_EXCEEDED), ) except ValidationException as e: return Response( status_code=400, content_type="application/json", - body=json_dumps({"error": str(e)}), + body=json.dumps({"error": str(e)}), ) except ConflictException as e: return Response( status_code=409, content_type="application/json", - body=json_dumps({"error": str(e)}), + body=json.dumps({"error": str(e)}), ) except PreconditionFailedException as e: # pragma: nocover return Response( status_code=412, content_type="application/json", - body=json_dumps({"error": str(e)}), + body=json.dumps({"error": str(e)}), ) except Exception as e: logger.error(f"Internal server error: {e}") return Response( status_code=500, content_type="application/json", - body=json_dumps({"error": "Internal server error"}), + body=json.dumps({"error": "Internal server error"}), ) def _check_precondition(self, user_id, collection_name, if_unmodified_since): diff --git a/lambda/src/routes/collections/delete.py b/lambda/src/routes/collections/delete.py index 30106ad3..ad84a58f 100644 --- a/lambda/src/routes/collections/delete.py +++ b/lambda/src/routes/collections/delete.py @@ -1,11 +1,12 @@ +import json + from aws_lambda_powertools import Logger from aws_lambda_powertools.event_handler import APIGatewayRestResolver, Response from src.services.storage_manager import StorageManager from src.shared.base_route import BaseRoute from src.shared.exceptions import CollectionNotFoundException, ValidationException -from src.shared.models import ValidationError, validate_collection_name -from src.shared.utils import json_dumps +from src.shared.models import ModifiedOutput, ValidationError, validate_collection_name logger = Logger() @@ -28,7 +29,7 @@ def handle(self, event) -> Response: return Response( status_code=401, content_type="application/json", - body=json_dumps({"error": "Unauthorized"}), + body=json.dumps({"error": "Unauthorized"}), ) path_params = event.path_parameters or {} @@ -54,12 +55,12 @@ def handle(self, event) -> Response: user_id, collection_name ) - response_body = {"modified": modified_timestamp} + result = ModifiedOutput(modified=modified_timestamp) return Response( status_code=200, content_type="application/json", - body=json_dumps(response_body), + body=result.model_dump_json(), headers={"X-Last-Modified": str(round(modified_timestamp, 2))}, ) @@ -67,18 +68,18 @@ def handle(self, event) -> Response: return Response( status_code=400, content_type="application/json", - body=json_dumps({"error": str(e)}), + body=json.dumps({"error": str(e)}), ) except CollectionNotFoundException as e: return Response( status_code=404, content_type="application/json", - body=json_dumps({"error": str(e)}), + body=json.dumps({"error": str(e)}), ) except Exception as e: logger.error(f"Internal server error: {e}") return Response( status_code=500, content_type="application/json", - body=json_dumps({"error": "Internal server error"}), + body=json.dumps({"error": "Internal server error"}), ) diff --git a/lambda/src/routes/collections/list.py b/lambda/src/routes/collections/list.py index 0c17e639..71598e91 100644 --- a/lambda/src/routes/collections/list.py +++ b/lambda/src/routes/collections/list.py @@ -1,9 +1,11 @@ +import json + from aws_lambda_powertools import Logger from aws_lambda_powertools.event_handler import APIGatewayRestResolver, Response from src.services.storage_manager import StorageManager from src.shared.base_route import BaseRoute -from src.shared.utils import json_dumps +from src.shared.models import CollectionDataOutput, CollectionsResponse logger = Logger() @@ -26,18 +28,27 @@ def handle(self, event) -> Response: return Response( status_code=401, content_type="application/json", - body=json_dumps({"error": "Unauthorized"}), + body=json.dumps({"error": "Unauthorized"}), ) # Get collections using storage manager collections = self.storage_manager.list_collections(user_id) - response_body = {"collections": [collection.to_dict() for collection in collections]} + collection_models = [ + CollectionDataOutput( + name=c.name, + modified=round(c.modified.timestamp(), 2), + count=c.count, + usage=c.usage, + ) + for c in collections + ] + response = CollectionsResponse(collections=collection_models) return Response( status_code=200, content_type="application/json", - body=json_dumps(response_body), + body=response.model_dump_json(), ) except Exception as e: @@ -45,5 +56,5 @@ def handle(self, event) -> Response: return Response( status_code=500, content_type="application/json", - body=json_dumps({"error": "Internal server error"}), + body=json.dumps({"error": "Internal server error"}), ) diff --git a/lambda/src/routes/collections/read.py b/lambda/src/routes/collections/read.py index 731e4b96..979436fd 100644 --- a/lambda/src/routes/collections/read.py +++ b/lambda/src/routes/collections/read.py @@ -1,11 +1,17 @@ +import json + from aws_lambda_powertools import Logger from aws_lambda_powertools.event_handler import APIGatewayRestResolver, Response from src.services.storage_manager import StorageManager from src.shared.base_route import BaseRoute from src.shared.exceptions import ValidationException -from src.shared.models import ValidationError, validate_collection_name -from src.shared.utils import json_dumps +from src.shared.models import ( + BSOListAdapter, + BSOOutput, + ValidationError, + validate_collection_name, +) logger = Logger() @@ -28,7 +34,7 @@ def handle(self, event) -> Response: return Response( status_code=401, content_type="application/json", - body=json_dumps({"error": "Unauthorized"}), + body=json.dumps({"error": "Unauthorized"}), ) path_params = event.path_parameters or {} @@ -51,7 +57,7 @@ def handle(self, event) -> Response: return Response( status_code=400, content_type="application/json", - body=json_dumps( + body=json.dumps( { "error": "Cannot specify both X-If-Modified-Since and X-If-Unmodified-Since" } @@ -68,7 +74,7 @@ def handle(self, event) -> Response: return Response( status_code=400, content_type="application/json", - body=json_dumps( + body=json.dumps( {"error": "X-If-Modified-Since must be a valid positive decimal"} ), ) @@ -110,17 +116,19 @@ def handle(self, event) -> Response: # Determine response format based on 'full' parameter full = self._parse_bool(query_params.get("full", "0")) + items = objects.get("items", []) + response_headers = {"X-Last-Modified": str(last_modified_ts)} + if full: # Return full BSO objects (without TTL field per Requirement 11.4) - response_body = [self._format_object(obj) for obj in objects.get("items", [])] + bso_models = [BSOOutput.from_bso(obj) for obj in items] + response_headers["X-Weave-Records"] = str(len(bso_models)) + body = BSOListAdapter.dump_json(bso_models, exclude_none=True).decode() else: # Return just BSO IDs - response_body = [obj.id for obj in objects.get("items", [])] - - response_headers = {"X-Last-Modified": str(last_modified_ts)} - - # Add X-Weave-Records header indicating total number of records (Requirement 2.14) - response_headers["X-Weave-Records"] = str(len(response_body)) + ids = [obj.id for obj in items] + response_headers["X-Weave-Records"] = str(len(ids)) + body = json.dumps(ids) # Add X-Weave-Next-Offset header if more results available if objects.get("next_offset") is not None: @@ -130,7 +138,7 @@ def handle(self, event) -> Response: return Response( status_code=200, content_type="application/json", - body=json_dumps(response_body), + body=body, headers=response_headers, ) @@ -138,14 +146,14 @@ def handle(self, event) -> Response: return Response( status_code=400, content_type="application/json", - body=json_dumps({"error": str(e)}), + body=json.dumps({"error": str(e)}), ) except Exception as e: logger.error(f"Internal server error: {e}") return Response( status_code=500, content_type="application/json", - body=json_dumps({"error": "Internal server error"}), + body=json.dumps({"error": "Internal server error"}), ) def _parse_timestamp(self, value): @@ -171,16 +179,3 @@ def _parse_bool(self, value): if value is None: return True # pragma: nocover return value.lower() in ("1", "true", "yes") - - def _format_object(self, obj): - """Format storage object for response""" - # Convert to dict using dataclass serialization - obj_dict = obj.to_dict() - - # Remove None values - if obj_dict.get("sortindex") is None: - del obj_dict["sortindex"] - if obj_dict.get("ttl") is None: - del obj_dict["ttl"] - - return obj_dict diff --git a/lambda/src/routes/collections/update.py b/lambda/src/routes/collections/update.py index 70d1110d..d45dcb6b 100644 --- a/lambda/src/routes/collections/update.py +++ b/lambda/src/routes/collections/update.py @@ -11,8 +11,12 @@ PreconditionFailedException, ValidationException, ) -from src.shared.models import BasicStorageObject, ValidationError, validate_collection_name -from src.shared.utils import json_dumps +from src.shared.models import ( + BasicStorageObject, + BatchResultOutput, + ValidationError, + validate_collection_name, +) logger = Logger() @@ -35,7 +39,7 @@ def handle(self, event) -> Response: return Response( status_code=401, content_type="application/json", - body=json_dumps({"error": "Unauthorized"}), + body=json.dumps({"error": "Unauthorized"}), ) path_params = event.path_parameters or {} @@ -93,18 +97,17 @@ def handle(self, event) -> Response: ) # Return Mozilla-compliant response format - # {"modified": timestamp, "success": [...], "failed": {...}} modified_ts = collection_data.modified.timestamp() - response_body = { - "modified": modified_ts, - "success": batch_result.success, - "failed": batch_result.failed, - } + result = BatchResultOutput( + modified=modified_ts, + success=batch_result.success, + failed=batch_result.failed, + ) return Response( status_code=200, content_type="application/json", - body=json_dumps(response_body), + body=result.model_dump_json(), headers={"X-Last-Modified": str(round(modified_ts, 2))}, ) @@ -112,24 +115,24 @@ def handle(self, event) -> Response: return Response( status_code=400, content_type="application/json", - body=json_dumps({"error": str(e)}), + body=json.dumps({"error": str(e)}), ) except CollectionNotFoundException as e: return Response( status_code=404, content_type="application/json", - body=json_dumps({"error": str(e)}), + body=json.dumps({"error": str(e)}), ) except PreconditionFailedException as e: return Response( status_code=412, content_type="application/json", - body=json_dumps({"error": str(e)}), + body=json.dumps({"error": str(e)}), ) except Exception as e: logger.error(f"Internal server error: {e}") return Response( status_code=500, content_type="application/json", - body=json_dumps({"error": "Internal server error"}), + body=json.dumps({"error": "Internal server error"}), ) diff --git a/lambda/src/routes/info/read_collections.py b/lambda/src/routes/info/read_collections.py index 977b8163..aa39589b 100644 --- a/lambda/src/routes/info/read_collections.py +++ b/lambda/src/routes/info/read_collections.py @@ -1,9 +1,10 @@ +import json + from aws_lambda_powertools import Logger from aws_lambda_powertools.event_handler import APIGatewayRestResolver, Response from src.services.storage_manager import StorageManager from src.shared.base_route import BaseRoute -from src.shared.utils import datetime_encoder, json_dumps logger = Logger() @@ -31,7 +32,7 @@ def handle(self, event) -> Response: return Response( status_code=401, content_type="application/json", - body=json_dumps({"error": "Unauthorized"}), + body=json.dumps({"error": "Unauthorized"}), ) # Get collections using storage manager @@ -39,13 +40,14 @@ def handle(self, event) -> Response: # Mozilla format: object mapping collection names to timestamps response_body = { - collection.name: datetime_encoder(collection.modified) for collection in collections + collection.name: round(collection.modified.timestamp(), 2) + for collection in collections } return Response( status_code=200, content_type="application/json", - body=json_dumps(response_body), + body=json.dumps(response_body), ) except Exception as e: @@ -53,5 +55,5 @@ def handle(self, event) -> Response: return Response( status_code=500, content_type="application/json", - body=json_dumps({"error": "Internal server error"}), + body=json.dumps({"error": "Internal server error"}), ) diff --git a/lambda/src/routes/info/read_configuration.py b/lambda/src/routes/info/read_configuration.py index adb0cc98..f10581a0 100644 --- a/lambda/src/routes/info/read_configuration.py +++ b/lambda/src/routes/info/read_configuration.py @@ -2,7 +2,7 @@ from aws_lambda_powertools.event_handler import APIGatewayRestResolver, Response from src.shared.base_route import BaseRoute -from src.shared.utils import json_dumps +from src.shared.models import ConfigurationOutput logger = Logger() @@ -51,29 +51,26 @@ def handle(self, event) -> Response: - max_total_bytes: Maximum combined payload size for batched uploads (optional) """ try: - response_body = { - "max_request_bytes": self.max_request_bytes, - "max_post_records": self.max_post_records, - "max_post_bytes": self.max_post_bytes, - "max_record_payload_bytes": self.max_record_payload_bytes, - } - - # Include optional limits only if configured - if self.max_total_records is not None: - response_body["max_total_records"] = self.max_total_records - if self.max_total_bytes is not None: - response_body["max_total_bytes"] = self.max_total_bytes - + result = ConfigurationOutput( + max_request_bytes=self.max_request_bytes, + max_post_records=self.max_post_records, + max_post_bytes=self.max_post_bytes, + max_record_payload_bytes=self.max_record_payload_bytes, + max_total_records=self.max_total_records, + max_total_bytes=self.max_total_bytes, + ) return Response( status_code=200, content_type="application/json", - body=json_dumps(response_body), + body=result.model_dump_json(exclude_none=True), ) except Exception as e: # pragma: nocover + import json + logger.error(f"Internal server error: {e}") return Response( status_code=500, content_type="application/json", - body=json_dumps({"error": "Internal server error"}), + body=json.dumps({"error": "Internal server error"}), ) diff --git a/lambda/src/routes/info/read_counts.py b/lambda/src/routes/info/read_counts.py index e57c4afd..b0f3fa1d 100644 --- a/lambda/src/routes/info/read_counts.py +++ b/lambda/src/routes/info/read_counts.py @@ -1,9 +1,10 @@ +import json + from aws_lambda_powertools import Logger from aws_lambda_powertools.event_handler import APIGatewayRestResolver, Response from src.services.storage_manager import StorageManager from src.shared.base_route import BaseRoute -from src.shared.utils import json_dumps logger = Logger() @@ -31,7 +32,7 @@ def handle(self, event) -> Response: return Response( status_code=401, content_type="application/json", - body=json_dumps({"error": "Unauthorized"}), + body=json.dumps({"error": "Unauthorized"}), ) # Get collections using storage manager @@ -43,7 +44,7 @@ def handle(self, event) -> Response: return Response( status_code=200, content_type="application/json", - body=json_dumps(response_body), + body=json.dumps(response_body), ) except Exception as e: @@ -51,5 +52,5 @@ def handle(self, event) -> Response: return Response( status_code=500, content_type="application/json", - body=json_dumps({"error": "Internal server error"}), + body=json.dumps({"error": "Internal server error"}), ) diff --git a/lambda/src/routes/info/read_quota.py b/lambda/src/routes/info/read_quota.py index d2aa5ac7..b8192245 100644 --- a/lambda/src/routes/info/read_quota.py +++ b/lambda/src/routes/info/read_quota.py @@ -1,9 +1,10 @@ +import json + from aws_lambda_powertools import Logger from aws_lambda_powertools.event_handler import APIGatewayRestResolver, Response from src.services.storage_manager import StorageManager from src.shared.base_route import BaseRoute -from src.shared.utils import json_dumps logger = Logger() @@ -36,7 +37,7 @@ def handle(self, event) -> Response: return Response( status_code=401, content_type="application/json", - body=json_dumps({"error": "Unauthorized"}), + body=json.dumps({"error": "Unauthorized"}), ) # Get collections using storage manager to calculate current usage @@ -52,7 +53,7 @@ def handle(self, event) -> Response: return Response( status_code=200, content_type="application/json", - body=json_dumps(response_body), + body=json.dumps(response_body), ) except Exception as e: @@ -60,5 +61,5 @@ def handle(self, event) -> Response: return Response( status_code=500, content_type="application/json", - body=json_dumps({"error": "Internal server error"}), + body=json.dumps({"error": "Internal server error"}), ) diff --git a/lambda/src/routes/info/read_usage.py b/lambda/src/routes/info/read_usage.py index aab586dd..58cdfdce 100644 --- a/lambda/src/routes/info/read_usage.py +++ b/lambda/src/routes/info/read_usage.py @@ -1,9 +1,10 @@ +import json + from aws_lambda_powertools import Logger from aws_lambda_powertools.event_handler import APIGatewayRestResolver, Response from src.services.storage_manager import StorageManager from src.shared.base_route import BaseRoute -from src.shared.utils import json_dumps logger = Logger() @@ -31,7 +32,7 @@ def handle(self, event) -> Response: return Response( status_code=401, content_type="application/json", - body=json_dumps({"error": "Unauthorized"}), + body=json.dumps({"error": "Unauthorized"}), ) # Get collections using storage manager @@ -43,7 +44,7 @@ def handle(self, event) -> Response: return Response( status_code=200, content_type="application/json", - body=json_dumps(response_body), + body=json.dumps(response_body), ) except Exception as e: @@ -51,5 +52,5 @@ def handle(self, event) -> Response: return Response( status_code=500, content_type="application/json", - body=json_dumps({"error": "Internal server error"}), + body=json.dumps({"error": "Internal server error"}), ) diff --git a/lambda/src/routes/profile/get_profile.py b/lambda/src/routes/profile/get_profile.py index 6fa7e3a8..1b4e3262 100644 --- a/lambda/src/routes/profile/get_profile.py +++ b/lambda/src/routes/profile/get_profile.py @@ -8,6 +8,7 @@ from src.services.jwt_verifier import JWTVerifier from src.shared.base_route import BaseRoute from src.shared.exceptions import InvalidTokenError +from src.shared.models import ProfileOutput class GetProfileRoute(BaseRoute): @@ -53,19 +54,17 @@ def handle(self, event) -> Response: return self._error(401, 110, "Account not found") uid = account["uid"] + result = ProfileOutput( + uid=uid, + email=account["email"], + avatar="", + avatar_default=True, + sub=uid, + ) return Response( status_code=200, content_type="application/json", - body=json.dumps( - { - "uid": uid, - "email": account["email"], - "locale": "en-US", - "avatar": "", - "avatarDefault": True, - "sub": uid, - } - ), + body=result.model_dump_json(by_alias=True), ) @staticmethod diff --git a/lambda/src/routes/storage/delete_all.py b/lambda/src/routes/storage/delete_all.py index f1cac161..14d28a0c 100644 --- a/lambda/src/routes/storage/delete_all.py +++ b/lambda/src/routes/storage/delete_all.py @@ -1,9 +1,11 @@ +import json + from aws_lambda_powertools import Logger from aws_lambda_powertools.event_handler import APIGatewayRestResolver, Response from src.services.storage_manager import StorageManager from src.shared.base_route import BaseRoute -from src.shared.utils import json_dumps +from src.shared.models import ModifiedOutput logger = Logger() @@ -26,16 +28,17 @@ def handle(self, event) -> Response: return Response( status_code=401, content_type="application/json", - body=json_dumps({"error": "Unauthorized"}), + body=json.dumps({"error": "Unauthorized"}), ) # Delete all collections and BSOs for the authenticated user modified_timestamp = self.storage_manager.delete_all_storage(user_id) + result = ModifiedOutput(modified=modified_timestamp) return Response( status_code=200, content_type="application/json", - body=json_dumps({"modified": modified_timestamp}), + body=result.model_dump_json(), headers={"X-Last-Modified": str(round(modified_timestamp, 2))}, ) @@ -44,5 +47,5 @@ def handle(self, event) -> Response: return Response( status_code=500, content_type="application/json", - body=json_dumps({"error": "Internal server error"}), + body=json.dumps({"error": "Internal server error"}), ) diff --git a/lambda/src/routes/storage/delete_root.py b/lambda/src/routes/storage/delete_root.py index 264f24ec..6eaebf61 100644 --- a/lambda/src/routes/storage/delete_root.py +++ b/lambda/src/routes/storage/delete_root.py @@ -1,9 +1,11 @@ +import json + from aws_lambda_powertools import Logger from aws_lambda_powertools.event_handler import APIGatewayRestResolver, Response from src.services.storage_manager import StorageManager from src.shared.base_route import BaseRoute -from src.shared.utils import json_dumps +from src.shared.models import ModifiedOutput logger = Logger() @@ -26,18 +28,17 @@ def handle(self, event) -> Response: return Response( status_code=401, content_type="application/json", - body=json_dumps({"error": "Unauthorized"}), + body=json.dumps({"error": "Unauthorized"}), ) # Delete all collections and BSOs for the authenticated user modified_timestamp = self.storage_manager.delete_all_storage(user_id) - response_body = {"modified": modified_timestamp} - + result = ModifiedOutput(modified=modified_timestamp) return Response( status_code=200, content_type="application/json", - body=json_dumps(response_body), + body=result.model_dump_json(), headers={"X-Last-Modified": str(round(modified_timestamp, 2))}, ) @@ -46,5 +47,5 @@ def handle(self, event) -> Response: return Response( status_code=500, content_type="application/json", - body=json_dumps({"error": "Internal server error"}), + body=json.dumps({"error": "Internal server error"}), ) diff --git a/lambda/src/routes/token/request.py b/lambda/src/routes/token/request.py index becc5264..240559a4 100644 --- a/lambda/src/routes/token/request.py +++ b/lambda/src/routes/token/request.py @@ -1,5 +1,6 @@ """RequestToken route for Firefox Sync Token Server""" +import json import re from dataclasses import asdict @@ -19,7 +20,8 @@ ServiceUnavailableError, ValidationException, ) -from src.shared.utils import get_weave_timestamp, json_dumps +from src.shared.models import TokenOutput +from src.shared.utils import get_weave_timestamp logger = Logger("token-server") @@ -167,10 +169,11 @@ def handle(self, event) -> Response: }, ) + result = TokenOutput.model_validate(asdict(token_response)) return Response( status_code=200, content_type="application/json", - body=json_dumps(asdict(token_response)), + body=result.model_dump_json(), headers={"X-Timestamp": str(int(float(get_weave_timestamp())))}, ) @@ -334,6 +337,6 @@ def _error_response( return Response( status_code=status_code, content_type="application/json", - body=json_dumps(body), + body=json.dumps(body), headers=headers, ) diff --git a/lambda/src/services/api_router.py b/lambda/src/services/api_router.py index b2962ee1..a034dd42 100644 --- a/lambda/src/services/api_router.py +++ b/lambda/src/services/api_router.py @@ -1,23 +1,16 @@ +import json from typing import Any, Sequence -from aws_lambda_powertools.event_handler import APIGatewayRestResolver, CORSConfig +from aws_lambda_powertools.event_handler import APIGatewayRestResolver, CORSConfig, Response from aws_lambda_powertools.event_handler.middlewares import BaseMiddlewareHandler +from aws_lambda_powertools.event_handler.openapi.exceptions import ( + RequestValidationError, + ResponseValidationError, +) from aws_lambda_powertools.utilities.typing import LambdaContext -from src.middlewares.hawk_auth import HawkAuthenticationError, HawkAuthMiddleware, UidMismatchError -from src.middlewares.request_logging import RequestLoggingMiddleware -from src.middlewares.weave_timestamp import WeaveTimestampMiddleware from src.shared.base_route import BaseRoute -__all__ = [ - "ApiRouter", - "HawkAuthMiddleware", - "HawkAuthenticationError", - "RequestLoggingMiddleware", - "UidMismatchError", - "WeaveTimestampMiddleware", -] - class ApiRouter: def __init__( @@ -26,8 +19,9 @@ def __init__( middlewares: Sequence[BaseMiddlewareHandler[Any]], cors: CORSConfig | None = None, exception_handlers: dict[type[Exception], Any] | None = None, + enable_validation: bool = False, ): - self.app = APIGatewayRestResolver(cors=cors) + self.app = APIGatewayRestResolver(cors=cors, enable_validation=enable_validation) self._routes = routes self._middlewares = middlewares @@ -39,6 +33,22 @@ def _register_exception_handlers(self, handlers: dict): for exc_type, handler_fn in handlers.items(): self.app.exception_handler(exc_type)(handler_fn) + @self.app.exception_handler(RequestValidationError) + def _handle_request_validation(ex: RequestValidationError): # pragma: nocover + return Response( + status_code=422, + content_type="application/json", + body=json.dumps({"error": "Validation error", "details": ex.errors()}), + ) + + @self.app.exception_handler(ResponseValidationError) + def _handle_response_validation(ex: ResponseValidationError): # pragma: nocover + return Response( + status_code=500, + content_type="application/json", + body=json.dumps({"error": "Internal server error"}), + ) + def _register_middleware(self): """Register middleware handlers""" self.app.use(middlewares=self._middlewares) # type: ignore diff --git a/lambda/src/services/storage_manager.py b/lambda/src/services/storage_manager.py index 029992a1..da4e7db1 100644 --- a/lambda/src/services/storage_manager.py +++ b/lambda/src/services/storage_manager.py @@ -51,30 +51,6 @@ def _object_sk(self, object_id: str) -> str: """Generate sort key for storage object""" return f"OBJECT#{object_id}" - def _encode_basic_storage_object( - self, user_id: str, collection_name: str, obj: BasicStorageObject - ) -> dict: - """Encode BasicStorageObject to DynamoDB format""" - obj_data = obj.to_dict() - obj_data[_PK] = self._collection_pk(user_id, collection_name) - obj_data[_SK] = self._object_sk(obj.id) - - # Add DynamoDB TTL attribute if ttl is set (Requirement 11.1-11.4) - if obj.ttl is not None: - # Calculate expiry as current_time + ttl (in seconds) - current_time = int(datetime.now(tz=timezone.utc).timestamp()) - obj_data["expiry"] = current_time + obj.ttl - - return obj_data - - def _encode_collection_data(self, user_id: str, collection_data: CollectionData) -> dict: - """Encode CollectionData to DynamoDB format""" - col_data = collection_data.to_dict() - col_data[_PK] = self._collection_pk(user_id, collection_data.name) - col_data[_SK] = self._metadata_sk() - col_data["user_id"] = user_id # GSI partition key for efficient user queries - return col_data - def _batch_get_existing_objects( self, pk: str, object_ids: list[str] ) -> dict[str, BasicStorageObject]: @@ -107,7 +83,7 @@ def _batch_get_existing_objects( ) for item in response.get("Responses", {}).get(self.table.name, []): - obj = BasicStorageObject.from_dict(item) + obj = BasicStorageObject.model_validate(item) result[obj.id] = obj # Handle unprocessed keys with retry @@ -117,7 +93,7 @@ def _batch_get_existing_objects( RequestItems={self.table.name: unprocessed} ) for item in response.get("Responses", {}).get(self.table.name, []): - obj = BasicStorageObject.from_dict(item) + obj = BasicStorageObject.model_validate(item) result[obj.id] = obj unprocessed = response.get("UnprocessedKeys", {}).get(self.table.name) @@ -148,7 +124,7 @@ def get_collection(self, user_id: str, collection_name: str) -> CollectionData: raise CollectionNotFoundException(f"Collection '{collection_name}' not found") item = response["Item"] - return CollectionData.from_dict(item) + return CollectionData.model_validate(item) except ClientError as e: if e.response["Error"]["Code"] == "ResourceNotFoundException": raise CollectionNotFoundException(f"Collection '{collection_name}' not found") @@ -184,7 +160,7 @@ def get_storage_object( ) item = response["Item"] - return BasicStorageObject.from_dict(item) + return BasicStorageObject.model_validate(item) except ClientError as e: if e.response["Error"]["Code"] == "ResourceNotFoundException": raise StorageObjectNotFoundException(f"Object '{object_id}' not found") @@ -233,7 +209,7 @@ def _write_batch_objects( sortindex=obj.sortindex, ttl=obj.ttl, ) - obj_item = self._encode_basic_storage_object(user_id, collection_name, updated_obj) + obj_item = updated_obj.to_item(user_id, collection_name) items_to_write.append((obj.id, obj_item, obj_delta, is_new_bso)) except Exception as e: # pragma: nocover failed[obj.id] = [str(e)] @@ -358,7 +334,7 @@ def create_or_update_collection( collection_data = CollectionData( name=collection_name, modified=modified, count=new_count, usage=new_usage ) - metadata_item = self._encode_collection_data(user_id, collection_data) + metadata_item = collection_data.to_item(user_id) self.table.put_item(Item=metadata_item) collection_data = CollectionData( @@ -521,7 +497,7 @@ def list_collections(self, user_id: str) -> List[CollectionData]: ) for item in response.get("Items", []): - collection = CollectionData.from_dict(item) + collection = CollectionData.model_validate(item) collections.append(collection) # Handle pagination if there are more results @@ -533,7 +509,7 @@ def list_collections(self, user_id: str) -> List[CollectionData]: ExclusiveStartKey=response["LastEvaluatedKey"], ) for item in response.get("Items", []): - collection = CollectionData.from_dict(item) + collection = CollectionData.model_validate(item) collections.append(collection) return collections @@ -618,14 +594,14 @@ def get_collection_objects( query_kwargs["ProjectionExpression"] = "SK, modified, sortindex" response = self.table.query(**query_kwargs) - items = [BasicStorageObject.from_dict(item) for item in response.get("Items", [])] + items = [BasicStorageObject.model_validate(item) for item in response.get("Items", [])] # Handle pagination from DynamoDB while "LastEvaluatedKey" in response: query_kwargs["ExclusiveStartKey"] = response["LastEvaluatedKey"] response = self.table.query(**query_kwargs) items.extend( - BasicStorageObject.from_dict(item) for item in response.get("Items", []) + BasicStorageObject.model_validate(item) for item in response.get("Items", []) ) # Apply newer/older filters for BatchGetItem path (not pushed to DynamoDB) @@ -744,11 +720,7 @@ def update_storage_object( sortindex=sortindex, ttl=ttl, ) - self.table.put_item( - Item=self._encode_basic_storage_object( - user_id=user_id, collection_name=collection_name, obj=obj - ) - ) + self.table.put_item(Item=obj.to_item(user_id, collection_name)) # Upsert collection metadata so list_collections reflects this write. # DynamoDB update_item creates the item if it doesn't exist, and ADD @@ -860,7 +832,7 @@ def delete_collection_objects( count=collection.count, usage=collection.usage, ) - metadata_item = self._encode_collection_data(user_id, collection_data) + metadata_item = collection_data.to_item(user_id) self.table.put_item(Item=metadata_item) return modified diff --git a/lambda/src/services/user_manager.py b/lambda/src/services/user_manager.py index db8db3fa..627b60d9 100644 --- a/lambda/src/services/user_manager.py +++ b/lambda/src/services/user_manager.py @@ -1,13 +1,13 @@ """User manager for DynamoDB operations on token server users""" from datetime import datetime, timezone +from decimal import Decimal from typing import List, Optional from botocore.exceptions import ClientError from src.shared.exceptions import InvalidClientStateError, ServiceUnavailableError from src.shared.user import UserRecord -from src.shared.utils import float_to_decimal _PK = "PK" PK_PREFIX = "USER" @@ -36,11 +36,6 @@ def _user_pk(self, user_id: str) -> str: """ return f"{PK_PREFIX}#{user_id}" - def _encode_user_record(self, user_record: UserRecord) -> dict: - encoded = user_record.to_dict() - encoded[_PK] = f"{PK_PREFIX}#{user_record.user_id}" - return encoded - def create_user(self, user_id: str, client_state: str = "") -> UserRecord: """Create a new user record with generation 0 @@ -72,7 +67,7 @@ def create_user(self, user_id: str, client_state: str = "") -> UserRecord: ) self.table.put_item( - Item=self._encode_user_record(user_record=user_record), + Item=user_record.to_item(), ConditionExpression="attribute_not_exists(PK)", ) @@ -118,7 +113,7 @@ def get_user(self, user_id: str) -> Optional[UserRecord]: item["client_state"] = "" if "client_state_history" not in item: item["client_state_history"] = [] - return UserRecord.from_dict(item) + return UserRecord.model_validate(item) except ClientError as e: if e.response["Error"]["Code"] in ( @@ -199,7 +194,7 @@ def update_user_client_state( ":inc": 1, ":client_state": client_state, ":new_history": new_history, - ":updated_at": float_to_decimal(current_time.timestamp()), + ":updated_at": Decimal(str(current_time.timestamp())), }, ReturnValues="ALL_NEW", ) @@ -210,7 +205,7 @@ def update_user_client_state( # Ensure client_state_history is present (for legacy records) if "client_state_history" not in updated_item: # pragma: nocover updated_item["client_state_history"] = [] - return UserRecord.from_dict(updated_item) + return UserRecord.model_validate(updated_item) except ClientError as e: if e.response["Error"]["Code"] in ( @@ -298,7 +293,7 @@ def increment_generation(self, user_id: str) -> int: UpdateExpression="SET generation = generation + :inc, updated_at = :updated_at", ExpressionAttributeValues={ ":inc": 1, - ":updated_at": float_to_decimal(current_time.timestamp()), + ":updated_at": Decimal(str(current_time.timestamp())), }, ReturnValues="ALL_NEW", ) diff --git a/lambda/src/shared/models.py b/lambda/src/shared/models.py index f38c36d5..8ad7b590 100644 --- a/lambda/src/shared/models.py +++ b/lambda/src/shared/models.py @@ -1,10 +1,9 @@ -from dataclasses import dataclass, field from datetime import datetime, timezone +from decimal import Decimal from typing import Dict, List, Optional -from dataclasses_json import DataClassJsonMixin, config - -from src.shared.utils import datetime_decoder, datetime_encoder +from pydantic import BaseModel, ConfigDict, Field, TypeAdapter, model_validator +from pydantic.alias_generators import to_camel class ValidationError(Exception): @@ -13,34 +12,83 @@ class ValidationError(Exception): pass -@dataclass -class BasicStorageObject(DataClassJsonMixin): - """Basic Storage Object model""" +class DynamoModel(BaseModel): + """Base for models stored in DynamoDB. + + Handles: + - Decimal/float → datetime coercion on read (model_validator) + - datetime → Decimal serialization on write (model_dump override) + - float → Decimal conversion for all numeric fields + """ + + model_config = ConfigDict(extra="ignore") + + @model_validator(mode="before") + @classmethod + def _coerce_timestamps(cls, data): + """Coerce Decimal/float values to datetime for annotated datetime fields.""" + if not isinstance(data, dict): + return data + for field_name, field_info in cls.model_fields.items(): + if field_info.annotation is datetime and field_name in data: + v = data[field_name] + if not isinstance(v, datetime): + data[field_name] = datetime.fromtimestamp(float(v), tz=timezone.utc) + return data + + def _to_dynamodb_dict(self) -> dict: + """Serialize to a DynamoDB-compatible dict (datetime/float → Decimal).""" + data = super().model_dump() + for k, v in data.items(): + if isinstance(v, datetime): + data[k] = Decimal(str(v.timestamp())) + elif isinstance(v, float): + data[k] = Decimal(str(v)) + return data + + +class BasicStorageObject(DynamoModel): + """Basic Storage Object — internal model for StorageManager ↔ DynamoDB.""" id: str - payload: str - modified: datetime = field(metadata=config(encoder=datetime_encoder, decoder=datetime_decoder)) + payload: str = "" + modified: datetime sortindex: Optional[int] = None ttl: Optional[int] = None + def to_item(self, user_id: str, collection_name: str) -> dict: + """Produce a complete DynamoDB item with PK/SK and optional TTL expiry.""" + item = self._to_dynamodb_dict() + item["PK"] = f"USER#{user_id}#COLLECTION#{collection_name}" + item["SK"] = f"OBJECT#{self.id}" + if self.ttl is not None: + item["expiry"] = int(datetime.now(tz=timezone.utc).timestamp()) + self.ttl + return item + -@dataclass -class CollectionData(DataClassJsonMixin): - """Collection metadata model""" +class CollectionData(DynamoModel): + """Collection metadata — internal model for StorageManager ↔ DynamoDB.""" name: str - modified: datetime = field(metadata=config(encoder=datetime_encoder, decoder=datetime_decoder)) - count: int - usage: int + modified: datetime + count: int = 0 + usage: int = 0 + + def to_item(self, user_id: str) -> dict: + """Produce a complete DynamoDB item with PK/SK and GSI key.""" + item = self._to_dynamodb_dict() + item["PK"] = f"USER#{user_id}#COLLECTION#{self.name}" + item["SK"] = "METADATA" + item["user_id"] = user_id + return item -@dataclass -class BatchResult(DataClassJsonMixin): - """Batch operation result model""" +class BatchResult(BaseModel): + """Batch operation result.""" success: List[str] failed: Dict[str, List[str]] - modified: datetime = field(metadata=config(encoder=datetime_encoder, decoder=datetime_decoder)) + modified: datetime def get_current_timestamp() -> float: @@ -79,51 +127,6 @@ def validate_payload_size(payload: str) -> None: ) -def validate_sortindex(sortindex: Optional[int]) -> None: - """ - Validate BSO sortindex. - - Args: - sortindex: The BSO sortindex value - - Raises: - ValidationError: If sortindex is not an integer or exceeds 9 digits - """ - if sortindex is None: - return - - if not isinstance(sortindex, int): - raise ValidationError(f"Sortindex must be an integer, got {type(sortindex).__name__}") - - if sortindex > MAX_SORTINDEX or sortindex < MIN_SORTINDEX: - raise ValidationError( - f"Sortindex {sortindex} exceeds maximum 9 digits (range: {MIN_SORTINDEX} to {MAX_SORTINDEX})" - ) - - -def validate_ttl(ttl: Optional[int]) -> None: - """ - Validate BSO TTL (Time-To-Live). - - Args: - ttl: The BSO TTL value in seconds - - Raises: - ValidationError: If TTL is not a positive integer or exceeds 9 digits - """ - if ttl is None: - return - - if not isinstance(ttl, int): - raise ValidationError(f"TTL must be an integer, got {type(ttl).__name__}") - - if ttl <= 0: - raise ValidationError(f"TTL must be a positive integer, got {ttl}") - - if ttl > MAX_TTL: - raise ValidationError(f"TTL {ttl} exceeds maximum 9 digits (max: {MAX_TTL})") - - def validate_bso_id(bso_id: str) -> None: """ Validate BSO ID. @@ -170,3 +173,243 @@ def validate_collection_name(collection_name: str) -> None: f"Collection name contains invalid character: {repr(char)}. " f"Only alphanumeric, underscore, hyphen, and period are allowed." ) + + +# --------------------------------------------------------------------------- +# Pydantic v2 models (new — coexist with dataclass models above) +# --------------------------------------------------------------------------- + + +class CamelModel(BaseModel): + """Base for FxA API models - camelCase on the wire, snake_case internally.""" + + model_config = ConfigDict( + alias_generator=to_camel, + populate_by_name=True, + ) + + +# --- Storage models (snake_case per Mozilla SyncStorage spec) --- + + +class BSOOutput(BaseModel): + id: str + payload: str + modified: float + sortindex: Optional[int] = None + + @classmethod + def from_bso(cls, bso: "BasicStorageObject") -> "BSOOutput": + return cls( + id=bso.id, + payload=bso.payload, + modified=round(bso.modified.timestamp(), 2), + sortindex=bso.sortindex, + ) + + +class BSOInput(BaseModel): + id: Optional[str] = None + payload: Optional[str] = None + sortindex: Optional[int] = Field(default=None, ge=-999999999, le=999999999) + ttl: Optional[int] = Field(default=None, gt=0, le=999999999) + + +class BatchResultOutput(BaseModel): + success: list[str] + failed: dict[str, list[str]] + modified: float + + +class CollectionDataOutput(BaseModel): + name: str + modified: float + count: int + usage: int + + +class ModifiedOutput(BaseModel): + modified: float + + +class CollectionsResponse(BaseModel): + collections: list[CollectionDataOutput] + + +BSOListAdapter = TypeAdapter(list[BSOOutput]) + + +# --- Device / Auth models (camelCase on the wire) --- + + +class DeviceInput(CamelModel): + id: Optional[str] = None + name: Optional[str] = None + type: Optional[str] = None + push_callback: Optional[str] = None + push_public_key: Optional[str] = None + push_auth_key: Optional[str] = None + available_commands: Optional[dict] = None + + +class DeviceOutput(CamelModel): + id: str + name: Optional[str] = None + type: Optional[str] = None + push_callback: Optional[str] = None + push_public_key: Optional[str] = None + push_auth_key: Optional[str] = None + push_endpoint_expired: bool = False + available_commands: dict = {} + session_token_id: Optional[str] = None + is_current_device: bool = False + created_at: Optional[int] = None + last_access_time: Optional[int] = None + + +DeviceListAdapter = TypeAdapter(list[DeviceOutput]) + + +class AttachedClientOutput(CamelModel): + client_id: Optional[str] = None + device_id: Optional[str] = None + session_token_id: Optional[str] = None + refresh_token_id: Optional[str] = None + is_current_session: bool = False + device_type: Optional[str] = None + name: Optional[str] = None + created_time: Optional[int] = None + last_access_time: Optional[int] = None + scope: Optional[list[str]] = None + location: dict = {} + user_agent: Optional[str] = None + os: Optional[str] = None + + +ClientListAdapter = TypeAdapter(list[AttachedClientOutput]) + + +class AccountCreateInput(CamelModel): + email: str + auth_pw: str = Field(min_length=64, max_length=64) + + +class AccountCreateOutput(CamelModel): + uid: str + session_token: str + key_fetch_token: str + verified: bool + + +class AccountLoginInput(CamelModel): + email: str + auth_pw: str = Field(min_length=64, max_length=64) + + +class AccountLoginOutput(CamelModel): + uid: str + session_token: str + key_fetch_token: Optional[str] = None + verified: bool + + +class AccountStatusOutput(BaseModel): + exists: bool + + +class AccountKeysOutput(BaseModel): + bundle: str + + +class SessionStatusOutput(BaseModel): + state: str + uid: str + + +class ScopedKeyDataInput(BaseModel): + scope: str + + +class ScopedKeyDataEntry(CamelModel): + identifier: str + key_rotation_secret: str + key_rotation_timestamp: int + + +class OAuthAuthorizationInput(CamelModel): + client_id: str + scope: str + state: str + redirect_uri: str = "urn:ietf:wg:oauth:2.0:oob" + code_challenge: Optional[str] = None + code_challenge_method: str = "S256" + keys_jwe: Optional[str] = None + + +class OAuthAuthorizationOutput(BaseModel): + code: str + state: str + redirect: str + + +class OAuthTokenInput(CamelModel): + grant_type: str + code: Optional[str] = None + code_verifier: Optional[str] = None + refresh_token: Optional[str] = None + scope: Optional[str] = None + client_id: Optional[str] = None + ttl: Optional[int] = None + + +class OAuthTokenOutput(CamelModel): + access_token: str + token_type: str = "bearer" + expires_in: int + scope: str + refresh_token: Optional[str] = None + auth_at: int + keys_jwe: Optional[str] = None + + +class OAuthDestroyInput(BaseModel): + token: str + + +class OIDCExchangeInput(CamelModel): + code: str + code_verifier: str + redirect_uri: str + + +class OIDCExchangeOutput(CamelModel): + email: str + access_token: str + account_exists: bool + + +class TokenOutput(BaseModel): + id: str + key: str + api_endpoint: str + uid: int + duration: int + hashalg: str + + +class ProfileOutput(CamelModel): + uid: str + email: str + locale: str = "en-US" + avatar: str + avatar_default: bool + sub: str + + +class ConfigurationOutput(BaseModel): + max_request_bytes: int + max_post_records: int + max_post_bytes: int + max_record_payload_bytes: int + max_total_records: Optional[int] = None + max_total_bytes: Optional[int] = None diff --git a/lambda/src/shared/oidc.py b/lambda/src/shared/oidc.py index 82714d40..54e548c3 100644 --- a/lambda/src/shared/oidc.py +++ b/lambda/src/shared/oidc.py @@ -3,11 +3,9 @@ from dataclasses import dataclass from typing import Optional -from dataclasses_json import DataClassJsonMixin - @dataclass -class OIDCTokenClaims(DataClassJsonMixin): +class OIDCTokenClaims: """ OIDC token claims extracted from validated token @@ -31,7 +29,7 @@ class OIDCTokenClaims(DataClassJsonMixin): @dataclass -class OIDCProviderConfig(DataClassJsonMixin): +class OIDCProviderConfig: """ OIDC provider configuration from .well-known/openid-configuration @@ -51,7 +49,7 @@ class OIDCProviderConfig(DataClassJsonMixin): @dataclass -class ErrorDetail(DataClassJsonMixin): +class ErrorDetail: """ Error detail for Firefox Sync protocol error responses diff --git a/lambda/src/shared/token.py b/lambda/src/shared/token.py index 684b97f9..386ebb62 100644 --- a/lambda/src/shared/token.py +++ b/lambda/src/shared/token.py @@ -2,11 +2,9 @@ from dataclasses import dataclass -from dataclasses_json import DataClassJsonMixin - @dataclass -class TokenResponse(DataClassJsonMixin): +class TokenResponse: """ Token response returned to Firefox Sync clients diff --git a/lambda/src/shared/user.py b/lambda/src/shared/user.py index e3601223..74b63bc8 100644 --- a/lambda/src/shared/user.py +++ b/lambda/src/shared/user.py @@ -1,16 +1,12 @@ """User record data model""" -from dataclasses import dataclass, field from datetime import datetime from typing import List -from dataclasses_json import DataClassJsonMixin, config +from src.shared.models import DynamoModel -from src.shared.utils import datetime_decoder, datetime_encoder - -@dataclass -class UserRecord(DataClassJsonMixin): +class UserRecord(DynamoModel): """ User record stored in DynamoDB @@ -28,10 +24,12 @@ class UserRecord(DataClassJsonMixin): user_id: str generation: int client_state: str - created_at: datetime = field( - metadata=config(encoder=datetime_encoder, decoder=datetime_decoder) - ) - updated_at: datetime = field( - metadata=config(encoder=datetime_encoder, decoder=datetime_decoder) - ) - client_state_history: List[str] = field(default_factory=list) + created_at: datetime + updated_at: datetime + client_state_history: List[str] = [] + + def to_item(self) -> dict: + """Produce a complete DynamoDB item with PK.""" + item = self._to_dynamodb_dict() + item["PK"] = f"USER#{self.user_id}" + return item diff --git a/lambda/src/shared/utils.py b/lambda/src/shared/utils.py index cc605e80..849fe961 100644 --- a/lambda/src/shared/utils.py +++ b/lambda/src/shared/utils.py @@ -1,29 +1,4 @@ -import json from datetime import datetime, timezone -from decimal import Decimal - - -def datetime_encoder(dt: datetime) -> Decimal: - """Convert datetime to Unix timestamp (Decimal) for DynamoDB serialization""" - # Treat naive datetime as UTC - if dt.tzinfo is None: - dt = dt.replace(tzinfo=timezone.utc) # pragma: nocover - return Decimal(str(dt.timestamp())) - - -def datetime_decoder(timestamp: float) -> datetime: - """Convert Unix timestamp (float/Decimal) to datetime for deserialization""" - return datetime.fromtimestamp(float(timestamp), tz=timezone.utc) - - -def float_to_decimal(value: float) -> Decimal: - """Convert float to Decimal for DynamoDB serialization""" - return Decimal(str(value)) - - -def decimal_to_float(value: Decimal) -> float: - """Convert Decimal to float for deserialization""" - return float(value) def get_weave_timestamp() -> str: @@ -39,20 +14,6 @@ def get_weave_timestamp() -> str: return f"{datetime.now(timezone.utc).timestamp():.2f}" -class DecimalEncoder(json.JSONEncoder): - """JSON encoder that handles Decimal objects by converting them to float""" - - def default(self, obj): - if isinstance(obj, Decimal): - return float(obj) - return super().default(obj) - - -def json_dumps(obj, **kwargs) -> str: - """JSON dumps wrapper that handles Decimal objects""" - return json.dumps(obj, cls=DecimalEncoder, **kwargs) - - def extract_hawk_request_params(event) -> tuple[str, str, str, int]: """Extract (method, path, host, port) for Hawk MAC verification. diff --git a/lambda/tests/routes/test_bso_routes.py b/lambda/tests/routes/test_bso_routes.py index 6eedd5e3..bbdc83bf 100644 --- a/lambda/tests/routes/test_bso_routes.py +++ b/lambda/tests/routes/test_bso_routes.py @@ -1048,7 +1048,7 @@ def test_handle_payload_too_large(self, mock_storage_manager): assert response.status_code == 413 assert response.body is not None body = json.loads(response.body) - assert "Payload size" in body["error"] + assert "Payload size" in body["error"] or "payload" in body["error"].lower() def test_handle_bso_id_too_long(self, mock_storage_manager): """Test 400 Bad Request for BSO ID exceeding 64 characters""" @@ -1103,7 +1103,7 @@ def test_handle_bso_id_non_printable_ascii(self, mock_storage_manager): assert "non-printable ASCII" in body["error"] def test_handle_sortindex_invalid(self, mock_storage_manager): - """Test 400 Bad Request for invalid sortindex""" + """Test 400 Bad Request for invalid sortindex (Pydantic rejects non-int)""" route = UpdateBSORoute(mock_storage_manager) event = APIGatewayProxyEvent( @@ -1124,7 +1124,7 @@ def test_handle_sortindex_invalid(self, mock_storage_manager): assert response.status_code == 400 assert response.body is not None body = json.loads(response.body) - assert "Sortindex must be an integer" in body["error"] + assert "sortindex" in body["error"] def test_handle_sortindex_exceeds_max(self, mock_storage_manager): """Test 400 Bad Request for sortindex exceeding 9 digits""" @@ -1148,10 +1148,10 @@ def test_handle_sortindex_exceeds_max(self, mock_storage_manager): assert response.status_code == 400 assert response.body is not None body = json.loads(response.body) - assert "Sortindex" in body["error"] and "exceeds" in body["error"] + assert "sortindex" in body["error"] def test_handle_ttl_invalid(self, mock_storage_manager): - """Test 400 Bad Request for invalid TTL""" + """Test 400 Bad Request for invalid TTL (Pydantic rejects non-int)""" route = UpdateBSORoute(mock_storage_manager) event = APIGatewayProxyEvent( @@ -1172,10 +1172,10 @@ def test_handle_ttl_invalid(self, mock_storage_manager): assert response.status_code == 400 assert response.body is not None body = json.loads(response.body) - assert "TTL must be an integer" in body["error"] + assert "ttl" in body["error"] def test_handle_ttl_negative(self, mock_storage_manager): - """Test 400 Bad Request for negative TTL""" + """Test 400 Bad Request for negative TTL (Pydantic enforces gt=0)""" route = UpdateBSORoute(mock_storage_manager) event = APIGatewayProxyEvent( @@ -1196,10 +1196,10 @@ def test_handle_ttl_negative(self, mock_storage_manager): assert response.status_code == 400 assert response.body is not None body = json.loads(response.body) - assert "TTL must be a positive integer" in body["error"] + assert "ttl" in body["error"] def test_handle_ttl_exceeds_max(self, mock_storage_manager): - """Test 400 Bad Request for TTL exceeding 9 digits""" + """Test 400 Bad Request for TTL exceeding 9 digits (Pydantic enforces le=999999999)""" route = UpdateBSORoute(mock_storage_manager) event = APIGatewayProxyEvent( @@ -1220,7 +1220,59 @@ def test_handle_ttl_exceeds_max(self, mock_storage_manager): assert response.status_code == 400 assert response.body is not None body = json.loads(response.body) - assert "TTL" in body["error"] and "exceeds" in body["error"] + assert "ttl" in body["error"] + + def test_handle_none_body(self, mock_storage_manager): + """Test 400 Bad Request when body is None""" + route = UpdateBSORoute(mock_storage_manager) + + event = APIGatewayProxyEvent( + { + "pathParameters": { + "uid": "12345", + "collectionName": "bookmarks", + "objectId": "item123", + }, + "body": None, + "headers": {}, + "requestContext": {"hawk_uid": "test-user-123"}, + } + ) + + response = route.handle(event) + + assert response.status_code == 400 + assert response.body is not None + body = json.loads(response.body) + assert "Invalid request body" in body["error"] + + def test_handle_payload_too_large_multibyte(self, mock_storage_manager): + """Test 413 for payload within char limit but exceeding byte limit (multi-byte)""" + route = UpdateBSORoute(mock_storage_manager) + + # Each char is 2 bytes in UTF-8; total chars = 200000 (within 262144 char limit) + # but total bytes = 400000 (exceeds 256 * 1024 = 262144 byte limit) + multibyte_payload = "\u00e9" * 200000 + + event = APIGatewayProxyEvent( + { + "pathParameters": { + "uid": "12345", + "collectionName": "bookmarks", + "objectId": "item123", + }, + "body": json.dumps({"id": "item123", "payload": multibyte_payload}), + "headers": {}, + "requestContext": {"hawk_uid": "test-user-123"}, + } + ) + + response = route.handle(event) + + assert response.status_code == 413 + assert response.body is not None + body = json.loads(response.body) + assert "Payload size" in body["error"] def test_handle_no_payload(self, mock_storage_manager): """Test successful update without payload (partial update)""" diff --git a/lambda/tests/services/test_api_router.py b/lambda/tests/services/test_api_router.py deleted file mode 100644 index 808da341..00000000 --- a/lambda/tests/services/test_api_router.py +++ /dev/null @@ -1,563 +0,0 @@ -"""Tests for ApiRouter""" - -from unittest.mock import MagicMock, patch - -from aws_lambda_powertools.event_handler import Response - -from src.services.api_router import ( - ApiRouter, - HawkAuthenticationError, - HawkAuthMiddleware, - RequestLoggingMiddleware, - UidMismatchError, - WeaveTimestampMiddleware, -) -from src.services.token_generator import TokenGenerator - - -def test_api_router_initialization(): - """Test that ApiRouter initializes with routes""" - mock_route1 = MagicMock() - mock_route2 = MagicMock() - routes = [mock_route1, mock_route2] - - with patch("src.services.api_router.APIGatewayRestResolver") as mock_resolver_class: - mock_resolver_instance = MagicMock() - mock_resolver_class.return_value = mock_resolver_instance - - router = ApiRouter(routes=routes, middlewares=[]) # type: ignore[arg-type] - - # Verify resolver was created - mock_resolver_class.assert_called_once() - - # Verify routes were registered - mock_route1.bind.assert_called_once_with(mock_resolver_instance) - mock_route2.bind.assert_called_once_with(mock_resolver_instance) - - assert router.app == mock_resolver_instance - assert router._routes == routes - - -def test_api_router_handler_calls_resolver(): - """Test that handler method calls the resolver's resolve method""" - event = {"httpMethod": "GET", "path": "/test"} - context = MagicMock() - - with patch("src.services.api_router.APIGatewayRestResolver") as mock_resolver_class: - mock_resolver_instance = MagicMock() - mock_resolver_class.return_value = mock_resolver_instance - - router = ApiRouter(routes=[], middlewares=[]) - router.handler(event, context) - - # Verify resolver.resolve was called with event and context - mock_resolver_instance.resolve.assert_called_once_with(event=event, context=context) - - -def test_api_router_handler_with_different_responses(): - """Test middleware adds X-Weave-Timestamp to various responses""" - with patch("src.middlewares.weave_timestamp.get_weave_timestamp") as mock_timestamp: - mock_timestamp.return_value = "1702345678.12" - - middleware = WeaveTimestampMiddleware() - mock_app = MagicMock() - - test_cases = [ - Response(status_code=200, body='{"result": "success"}'), - Response(status_code=404, body='{"error": "Not found"}'), - Response(status_code=500, body='{"error": "Internal error"}'), - Response( - status_code=201, - body='{"created": true}', - headers={"X-Custom": "value"}, - ), - ] - - for response in test_cases: - # Mock next_middleware to return the response - mock_next = MagicMock(return_value=response) - - result = middleware.handler(mock_app, mock_next) - - # Verify X-Weave-Timestamp header was added - assert "X-Weave-Timestamp" in result.headers - assert result.headers["X-Weave-Timestamp"] == "1702345678.12" - - # Verify existing headers are preserved - if response.headers: - for key, value in response.headers.items(): - if key != "X-Weave-Timestamp": - assert result.headers[key] == value - - -def test_api_router_empty_routes(): - """Test ApiRouter with no routes""" - with patch("src.services.api_router.APIGatewayRestResolver") as mock_resolver_class: - mock_resolver_instance = MagicMock() - mock_resolver_class.return_value = mock_resolver_instance - - router = ApiRouter(routes=[], middlewares=[]) - - # Verify resolver was created - mock_resolver_class.assert_called_once() - assert router._routes == [] - - -def test_api_router_handler_passes_context(): - """Test that handler passes context correctly to resolver""" - event = {"test": "event"} - context = MagicMock() - context.function_name = "test-function" - context.aws_request_id = "test-request-id" - - with patch("src.services.api_router.APIGatewayRestResolver") as mock_resolver_class: - mock_resolver_instance = MagicMock() - mock_resolver_class.return_value = mock_resolver_instance - - router = ApiRouter(routes=[], middlewares=[]) - router.handler(event, context) - - # Verify context was passed through - call_kwargs = mock_resolver_instance.resolve.call_args[1] - assert call_kwargs["context"] == context - assert call_kwargs["event"] == event - - -def test_api_router_adds_x_weave_timestamp_to_all_responses(): - """Test that X-Weave-Timestamp header is added to all responses (Requirements 9.1-9.4)""" - with patch("src.middlewares.weave_timestamp.get_weave_timestamp") as mock_timestamp: - mock_timestamp.return_value = "1702345678.12" - - middleware = WeaveTimestampMiddleware() - mock_app = MagicMock() - mock_response = Response(status_code=200, body='{"test": "data"}') - mock_next = MagicMock(return_value=mock_response) - - result = middleware.handler(mock_app, mock_next) - - # Verify X-Weave-Timestamp header is present - assert "X-Weave-Timestamp" in result.headers - assert result.headers["X-Weave-Timestamp"] == "1702345678.12" - - -def test_weave_timestamp_middleware_preserves_existing_headers(): - """Test that middleware preserves existing response headers""" - with patch("src.middlewares.weave_timestamp.get_weave_timestamp") as mock_timestamp: - mock_timestamp.return_value = "1702345678.12" - - middleware = WeaveTimestampMiddleware() - mock_app = MagicMock() - mock_response = Response( - status_code=200, - body='{"test": "data"}', - headers={"X-Custom": "value", "Content-Type": "application/json"}, - ) - mock_next = MagicMock(return_value=mock_response) - - result = middleware.handler(mock_app, mock_next) - - # Verify all headers are present - assert result.headers["X-Weave-Timestamp"] == "1702345678.12" - assert result.headers["X-Custom"] == "value" - assert result.headers["Content-Type"] == "application/json" - - -def test_api_router_registers_middleware(): - """Test that ApiRouter registers both RequestLoggingMiddleware and WeaveTimestampMiddleware""" - with patch("src.services.api_router.APIGatewayRestResolver") as mock_resolver_class: - mock_resolver_instance = MagicMock() - mock_resolver_class.return_value = mock_resolver_instance - - request_logging = RequestLoggingMiddleware() - weave_timestamp = WeaveTimestampMiddleware() - middlewares = [request_logging, weave_timestamp] - - ApiRouter(routes=[], middlewares=middlewares) - - # Verify middleware was registered - mock_resolver_instance.use.assert_called_once() - registered_middlewares = mock_resolver_instance.use.call_args[1]["middlewares"] - assert len(registered_middlewares) == 2 - assert registered_middlewares[0] is request_logging - assert registered_middlewares[1] is weave_timestamp - - -def test_weave_timestamp_middleware_calls_next(): - """Test that middleware calls the next handler in the chain""" - with patch("src.middlewares.weave_timestamp.get_weave_timestamp") as mock_timestamp: - mock_timestamp.return_value = "1702345678.12" - - middleware = WeaveTimestampMiddleware() - mock_app = MagicMock() - mock_response = Response(status_code=200, body='{"test": "data"}') - mock_next = MagicMock(return_value=mock_response) - - middleware.handler(mock_app, mock_next) - - # Verify next middleware was called - mock_next.assert_called_once_with(mock_app) - - -def test_request_logging_middleware_logs_request_and_response(): - """Test that RequestLoggingMiddleware logs request and response information (Requirements 14.1-14.4)""" - with patch("src.middlewares.request_logging.logger") as mock_logger: - middleware = RequestLoggingMiddleware() - mock_app = MagicMock() - mock_app.current_event = { - "httpMethod": "GET", - "path": "/1.5/12345/storage/bookmarks", - "requestContext": {"hawk_uid": "user123"}, - } - mock_response = Response(status_code=200, body='{"test": "data"}') - mock_next = MagicMock(return_value=mock_response) - - result = middleware.handler(mock_app, mock_next) - - # Verify request received was logged - assert mock_logger.info.call_count == 2 - first_call = mock_logger.info.call_args_list[0] - assert first_call[0][0] == "Request received" - assert first_call[1]["extra"]["method"] == "GET" - assert first_call[1]["extra"]["path"] == "/1.5/12345/storage/bookmarks" - assert first_call[1]["extra"]["user_id"] == "user123" - - # Verify request completed was logged - second_call = mock_logger.info.call_args_list[1] - assert second_call[0][0] == "Request completed" - assert second_call[1]["extra"]["method"] == "GET" - assert second_call[1]["extra"]["path"] == "/1.5/12345/storage/bookmarks" - assert second_call[1]["extra"]["user_id"] == "user123" - assert second_call[1]["extra"]["status_code"] == 200 - assert "duration_ms" in second_call[1]["extra"] - - # Verify response is returned - assert result == mock_response - - -def test_request_logging_middleware_logs_anonymous_user(): - """Test that RequestLoggingMiddleware logs 'anonymous' when hawk_uid is not present""" - with patch("src.middlewares.request_logging.logger") as mock_logger: - middleware = RequestLoggingMiddleware() - mock_app = MagicMock() - mock_app.current_event = { - "httpMethod": "GET", - "path": "/1.5/12345/info/configuration", - "requestContext": {}, - } - mock_response = Response(status_code=200, body='{"test": "data"}') - mock_next = MagicMock(return_value=mock_response) - - middleware.handler(mock_app, mock_next) - - # Verify anonymous user was logged - first_call = mock_logger.info.call_args_list[0] - assert first_call[1]["extra"]["user_id"] == "anonymous" - - -def test_request_logging_middleware_logs_errors(): - """Test that RequestLoggingMiddleware logs errors with stack trace (Requirements 14.2)""" - with patch("src.middlewares.request_logging.logger") as mock_logger: - middleware = RequestLoggingMiddleware() - mock_app = MagicMock() - mock_app.current_event = { - "httpMethod": "POST", - "path": "/1.5/12345/storage/bookmarks", - "requestContext": {"hawk_uid": "user123"}, - } - test_exception = ValueError("Test error") - mock_next = MagicMock(side_effect=test_exception) - - # Verify exception is re-raised - try: - middleware.handler(mock_app, mock_next) - assert False, "Expected exception to be raised" - except ValueError: - pass - - # Verify request received was logged - assert mock_logger.info.call_count == 1 - - # Verify error was logged - mock_logger.error.assert_called_once() - error_call = mock_logger.error.call_args - assert error_call[0][0] == "Request failed" - assert error_call[1]["extra"]["method"] == "POST" - assert error_call[1]["extra"]["path"] == "/1.5/12345/storage/bookmarks" - assert error_call[1]["extra"]["user_id"] == "user123" - assert error_call[1]["extra"]["error_type"] == "ValueError" - assert error_call[1]["extra"]["error_message"] == "Test error" - assert "duration_ms" in error_call[1]["extra"] - assert error_call[1]["exc_info"] is True # Stack trace included - - -def test_request_logging_middleware_never_logs_payloads(): - """Test that RequestLoggingMiddleware never logs BSO payloads (Requirements 14.4)""" - with patch("src.middlewares.request_logging.logger") as mock_logger: - middleware = RequestLoggingMiddleware() - mock_app = MagicMock() - # Event with body containing sensitive data - mock_app.current_event = { - "httpMethod": "PUT", - "path": "/1.5/12345/storage/bookmarks/abc123", - "body": '{"payload": "sensitive encrypted data", "sortindex": 100}', - "requestContext": {"hawk_uid": "user123"}, - } - mock_response = Response(status_code=200, body='{"modified": 1702345678.12}') - mock_next = MagicMock(return_value=mock_response) - - middleware.handler(mock_app, mock_next) - - # Verify no log call contains the sensitive payload - for call in mock_logger.info.call_args_list: - log_message = str(call) - assert "sensitive encrypted data" not in log_message - assert "body" not in str(call[1].get("extra", {})) - - -class TestHawkAuthMiddlewareUidValidation: - """Tests for HawkAuthMiddleware UID validation (replaces UidValidationMiddleware)""" - - def test_matching_uid_passes_through(self): - """Test that a matching uid allows the request to proceed""" - hawk_service = MagicMock() - user_id = "test-user-123" - generation = 0 - expected_uid = str(TokenGenerator.generate_uid(user_id, generation)) - - from src.services.hawk_service import HawkCredentials - - creds = HawkCredentials( - user_id=user_id, - generation=generation, - expiry=9999999999, - hawk_id="hawkid123", - ) - hawk_service.validate.return_value = creds - middleware = HawkAuthMiddleware(hawk_service=hawk_service) - - mock_app = MagicMock() - event = MagicMock() - event.headers = {"Authorization": 'Hawk id="hawkid123"'} - event.http_method = "GET" - event.path = f"/1.5/{expected_uid}/storage/bookmarks" - event.query_string_parameters = None - event.request_context.domain_name = "storage.example.com" - raw_event: dict = { - "requestContext": {}, - "pathParameters": {"uid": expected_uid}, - } - event.__getitem__ = lambda self, key: raw_event[key] - event.__setitem__ = lambda self, key, val: raw_event.__setitem__(key, val) - event.get = lambda key, default=None: raw_event.get(key, default) - mock_app.current_event = event - - mock_response = Response(status_code=200, body='{"ok": true}') - mock_next = MagicMock(return_value=mock_response) - - result = middleware.handler(mock_app, mock_next) - - mock_next.assert_called_once_with(mock_app) - assert result.status_code == 200 - - def test_mismatched_uid_raises_uid_mismatch_error(self): - """Test that a mismatched uid raises UidMismatchError""" - hawk_service = MagicMock() - user_id = "test-user-123" - generation = 0 - - from src.services.hawk_service import HawkCredentials - - creds = HawkCredentials( - user_id=user_id, - generation=generation, - expiry=9999999999, - hawk_id="hawkid123", - ) - hawk_service.validate.return_value = creds - middleware = HawkAuthMiddleware(hawk_service=hawk_service) - - mock_app = MagicMock() - event = MagicMock() - event.headers = {"Authorization": 'Hawk id="hawkid123"'} - event.http_method = "GET" - event.path = "/1.5/wrong-uid/storage/bookmarks" - event.query_string_parameters = None - event.request_context.domain_name = "storage.example.com" - raw_event: dict = { - "requestContext": {}, - "pathParameters": {"uid": "wrong-uid"}, - } - event.__getitem__ = lambda self, key: raw_event[key] - event.__setitem__ = lambda self, key, val: raw_event.__setitem__(key, val) - event.get = lambda key, default=None: raw_event.get(key, default) - mock_app.current_event = event - - mock_next = MagicMock() - - import pytest - - with pytest.raises(UidMismatchError): - middleware.handler(mock_app, mock_next) - - mock_next.assert_not_called() - - def test_missing_uid_skips_validation(self): - """Test that missing uid path parameter skips UID validation""" - hawk_service = MagicMock() - user_id = "test-user-123" - - from src.services.hawk_service import HawkCredentials - - creds = HawkCredentials( - user_id=user_id, - generation=0, - expiry=9999999999, - hawk_id="hawkid123", - ) - hawk_service.validate.return_value = creds - middleware = HawkAuthMiddleware(hawk_service=hawk_service) - - mock_app = MagicMock() - event = MagicMock() - event.headers = {"Authorization": 'Hawk id="hawkid123"'} - event.http_method = "GET" - event.path = "/test" - event.query_string_parameters = None - event.request_context.domain_name = "storage.example.com" - raw_event: dict = { - "requestContext": {}, - "pathParameters": {}, - } - event.__getitem__ = lambda self, key: raw_event[key] - event.__setitem__ = lambda self, key, val: raw_event.__setitem__(key, val) - event.get = lambda key, default=None: raw_event.get(key, default) - mock_app.current_event = event - - mock_response = Response(status_code=200, body='{"ok": true}') - mock_next = MagicMock(return_value=mock_response) - - result = middleware.handler(mock_app, mock_next) - - mock_next.assert_called_once_with(mock_app) - assert result.status_code == 200 - - def test_null_path_parameters_skips_validation(self): - """Test that null pathParameters skips UID validation""" - hawk_service = MagicMock() - user_id = "test-user-123" - - from src.services.hawk_service import HawkCredentials - - creds = HawkCredentials( - user_id=user_id, - generation=0, - expiry=9999999999, - hawk_id="hawkid123", - ) - hawk_service.validate.return_value = creds - middleware = HawkAuthMiddleware(hawk_service=hawk_service) - - mock_app = MagicMock() - event = MagicMock() - event.headers = {"Authorization": 'Hawk id="hawkid123"'} - event.http_method = "GET" - event.path = "/test" - event.query_string_parameters = None - event.request_context.domain_name = "storage.example.com" - raw_event: dict = { - "requestContext": {}, - "pathParameters": None, - } - event.__getitem__ = lambda self, key: raw_event[key] - event.__setitem__ = lambda self, key, val: raw_event.__setitem__(key, val) - event.get = lambda key, default=None: raw_event.get(key, default) - mock_app.current_event = event - - mock_response = Response(status_code=200, body='{"ok": true}') - mock_next = MagicMock(return_value=mock_response) - - result = middleware.handler(mock_app, mock_next) - - mock_next.assert_called_once_with(mock_app) - assert result.status_code == 200 - - def test_different_generation_produces_different_uid(self): - """Test that different generation values produce different expected uids""" - hawk_service = MagicMock() - user_id = "test-user-123" - gen0_uid = str(TokenGenerator.generate_uid(user_id, 0)) - gen1_uid = str(TokenGenerator.generate_uid(user_id, 1)) - - from src.services.hawk_service import HawkCredentials - - # Request with gen0 uid but gen1 in credentials should fail - creds = HawkCredentials( - user_id=user_id, - generation=1, - expiry=9999999999, - hawk_id="hawkid123", - ) - hawk_service.validate.return_value = creds - middleware = HawkAuthMiddleware(hawk_service=hawk_service) - - mock_app = MagicMock() - event = MagicMock() - event.headers = {"Authorization": 'Hawk id="hawkid123"'} - event.http_method = "GET" - event.path = f"/1.5/{gen0_uid}/storage/bookmarks" - event.query_string_parameters = None - event.request_context.domain_name = "storage.example.com" - raw_event: dict = { - "requestContext": {}, - "pathParameters": {"uid": gen0_uid}, - } - event.__getitem__ = lambda self, key: raw_event[key] - event.__setitem__ = lambda self, key, val: raw_event.__setitem__(key, val) - event.get = lambda key, default=None: raw_event.get(key, default) - mock_app.current_event = event - - mock_next = MagicMock() - - import pytest - - with pytest.raises(UidMismatchError): - middleware.handler(mock_app, mock_next) - - mock_next.assert_not_called() - - # But gen1 uid with gen1 in credentials should pass - raw_event["pathParameters"]["uid"] = gen1_uid - mock_response = Response(status_code=200, body='{"ok": true}') - mock_next_pass = MagicMock(return_value=mock_response) - - result = middleware.handler(mock_app, mock_next_pass) - - mock_next_pass.assert_called_once_with(mock_app) - assert result.status_code == 200 - - -def test_api_router_registers_exception_handlers(): - """Test that ApiRouter registers exception handlers""" - with patch("src.services.api_router.APIGatewayRestResolver") as mock_resolver_class: - mock_resolver_instance = MagicMock() - mock_resolver_class.return_value = mock_resolver_instance - - def handle_hawk_auth(ex): - return Response(status_code=401, body='{"error": "Unauthorized"}') - - def handle_uid_mismatch(ex): - return Response(status_code=403, body='{"error": "uid mismatch"}') - - exception_handlers = { - HawkAuthenticationError: handle_hawk_auth, - UidMismatchError: handle_uid_mismatch, - } - - ApiRouter( - routes=[], - middlewares=[], - exception_handlers=exception_handlers, - ) - - # Verify exception_handler was called for each handler - assert mock_resolver_instance.exception_handler.call_count == 2 diff --git a/lambda/tests/services/test_user_manager.py b/lambda/tests/services/test_user_manager.py index 863f2231..85b5990c 100644 --- a/lambda/tests/services/test_user_manager.py +++ b/lambda/tests/services/test_user_manager.py @@ -115,7 +115,7 @@ def test_get_or_create_user_dynamodb_unavailable( ) with pytest.raises(ServiceUnavailableError) as exc_info: - user_manager.get_or_create_user(123456789) + user_manager.get_or_create_user("123456789") assert "DynamoDB unavailable" in str(exc_info.value.message) @@ -171,7 +171,7 @@ def test_increment_generation_dynamodb_unavailable( ) with pytest.raises(ServiceUnavailableError) as exc_info: - user_manager.increment_generation(123456789) + user_manager.increment_generation("123456789") assert "DynamoDB unavailable" in str(exc_info.value.message) @@ -269,7 +269,7 @@ def test_validate_generation_dynamodb_unavailable( ) with pytest.raises(ServiceUnavailableError) as exc_info: - user_manager.validate_generation(123456789, 0) + user_manager.validate_generation("123456789", 0) assert "DynamoDB unavailable" in str(exc_info.value.message) @@ -346,7 +346,7 @@ def test_get_or_create_user_unexpected_error( ) with pytest.raises(ClientError) as exc_info: - user_manager.get_or_create_user(123456789) + user_manager.get_or_create_user("123456789") assert exc_info.value.response["Error"]["Code"] == "ValidationException" @@ -364,7 +364,7 @@ def test_get_user_unexpected_error( ) with pytest.raises(ClientError) as exc_info: - user_manager.get_user(123456789) + user_manager.get_user("123456789") assert exc_info.value.response["Error"]["Code"] == "ValidationException" @@ -382,7 +382,7 @@ def test_increment_generation_unexpected_error( ) with pytest.raises(ClientError) as exc_info: - user_manager.increment_generation(123456789) + user_manager.increment_generation("123456789") assert exc_info.value.response["Error"]["Code"] == "ValidationException" diff --git a/lambda/tests/shared/test_models.py b/lambda/tests/shared/test_models.py index 3580c595..ddda2817 100644 --- a/lambda/tests/shared/test_models.py +++ b/lambda/tests/shared/test_models.py @@ -1,18 +1,24 @@ +from decimal import Decimal + import pytest +from pydantic import ValidationError as PydanticValidationError from src.shared.models import ( MAX_BSO_ID_LENGTH, MAX_COLLECTION_NAME_LENGTH, MAX_PAYLOAD_BYTES, - MAX_SORTINDEX, - MAX_TTL, - MIN_SORTINDEX, + AccountCreateInput, + BasicStorageObject, + BatchResultOutput, + BSOInput, + BSOOutput, + CollectionDataOutput, + DeviceOutput, + ModifiedOutput, ValidationError, validate_bso_id, validate_collection_name, validate_payload_size, - validate_sortindex, - validate_ttl, ) @@ -38,112 +44,42 @@ def test_empty_payload(self): validate_payload_size("") # Should not raise -class TestValidateSortindex: - def test_valid_sortindex(self): - """Valid sortindex should not raise exception""" - validate_sortindex(100) # Should not raise - - def test_none_sortindex(self): - """None sortindex should be valid""" - validate_sortindex(None) # Should not raise - - def test_sortindex_at_max(self): - """Sortindex at max value should be valid""" - validate_sortindex(MAX_SORTINDEX) # Should not raise - - def test_sortindex_at_min(self): - """Sortindex at min value should be valid""" - validate_sortindex(MIN_SORTINDEX) # Should not raise - - def test_sortindex_exceeds_max(self): - """Sortindex exceeding max should raise ValidationError""" - with pytest.raises(ValidationError, match="exceeds maximum 9 digits"): - validate_sortindex(MAX_SORTINDEX + 1) - - def test_sortindex_below_min(self): - """Sortindex below min should raise ValidationError""" - with pytest.raises(ValidationError, match="exceeds maximum 9 digits"): - validate_sortindex(MIN_SORTINDEX - 1) - - def test_sortindex_not_integer(self): - """Non-integer sortindex should raise ValidationError""" - with pytest.raises(ValidationError, match="must be an integer"): - validate_sortindex("100") # type: ignore - - -class TestValidateTTL: - def test_valid_ttl(self): - """Valid TTL should not raise exception""" - validate_ttl(3600) # Should not raise - - def test_none_ttl(self): - """None TTL should be valid""" - validate_ttl(None) # Should not raise - - def test_ttl_at_max(self): - """TTL at max value should be valid""" - validate_ttl(MAX_TTL) # Should not raise - - def test_ttl_exceeds_max(self): - """TTL exceeding max should raise ValidationError""" - with pytest.raises(ValidationError, match="exceeds maximum 9 digits"): - validate_ttl(MAX_TTL + 1) - - def test_ttl_zero(self): - """TTL of zero should raise ValidationError""" - with pytest.raises(ValidationError, match="must be a positive integer"): - validate_ttl(0) - - def test_ttl_negative(self): - """Negative TTL should raise ValidationError""" - with pytest.raises(ValidationError, match="must be a positive integer"): - validate_ttl(-100) - - def test_ttl_not_integer(self): - """Non-integer TTL should raise ValidationError""" - with pytest.raises(ValidationError, match="must be an integer"): - validate_ttl("3600") # type: ignore - - class TestValidateBSOId: def test_valid_bso_id(self): """Valid BSO ID should not raise exception""" - validate_bso_id("valid-bso-id-123") # Should not raise + validate_bso_id("valid-bso-id") # Should not raise def test_bso_id_at_max_length(self): - """BSO ID at max length should be valid""" - bso_id = "a" * MAX_BSO_ID_LENGTH - validate_bso_id(bso_id) # Should not raise + """BSO ID at exactly max length should be valid""" + validate_bso_id("a" * MAX_BSO_ID_LENGTH) # Should not raise def test_bso_id_exceeds_max_length(self): """BSO ID exceeding max length should raise ValidationError""" - bso_id = "a" * (MAX_BSO_ID_LENGTH + 1) - with pytest.raises(ValidationError, match="exceeds maximum .* characters"): - validate_bso_id(bso_id) + with pytest.raises(ValidationError, match="BSO ID length .* exceeds maximum"): + validate_bso_id("a" * (MAX_BSO_ID_LENGTH + 1)) - def test_bso_id_with_non_printable_ascii(self): + def test_bso_id_with_special_chars(self): + """BSO ID with printable ASCII characters should be valid""" + validate_bso_id("valid-bso-id_123.test") # Should not raise + + def test_bso_id_with_non_printable_chars(self): """BSO ID with non-printable ASCII should raise ValidationError""" - bso_id = "test\x00id" # Contains null character - with pytest.raises(ValidationError, match="non-printable ASCII character"): - validate_bso_id(bso_id) + with pytest.raises(ValidationError, match="non-printable ASCII"): + validate_bso_id("invalid\x00id") - def test_bso_id_with_tab(self): + def test_bso_id_with_tab_char(self): """BSO ID with tab character should raise ValidationError""" - bso_id = "test\tid" - with pytest.raises(ValidationError, match="non-printable ASCII character"): - validate_bso_id(bso_id) + with pytest.raises(ValidationError, match="non-printable ASCII"): + validate_bso_id("invalid\tid") - def test_bso_id_with_newline(self): - """BSO ID with newline should raise ValidationError""" - bso_id = "test\nid" - with pytest.raises(ValidationError, match="non-printable ASCII character"): - validate_bso_id(bso_id) + def test_bso_id_with_del_char(self): + """BSO ID with DEL character (0x7F) should raise ValidationError""" + with pytest.raises(ValidationError, match="non-printable ASCII"): + validate_bso_id("invalid\x7fid") - def test_bso_id_with_all_printable_ascii(self): - """BSO ID with all printable ASCII characters should be valid""" - # Printable ASCII: 0x20 (space) to 0x7E (~) - bso_id = "abc123 !@#$%^&*()_+-=[]{}|;:',.<>?/~" - validate_bso_id(bso_id) # Should not raise + def test_empty_bso_id(self): + """Empty BSO ID should be valid (length check passes)""" + validate_bso_id("") # Should not raise class TestValidateCollectionName: @@ -151,44 +87,204 @@ def test_valid_collection_name(self): """Valid collection name should not raise exception""" validate_collection_name("bookmarks") # Should not raise - def test_collection_name_with_underscore(self): - """Collection name with underscore should be valid""" - validate_collection_name("my_collection") # Should not raise - - def test_collection_name_with_hyphen(self): - """Collection name with hyphen should be valid""" - validate_collection_name("my-collection") # Should not raise + def test_collection_name_with_special_chars(self): + """Collection name with allowed special characters""" + validate_collection_name("my-collection_1.0") # Should not raise - def test_collection_name_with_period(self): - """Collection name with period should be valid""" - validate_collection_name("my.collection") # Should not raise + def test_collection_name_with_invalid_chars(self): + """Collection name with invalid characters should raise ValidationError""" + with pytest.raises(ValidationError, match="invalid character"): + validate_collection_name("invalid collection!") - def test_collection_name_with_mixed_chars(self): - """Collection name with mixed valid characters should be valid""" - validate_collection_name("My_Collection-123.test") # Should not raise + def test_collection_name_with_space(self): + """Collection name with space should raise ValidationError""" + with pytest.raises(ValidationError, match="invalid character"): + validate_collection_name("invalid name") def test_collection_name_at_max_length(self): - """Collection name at max length should be valid""" - name = "a" * MAX_COLLECTION_NAME_LENGTH - validate_collection_name(name) # Should not raise + """Collection name at exactly max length should be valid""" + validate_collection_name("a" * MAX_COLLECTION_NAME_LENGTH) # Should not raise def test_collection_name_exceeds_max_length(self): """Collection name exceeding max length should raise ValidationError""" - name = "a" * (MAX_COLLECTION_NAME_LENGTH + 1) - with pytest.raises(ValidationError, match="exceeds maximum .* characters"): - validate_collection_name(name) + with pytest.raises(ValidationError, match="Collection name length .* exceeds maximum"): + validate_collection_name("a" * (MAX_COLLECTION_NAME_LENGTH + 1)) + + +class TestBSOInput: + def test_all_fields_optional(self): + bso = BSOInput() + assert bso.id is None + assert bso.payload is None + assert bso.sortindex is None + assert bso.ttl is None + + def test_sortindex_at_bounds(self): + BSOInput(sortindex=999999999) + BSOInput(sortindex=-999999999) + + def test_sortindex_out_of_range(self): + with pytest.raises(PydanticValidationError): + BSOInput(sortindex=1000000000) + with pytest.raises(PydanticValidationError): + BSOInput(sortindex=-1000000000) + + def test_ttl_must_be_positive(self): + with pytest.raises(PydanticValidationError): + BSOInput(ttl=0) + with pytest.raises(PydanticValidationError): + BSOInput(ttl=-1) - def test_collection_name_with_space(self): - """Collection name with space should raise ValidationError""" - with pytest.raises(ValidationError, match="invalid character"): - validate_collection_name("my collection") - - def test_collection_name_with_special_char(self): - """Collection name with special character should raise ValidationError""" - with pytest.raises(ValidationError, match="invalid character"): - validate_collection_name("my@collection") + def test_ttl_at_max(self): + BSOInput(ttl=999999999) - def test_collection_name_with_slash(self): - """Collection name with slash should raise ValidationError""" - with pytest.raises(ValidationError, match="invalid character"): - validate_collection_name("my/collection") + def test_ttl_exceeds_max(self): + with pytest.raises(PydanticValidationError): + BSOInput(ttl=1000000000) + + def test_payload_accepts_large_string(self): + """Payload validation is byte-based (validate_payload_size), not char-based.""" + BSOInput(payload="a" * 262144) # no Pydantic char limit + + +class TestBSOOutputFromBso: + def test_from_bso(self): + from datetime import datetime, timezone + + bso = BasicStorageObject( + id="item1", + payload="data", + modified=datetime(2024, 1, 1, tzinfo=timezone.utc), + sortindex=100, + ) + output = BSOOutput.from_bso(bso) + assert output.id == "item1" + assert output.payload == "data" + assert output.modified == round(bso.modified.timestamp(), 2) + assert output.sortindex == 100 + + +class TestCamelModelAliasing: + def test_device_output_serializes_to_camel(self): + dev = DeviceOutput( + id="d1", + name="My Phone", + type="mobile", + push_callback="https://push.example.com", + created_at=1000, + last_access_time=2000, + ) + d = dev.model_dump(by_alias=True) + assert "pushCallback" in d + assert "pushPublicKey" in d + assert "createdAt" in d + assert "lastAccessTime" in d + # snake_case keys should NOT appear when by_alias=True + assert "push_callback" not in d + + def test_device_output_accepts_camel_input(self): + dev = DeviceOutput.model_validate( + { + "id": "d1", + "name": "Phone", + "type": "mobile", + "pushCallback": "https://push", + "createdAt": 100, + "lastAccessTime": 200, + } + ) + assert dev.push_callback == "https://push" + assert dev.created_at == 100 + + def test_device_output_accepts_snake_input(self): + dev = DeviceOutput( + id="d1", + name="Phone", + type="mobile", + push_callback="https://push", + created_at=100, + last_access_time=200, + ) + assert dev.push_callback == "https://push" + + +class TestBatchResultOutput: + def test_basic_creation(self): + br = BatchResultOutput( + success=["a", "b"], + failed={"c": ["error"]}, + modified=1.23, + ) + assert br.success == ["a", "b"] + assert br.failed == {"c": ["error"]} + assert br.modified == 1.23 + + +class TestCollectionDataOutput: + def test_basic_creation(self): + cd = CollectionDataOutput(name="bookmarks", modified=1.0, count=5, usage=1024) + assert cd.name == "bookmarks" + assert cd.count == 5 + + +class TestModifiedOutput: + def test_basic_creation(self): + m = ModifiedOutput(modified=1.23) + assert m.modified == 1.23 + + +class TestAccountCreateInput: + def test_valid(self): + pw = "a" * 64 + a = AccountCreateInput(email="user@example.com", auth_pw=pw) + assert a.auth_pw == pw + + def test_auth_pw_too_short(self): + with pytest.raises(PydanticValidationError): + AccountCreateInput(email="user@example.com", auth_pw="short") + + def test_auth_pw_too_long(self): + with pytest.raises(PydanticValidationError): + AccountCreateInput(email="user@example.com", auth_pw="a" * 65) + + +class TestDynamoModel: + def test_coerce_timestamps_with_non_dict_data(self): + """Cover the early-return branch when data is not a dict.""" + from src.shared.models import DynamoModel + + class Dummy(DynamoModel): + pass + + # Call the validator directly with a non-dict value + result = Dummy._coerce_timestamps("not-a-dict") # type: ignore[operator] + assert result == "not-a-dict" + + def test__to_dynamodb_dict_converts_float_to_decimal(self): + """Cover the float -> Decimal branch in _to_dynamodb_dict.""" + from src.shared.models import DynamoModel + + class FloatModel(DynamoModel): + value: float + + m = FloatModel(value=3.14) + dumped = m._to_dynamodb_dict() + assert isinstance(dumped["value"], Decimal) + assert dumped["value"] == Decimal("3.14") + + +class TestDeviceOutputDecimalFields: + def test_decimal_fields_convert_to_int(self): + dev = DeviceOutput.model_validate( + { + "id": "d1", + "name": "Phone", + "type": "mobile", + "created_at": Decimal("1000"), + "last_access_time": Decimal("2000"), + } + ) + assert isinstance(dev.created_at, int) + assert dev.created_at == 1000 + assert isinstance(dev.last_access_time, int) + assert dev.last_access_time == 2000 diff --git a/lambda/tests/shared/test_oidc.py b/lambda/tests/shared/test_oidc.py index d86038da..52f824df 100644 --- a/lambda/tests/shared/test_oidc.py +++ b/lambda/tests/shared/test_oidc.py @@ -1,6 +1,6 @@ """Tests for OIDC-related models""" -import json +from dataclasses import asdict from src.shared.oidc import ErrorDetail, OIDCProviderConfig, OIDCTokenClaims @@ -9,7 +9,6 @@ class TestOIDCTokenClaims: """Tests for OIDCTokenClaims model""" def test_creation_with_all_fields(self): - """Test creating OIDCTokenClaims with all fields including email""" claims = OIDCTokenClaims( sub="user123", iss="https://auth.example.com", @@ -27,7 +26,6 @@ def test_creation_with_all_fields(self): assert claims.email == "user@example.com" def test_creation_without_email(self): - """Test creating OIDCTokenClaims without optional email field""" claims = OIDCTokenClaims( sub="user456", iss="https://auth.example.com", @@ -39,59 +37,7 @@ def test_creation_without_email(self): assert claims.sub == "user456" assert claims.email is None - def test_to_json(self): - """Test serialization to JSON""" - claims = OIDCTokenClaims( - sub="user789", - iss="https://auth.example.com", - aud="client-id-789", - exp=1234567890, - iat=1234567800, - email="user789@example.com", - ) - - json_str = claims.to_json() # type: ignore[attr-defined] - data = json.loads(json_str) - - assert data["sub"] == "user789" - assert data["iss"] == "https://auth.example.com" - assert data["aud"] == "client-id-789" - assert data["exp"] == 1234567890 - assert data["iat"] == 1234567800 - assert data["email"] == "user789@example.com" - - def test_from_json(self): - """Test deserialization from JSON""" - json_str = '{"sub": "user999", "iss": "https://auth.example.com", "aud": "client-id", "exp": 1234567890, "iat": 1234567800, "email": "user999@example.com"}' - claims = OIDCTokenClaims.from_json(json_str) # type: ignore[attr-defined] - - assert claims.sub == "user999" - assert claims.iss == "https://auth.example.com" - assert claims.email == "user999@example.com" - - def test_round_trip_serialization(self): - """Test that serialization and deserialization are inverses""" - original = OIDCTokenClaims( - sub="roundtrip", - iss="https://auth.example.com", - aud="client-id", - exp=1234567890, - iat=1234567800, - email="roundtrip@example.com", - ) - - json_str = original.to_json() # type: ignore[attr-defined] - restored = OIDCTokenClaims.from_json(json_str) # type: ignore[attr-defined] - - assert restored.sub == original.sub - assert restored.iss == original.iss - assert restored.aud == original.aud - assert restored.exp == original.exp - assert restored.iat == original.iat - assert restored.email == original.email - def test_exp_greater_than_iat(self): - """Test that expiry is after issued at time""" claims = OIDCTokenClaims( sub="user", iss="https://auth.example.com", @@ -102,8 +48,7 @@ def test_exp_greater_than_iat(self): assert claims.exp > claims.iat - def test_to_dict(self): - """Test conversion to dictionary""" + def test_asdict(self): claims = OIDCTokenClaims( sub="dictuser", iss="https://auth.example.com", @@ -112,7 +57,7 @@ def test_to_dict(self): iat=1234567800, ) - data = claims.to_dict() # type: ignore[attr-defined] + data = asdict(claims) assert isinstance(data, dict) assert data["sub"] == "dictuser" @@ -123,7 +68,6 @@ class TestOIDCProviderConfig: """Tests for OIDCProviderConfig model""" def test_creation_with_all_fields(self): - """Test creating OIDCProviderConfig with all fields""" config = OIDCProviderConfig( issuer="https://auth.example.com", jwks_uri="https://auth.example.com/jwks", @@ -138,71 +82,11 @@ def test_creation_with_all_fields(self): assert config.token_endpoint == "https://auth.example.com/token" assert config.userinfo_endpoint == "https://auth.example.com/userinfo" - def test_to_json(self): - """Test serialization to JSON""" - config = OIDCProviderConfig( - issuer="https://auth.example.com", - jwks_uri="https://auth.example.com/jwks", - authorization_endpoint="https://auth.example.com/authorize", - token_endpoint="https://auth.example.com/token", - userinfo_endpoint="https://auth.example.com/userinfo", - ) - - json_str = config.to_json() # type: ignore[attr-defined] - data = json.loads(json_str) - - assert data["issuer"] == "https://auth.example.com" - assert data["jwks_uri"] == "https://auth.example.com/jwks" - assert data["authorization_endpoint"] == "https://auth.example.com/authorize" - - def test_from_json(self): - """Test deserialization from JSON""" - json_str = '{"issuer": "https://auth.example.com", "jwks_uri": "https://auth.example.com/jwks", "authorization_endpoint": "https://auth.example.com/authorize", "token_endpoint": "https://auth.example.com/token", "userinfo_endpoint": "https://auth.example.com/userinfo"}' - config = OIDCProviderConfig.from_json(json_str) # type: ignore[attr-defined] - - assert config.issuer == "https://auth.example.com" - assert config.jwks_uri == "https://auth.example.com/jwks" - - def test_round_trip_serialization(self): - """Test that serialization and deserialization are inverses""" - original = OIDCProviderConfig( - issuer="https://auth.example.com", - jwks_uri="https://auth.example.com/jwks", - authorization_endpoint="https://auth.example.com/authorize", - token_endpoint="https://auth.example.com/token", - userinfo_endpoint="https://auth.example.com/userinfo", - ) - - json_str = original.to_json() # type: ignore[attr-defined] - restored = OIDCProviderConfig.from_json(json_str) # type: ignore[attr-defined] - - assert restored.issuer == original.issuer - assert restored.jwks_uri == original.jwks_uri - assert restored.authorization_endpoint == original.authorization_endpoint - assert restored.token_endpoint == original.token_endpoint - assert restored.userinfo_endpoint == original.userinfo_endpoint - - def test_from_dict(self): - """Test creation from dictionary""" - data = { - "issuer": "https://auth.example.com", - "jwks_uri": "https://auth.example.com/jwks", - "authorization_endpoint": "https://auth.example.com/authorize", - "token_endpoint": "https://auth.example.com/token", - "userinfo_endpoint": "https://auth.example.com/userinfo", - } - - config = OIDCProviderConfig.from_dict(data) # type: ignore[attr-defined] - - assert config.issuer == "https://auth.example.com" - assert config.jwks_uri == "https://auth.example.com/jwks" - class TestErrorDetail: """Tests for ErrorDetail model""" def test_creation_with_all_fields(self): - """Test creating ErrorDetail with all fields""" error = ErrorDetail( location="header", name="Authorization", @@ -214,78 +98,20 @@ def test_creation_with_all_fields(self): assert error.description == "Missing authorization header" def test_creation_with_body_location(self): - """Test creating ErrorDetail with body location""" error = ErrorDetail(location="body", name="email", description="Invalid email format") - assert error.location == "body" assert error.name == "email" def test_creation_with_query_location(self): - """Test creating ErrorDetail with query location""" error = ErrorDetail(location="query", name="limit", description="Limit must be positive") - assert error.location == "query" - def test_to_json(self): - """Test serialization to JSON""" - error = ErrorDetail( - location="header", - name="Content-Type", - description="Invalid content type", - ) - - json_str = error.to_json() # type: ignore[attr-defined] - data = json.loads(json_str) - - assert data["location"] == "header" - assert data["name"] == "Content-Type" - assert data["description"] == "Invalid content type" - - def test_from_json(self): - """Test deserialization from JSON""" - json_str = '{"location": "body", "name": "password", "description": "Password too short"}' - error = ErrorDetail.from_json(json_str) # type: ignore[attr-defined] - - assert error.location == "body" - assert error.name == "password" - assert error.description == "Password too short" - - def test_round_trip_serialization(self): - """Test that serialization and deserialization are inverses""" - original = ErrorDetail( - location="query", - name="page", - description="Page number must be an integer", - ) - - json_str = original.to_json() # type: ignore[attr-defined] - restored = ErrorDetail.from_json(json_str) # type: ignore[attr-defined] - - assert restored.location == original.location - assert restored.name == original.name - assert restored.description == original.description - - def test_to_dict(self): - """Test conversion to dictionary""" + def test_asdict(self): error = ErrorDetail(location="header", name="Accept", description="Unsupported media type") - data = error.to_dict() # type: ignore[attr-defined] + data = asdict(error) assert isinstance(data, dict) assert data["location"] == "header" assert data["name"] == "Accept" assert data["description"] == "Unsupported media type" - - def test_from_dict(self): - """Test creation from dictionary""" - data = { - "location": "body", - "name": "username", - "description": "Username already exists", - } - - error = ErrorDetail.from_dict(data) # type: ignore[attr-defined] - - assert error.location == "body" - assert error.name == "username" - assert error.description == "Username already exists" diff --git a/lambda/tests/shared/test_token.py b/lambda/tests/shared/test_token.py index 0c78f592..17c97eb8 100644 --- a/lambda/tests/shared/test_token.py +++ b/lambda/tests/shared/test_token.py @@ -1,6 +1,6 @@ """Tests for TokenResponse model""" -import json +from dataclasses import asdict from src.shared.token import TokenResponse @@ -9,7 +9,6 @@ class TestTokenResponse: """Tests for TokenResponse model""" def test_creation_with_all_fields(self): - """Test creating TokenResponse with all fields""" token = TokenResponse( id="hawk_id_base64", key="hawk_key_hex_64_chars", @@ -27,7 +26,6 @@ def test_creation_with_all_fields(self): assert token.hashalg == "sha256" def test_duration_is_300_seconds(self): - """Test that duration is 300 seconds (5 minutes)""" token = TokenResponse( id="test_id", key="test_key", @@ -36,11 +34,9 @@ def test_duration_is_300_seconds(self): duration=300, hashalg="sha256", ) - assert token.duration == 300 def test_hashalg_is_sha256(self): - """Test that hash algorithm is sha256""" token = TokenResponse( id="test_id", key="test_key", @@ -49,11 +45,9 @@ def test_hashalg_is_sha256(self): duration=300, hashalg="sha256", ) - assert token.hashalg == "sha256" def test_api_endpoint_format(self): - """Test that api_endpoint follows expected format""" token = TokenResponse( id="test_id", key="test_key", @@ -62,67 +56,11 @@ def test_api_endpoint_format(self): duration=300, hashalg="sha256", ) - assert token.api_endpoint.startswith("https://") assert "/1.5/" in token.api_endpoint assert token.api_endpoint.endswith("user456") - def test_to_json(self): - """Test serialization to JSON""" - token = TokenResponse( - id="json_test_id", - key="json_test_key", - api_endpoint="https://sync.example.com/1.5/jsonuser", - uid=789, - duration=300, - hashalg="sha256", - ) - - json_str = token.to_json() # type: ignore[attr-defined] - data = json.loads(json_str) - - assert data["id"] == "json_test_id" - assert data["key"] == "json_test_key" - assert data["api_endpoint"] == "https://sync.example.com/1.5/jsonuser" - assert data["uid"] == 789 - assert data["duration"] == 300 - assert data["hashalg"] == "sha256" - - def test_from_json(self): - """Test deserialization from JSON""" - json_str = '{"id": "from_json_id", "key": "from_json_key", "api_endpoint": "https://sync.example.com/1.5/fromjson", "uid": 111, "duration": 300, "hashalg": "sha256"}' - token = TokenResponse.from_json(json_str) # type: ignore[attr-defined] - - assert token.id == "from_json_id" - assert token.key == "from_json_key" - assert token.api_endpoint == "https://sync.example.com/1.5/fromjson" - assert token.uid == 111 - assert token.duration == 300 - assert token.hashalg == "sha256" - - def test_round_trip_serialization(self): - """Test that serialization and deserialization are inverses""" - original = TokenResponse( - id="roundtrip_id", - key="roundtrip_key_hex", - api_endpoint="https://sync.example.com/1.5/roundtrip", - uid=999999, - duration=300, - hashalg="sha256", - ) - - json_str = original.to_json() # type: ignore[attr-defined] - restored = TokenResponse.from_json(json_str) # type: ignore[attr-defined] - - assert restored.id == original.id - assert restored.key == original.key - assert restored.api_endpoint == original.api_endpoint - assert restored.uid == original.uid - assert restored.duration == original.duration - assert restored.hashalg == original.hashalg - - def test_to_dict(self): - """Test conversion to dictionary""" + def test_asdict(self): token = TokenResponse( id="dict_id", key="dict_key", @@ -132,7 +70,7 @@ def test_to_dict(self): hashalg="sha256", ) - data = token.to_dict() # type: ignore[attr-defined] + data = asdict(token) assert isinstance(data, dict) assert data["id"] == "dict_id" @@ -142,28 +80,7 @@ def test_to_dict(self): assert data["duration"] == 300 assert data["hashalg"] == "sha256" - def test_from_dict(self): - """Test creation from dictionary""" - data = { - "id": "fromdict_id", - "key": "fromdict_key", - "api_endpoint": "https://sync.example.com/1.5/fromdictuser", - "uid": 777, - "duration": 300, - "hashalg": "sha256", - } - - token = TokenResponse.from_dict(data) # type: ignore[attr-defined] - - assert token.id == "fromdict_id" - assert token.key == "fromdict_key" - assert token.api_endpoint == "https://sync.example.com/1.5/fromdictuser" - assert token.uid == 777 - assert token.duration == 300 - assert token.hashalg == "sha256" - def test_uid_is_numeric(self): - """Test that uid is a numeric value""" token = TokenResponse( id="test_id", key="test_key", @@ -172,12 +89,10 @@ def test_uid_is_numeric(self): duration=300, hashalg="sha256", ) - assert isinstance(token.uid, int) assert token.uid > 0 def test_different_uids_for_different_users(self): - """Test that different tokens can have different uids""" token1 = TokenResponse( id="id1", key="key1", @@ -186,7 +101,6 @@ def test_different_uids_for_different_users(self): duration=300, hashalg="sha256", ) - token2 = TokenResponse( id="id2", key="key2", @@ -195,5 +109,4 @@ def test_different_uids_for_different_users(self): duration=300, hashalg="sha256", ) - assert token1.uid != token2.uid diff --git a/lambda/tests/shared/test_utils.py b/lambda/tests/shared/test_utils.py index 14279ba1..5f3e2b1f 100644 --- a/lambda/tests/shared/test_utils.py +++ b/lambda/tests/shared/test_utils.py @@ -1,58 +1,15 @@ """Tests for shared utility functions""" import re -from datetime import datetime, timezone -from decimal import Decimal -import pytest from aws_lambda_powertools.utilities.data_classes import APIGatewayProxyEvent from src.shared.utils import ( - DecimalEncoder, - datetime_decoder, - datetime_encoder, - decimal_to_float, extract_hawk_request_params, - float_to_decimal, get_weave_timestamp, - json_dumps, ) -class TestDatetimeEncoderDecoder: - """Test datetime encoding and decoding""" - - def test_datetime_encoder_returns_decimal(self): - """Test that datetime_encoder returns a Decimal""" - dt = datetime(2009, 2, 13, 23, 31, 30, tzinfo=timezone.utc) - result = datetime_encoder(dt) - assert isinstance(result, Decimal) - assert result == Decimal("1234567890.0") - - def test_datetime_decoder_returns_datetime(self): - """Test that datetime_decoder returns a datetime""" - timestamp = 1234567890.0 - result = datetime_decoder(timestamp) - assert isinstance(result, datetime) - assert result == datetime(2009, 2, 13, 23, 31, 30, tzinfo=timezone.utc) - - -class TestFloatDecimalConverters: - """Test float to Decimal conversion utilities""" - - def test_float_to_decimal(self): - """Test converting float to Decimal""" - result = float_to_decimal(123.45) - assert isinstance(result, Decimal) - assert result == Decimal("123.45") - - def test_decimal_to_float(self): - """Test converting Decimal to float""" - result = decimal_to_float(Decimal("123.45")) - assert isinstance(result, float) - assert result == 123.45 - - class TestWeaveTimestamp: """Test Weave timestamp generation (Requirements 9.1, 9.2)""" @@ -80,59 +37,6 @@ def test_get_weave_timestamp_precision(self): assert len(parts[1]) == 2, "Timestamp should have exactly 2 decimal places" -class TestDecimalEncoder: - """Test the custom JSON encoder for Decimal""" - - def test_decimal_encoder_handles_decimal(self): - """Test that DecimalEncoder converts Decimal to float""" - encoder = DecimalEncoder() - result = encoder.default(Decimal("123.45")) - assert isinstance(result, float) - assert result == 123.45 - - def test_decimal_encoder_raises_for_unsupported_type(self): - """Test that DecimalEncoder raises TypeError for unsupported types""" - encoder = DecimalEncoder() - with pytest.raises(TypeError): - encoder.default(object()) - - -class TestJsonDumps: - """Test the json_dumps wrapper function""" - - def test_json_dumps_handles_decimal(self): - """Test that json_dumps can serialize Decimal objects""" - data = {"value": Decimal("123.45"), "nested": {"amount": Decimal("67.89")}} - result = json_dumps(data) - assert isinstance(result, str) - assert "123.45" in result - assert "67.89" in result - - def test_json_dumps_handles_regular_types(self): - """Test that json_dumps works with regular types""" - data = {"string": "test", "number": 42, "boolean": True, "null": None} - result = json_dumps(data) - assert isinstance(result, str) - assert "test" in result - assert "42" in result - assert "true" in result - assert "null" in result - - def test_json_dumps_with_mixed_types(self): - """Test json_dumps with mixed Decimal and regular types""" - data = { - "decimal_value": Decimal("99.99"), - "int_value": 10, - "str_value": "hello", - "list_value": [1, 2, Decimal("3.14")], - } - result = json_dumps(data) - assert "99.99" in result - assert "10" in result - assert "hello" in result - assert "3.14" in result - - class TestExtractHawkRequestParams: """Test extract_hawk_request_params helper""" diff --git a/lambda/uv.lock b/lambda/uv.lock index 8dc54372..9d4c9229 100644 --- a/lambda/uv.lock +++ b/lambda/uv.lock @@ -1,7 +1,16 @@ version = 1 -revision = 3 +revision = 2 requires-python = ">=3.14" +[[package]] +name = "annotated-types" +version = "0.7.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ee/67/531ea369ba64dcff5ec9c3402f9f51bf748cec26dde048a2f973a4eea7f5/annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89", size = 16081, upload-time = "2024-05-20T21:33:25.928Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53", size = 13643, upload-time = "2024-05-20T21:33:24.1Z" }, +] + [[package]] name = "aws-lambda-powertools" version = "3.25.0" @@ -257,19 +266,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/48/ef/0c2f4a8e31018a986949d34a01115dd057bf536905dca38897bacd21fac3/cryptography-46.0.5-cp38-abi3-win_amd64.whl", hash = "sha256:556e106ee01aa13484ce9b0239bca667be5004efb0aabbed28d353df86445595", size = 3467050, upload-time = "2026-02-10T19:18:18.899Z" }, ] -[[package]] -name = "dataclasses-json" -version = "0.6.7" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "marshmallow" }, - { name = "typing-inspect" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/64/a4/f71d9cf3a5ac257c993b5ca3f93df5f7fb395c725e7f1e6479d2514173c3/dataclasses_json-0.6.7.tar.gz", hash = "sha256:b6b3e528266ea45b9535223bc53ca645f5208833c29229e847b3f26a1cc55fc0", size = 32227, upload-time = "2024-06-09T16:20:19.103Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/c3/be/d0d44e092656fe7a06b55e6103cbce807cdbdee17884a5367c68c9860853/dataclasses_json-0.6.7-py3-none-any.whl", hash = "sha256:0dbf33f26c8d5305befd61b39d2b3414e8a407bedc2834dea9b8d642666fb40a", size = 28686, upload-time = "2024-06-09T16:20:16.715Z" }, -] - [[package]] name = "execnet" version = "2.1.2" @@ -287,8 +283,8 @@ dependencies = [ { name = "aws-lambda-powertools" }, { name = "boto3" }, { name = "cryptography" }, - { name = "dataclasses-json" }, { name = "mohawk" }, + { name = "pydantic" }, { name = "pyjwt" }, { name = "requests" }, ] @@ -313,8 +309,8 @@ requires-dist = [ { name = "aws-lambda-powertools", specifier = "==3.25.0" }, { name = "boto3", specifier = "==1.42.57" }, { name = "cryptography", specifier = "==46.0.5" }, - { name = "dataclasses-json", specifier = "==0.6.7" }, { name = "mohawk", specifier = "==1.1.0" }, + { name = "pydantic", specifier = "==2.12.5" }, { name = "pyjwt", specifier = "==2.11.0" }, { name = "requests", specifier = "==2.32.5" }, ] @@ -441,18 +437,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b2/c8/d148e041732d631fc76036f8b30fae4e77b027a1e95b7a84bb522481a940/librt-0.8.1-cp314-cp314t-win_arm64.whl", hash = "sha256:bf512a71a23504ed08103a13c941f763db13fb11177beb3d9244c98c29fb4a61", size = 48755, upload-time = "2026-02-17T16:12:47.943Z" }, ] -[[package]] -name = "marshmallow" -version = "3.26.2" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "packaging" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/55/79/de6c16cc902f4fc372236926b0ce2ab7845268dcc30fb2fbb7f71b418631/marshmallow-3.26.2.tar.gz", hash = "sha256:bbe2adb5a03e6e3571b573f42527c6fe926e17467833660bebd11593ab8dfd57", size = 222095, upload-time = "2025-12-22T06:53:53.309Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/be/2f/5108cb3ee4ba6501748c4908b908e55f42a5b66245b4cfe0c99326e1ef6e/marshmallow-3.26.2-py3-none-any.whl", hash = "sha256:013fa8a3c4c276c24d26d84ce934dc964e2aa794345a0f8c7e5a7191482c8a73", size = 50964, upload-time = "2025-12-22T06:53:51.801Z" }, -] - [[package]] name = "mccabe" version = "0.7.0" @@ -558,6 +542,60 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0c/c3/44f3fbbfa403ea2a7c779186dc20772604442dde72947e7d01069cbe98e3/pycparser-3.0-py3-none-any.whl", hash = "sha256:b727414169a36b7d524c1c3e31839a521725078d7b2ff038656844266160a992", size = 48172, upload-time = "2026-01-21T14:26:50.693Z" }, ] +[[package]] +name = "pydantic" +version = "2.12.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "annotated-types" }, + { name = "pydantic-core" }, + { name = "typing-extensions" }, + { name = "typing-inspection" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/69/44/36f1a6e523abc58ae5f928898e4aca2e0ea509b5aa6f6f392a5d882be928/pydantic-2.12.5.tar.gz", hash = "sha256:4d351024c75c0f085a9febbb665ce8c0c6ec5d30e903bdb6394b7ede26aebb49", size = 821591, upload-time = "2025-11-26T15:11:46.471Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5a/87/b70ad306ebb6f9b585f114d0ac2137d792b48be34d732d60e597c2f8465a/pydantic-2.12.5-py3-none-any.whl", hash = "sha256:e561593fccf61e8a20fc46dfc2dfe075b8be7d0188df33f221ad1f0139180f9d", size = 463580, upload-time = "2025-11-26T15:11:44.605Z" }, +] + +[[package]] +name = "pydantic-core" +version = "2.41.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/71/70/23b021c950c2addd24ec408e9ab05d59b035b39d97cdc1130e1bce647bb6/pydantic_core-2.41.5.tar.gz", hash = "sha256:08daa51ea16ad373ffd5e7606252cc32f07bc72b28284b6bc9c6df804816476e", size = 460952, upload-time = "2025-11-04T13:43:49.098Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ea/28/46b7c5c9635ae96ea0fbb779e271a38129df2550f763937659ee6c5dbc65/pydantic_core-2.41.5-cp314-cp314-macosx_10_12_x86_64.whl", hash = "sha256:3f37a19d7ebcdd20b96485056ba9e8b304e27d9904d233d7b1015db320e51f0a", size = 2119622, upload-time = "2025-11-04T13:40:56.68Z" }, + { url = "https://files.pythonhosted.org/packages/74/1a/145646e5687e8d9a1e8d09acb278c8535ebe9e972e1f162ed338a622f193/pydantic_core-2.41.5-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:1d1d9764366c73f996edd17abb6d9d7649a7eb690006ab6adbda117717099b14", size = 1891725, upload-time = "2025-11-04T13:40:58.807Z" }, + { url = "https://files.pythonhosted.org/packages/23/04/e89c29e267b8060b40dca97bfc64a19b2a3cf99018167ea1677d96368273/pydantic_core-2.41.5-cp314-cp314-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:25e1c2af0fce638d5f1988b686f3b3ea8cd7de5f244ca147c777769e798a9cd1", size = 1915040, upload-time = "2025-11-04T13:41:00.853Z" }, + { url = "https://files.pythonhosted.org/packages/84/a3/15a82ac7bd97992a82257f777b3583d3e84bdb06ba6858f745daa2ec8a85/pydantic_core-2.41.5-cp314-cp314-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:506d766a8727beef16b7adaeb8ee6217c64fc813646b424d0804d67c16eddb66", size = 2063691, upload-time = "2025-11-04T13:41:03.504Z" }, + { url = "https://files.pythonhosted.org/packages/74/9b/0046701313c6ef08c0c1cf0e028c67c770a4e1275ca73131563c5f2a310a/pydantic_core-2.41.5-cp314-cp314-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4819fa52133c9aa3c387b3328f25c1facc356491e6135b459f1de698ff64d869", size = 2213897, upload-time = "2025-11-04T13:41:05.804Z" }, + { url = "https://files.pythonhosted.org/packages/8a/cd/6bac76ecd1b27e75a95ca3a9a559c643b3afcd2dd62086d4b7a32a18b169/pydantic_core-2.41.5-cp314-cp314-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2b761d210c9ea91feda40d25b4efe82a1707da2ef62901466a42492c028553a2", size = 2333302, upload-time = "2025-11-04T13:41:07.809Z" }, + { url = "https://files.pythonhosted.org/packages/4c/d2/ef2074dc020dd6e109611a8be4449b98cd25e1b9b8a303c2f0fca2f2bcf7/pydantic_core-2.41.5-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:22f0fb8c1c583a3b6f24df2470833b40207e907b90c928cc8d3594b76f874375", size = 2064877, upload-time = "2025-11-04T13:41:09.827Z" }, + { url = "https://files.pythonhosted.org/packages/18/66/e9db17a9a763d72f03de903883c057b2592c09509ccfe468187f2a2eef29/pydantic_core-2.41.5-cp314-cp314-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2782c870e99878c634505236d81e5443092fba820f0373997ff75f90f68cd553", size = 2180680, upload-time = "2025-11-04T13:41:12.379Z" }, + { url = "https://files.pythonhosted.org/packages/d3/9e/3ce66cebb929f3ced22be85d4c2399b8e85b622db77dad36b73c5387f8f8/pydantic_core-2.41.5-cp314-cp314-musllinux_1_1_aarch64.whl", hash = "sha256:0177272f88ab8312479336e1d777f6b124537d47f2123f89cb37e0accea97f90", size = 2138960, upload-time = "2025-11-04T13:41:14.627Z" }, + { url = "https://files.pythonhosted.org/packages/a6/62/205a998f4327d2079326b01abee48e502ea739d174f0a89295c481a2272e/pydantic_core-2.41.5-cp314-cp314-musllinux_1_1_armv7l.whl", hash = "sha256:63510af5e38f8955b8ee5687740d6ebf7c2a0886d15a6d65c32814613681bc07", size = 2339102, upload-time = "2025-11-04T13:41:16.868Z" }, + { url = "https://files.pythonhosted.org/packages/3c/0d/f05e79471e889d74d3d88f5bd20d0ed189ad94c2423d81ff8d0000aab4ff/pydantic_core-2.41.5-cp314-cp314-musllinux_1_1_x86_64.whl", hash = "sha256:e56ba91f47764cc14f1daacd723e3e82d1a89d783f0f5afe9c364b8bb491ccdb", size = 2326039, upload-time = "2025-11-04T13:41:18.934Z" }, + { url = "https://files.pythonhosted.org/packages/ec/e1/e08a6208bb100da7e0c4b288eed624a703f4d129bde2da475721a80cab32/pydantic_core-2.41.5-cp314-cp314-win32.whl", hash = "sha256:aec5cf2fd867b4ff45b9959f8b20ea3993fc93e63c7363fe6851424c8a7e7c23", size = 1995126, upload-time = "2025-11-04T13:41:21.418Z" }, + { url = "https://files.pythonhosted.org/packages/48/5d/56ba7b24e9557f99c9237e29f5c09913c81eeb2f3217e40e922353668092/pydantic_core-2.41.5-cp314-cp314-win_amd64.whl", hash = "sha256:8e7c86f27c585ef37c35e56a96363ab8de4e549a95512445b85c96d3e2f7c1bf", size = 2015489, upload-time = "2025-11-04T13:41:24.076Z" }, + { url = "https://files.pythonhosted.org/packages/4e/bb/f7a190991ec9e3e0ba22e4993d8755bbc4a32925c0b5b42775c03e8148f9/pydantic_core-2.41.5-cp314-cp314-win_arm64.whl", hash = "sha256:e672ba74fbc2dc8eea59fb6d4aed6845e6905fc2a8afe93175d94a83ba2a01a0", size = 1977288, upload-time = "2025-11-04T13:41:26.33Z" }, + { url = "https://files.pythonhosted.org/packages/92/ed/77542d0c51538e32e15afe7899d79efce4b81eee631d99850edc2f5e9349/pydantic_core-2.41.5-cp314-cp314t-macosx_10_12_x86_64.whl", hash = "sha256:8566def80554c3faa0e65ac30ab0932b9e3a5cd7f8323764303d468e5c37595a", size = 2120255, upload-time = "2025-11-04T13:41:28.569Z" }, + { url = "https://files.pythonhosted.org/packages/bb/3d/6913dde84d5be21e284439676168b28d8bbba5600d838b9dca99de0fad71/pydantic_core-2.41.5-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:b80aa5095cd3109962a298ce14110ae16b8c1aece8b72f9dafe81cf597ad80b3", size = 1863760, upload-time = "2025-11-04T13:41:31.055Z" }, + { url = "https://files.pythonhosted.org/packages/5a/f0/e5e6b99d4191da102f2b0eb9687aaa7f5bea5d9964071a84effc3e40f997/pydantic_core-2.41.5-cp314-cp314t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3006c3dd9ba34b0c094c544c6006cc79e87d8612999f1a5d43b769b89181f23c", size = 1878092, upload-time = "2025-11-04T13:41:33.21Z" }, + { url = "https://files.pythonhosted.org/packages/71/48/36fb760642d568925953bcc8116455513d6e34c4beaa37544118c36aba6d/pydantic_core-2.41.5-cp314-cp314t-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:72f6c8b11857a856bcfa48c86f5368439f74453563f951e473514579d44aa612", size = 2053385, upload-time = "2025-11-04T13:41:35.508Z" }, + { url = "https://files.pythonhosted.org/packages/20/25/92dc684dd8eb75a234bc1c764b4210cf2646479d54b47bf46061657292a8/pydantic_core-2.41.5-cp314-cp314t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5cb1b2f9742240e4bb26b652a5aeb840aa4b417c7748b6f8387927bc6e45e40d", size = 2218832, upload-time = "2025-11-04T13:41:37.732Z" }, + { url = "https://files.pythonhosted.org/packages/e2/09/f53e0b05023d3e30357d82eb35835d0f6340ca344720a4599cd663dca599/pydantic_core-2.41.5-cp314-cp314t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:bd3d54f38609ff308209bd43acea66061494157703364ae40c951f83ba99a1a9", size = 2327585, upload-time = "2025-11-04T13:41:40Z" }, + { url = "https://files.pythonhosted.org/packages/aa/4e/2ae1aa85d6af35a39b236b1b1641de73f5a6ac4d5a7509f77b814885760c/pydantic_core-2.41.5-cp314-cp314t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2ff4321e56e879ee8d2a879501c8e469414d948f4aba74a2d4593184eb326660", size = 2041078, upload-time = "2025-11-04T13:41:42.323Z" }, + { url = "https://files.pythonhosted.org/packages/cd/13/2e215f17f0ef326fc72afe94776edb77525142c693767fc347ed6288728d/pydantic_core-2.41.5-cp314-cp314t-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d0d2568a8c11bf8225044aa94409e21da0cb09dcdafe9ecd10250b2baad531a9", size = 2173914, upload-time = "2025-11-04T13:41:45.221Z" }, + { url = "https://files.pythonhosted.org/packages/02/7a/f999a6dcbcd0e5660bc348a3991c8915ce6599f4f2c6ac22f01d7a10816c/pydantic_core-2.41.5-cp314-cp314t-musllinux_1_1_aarch64.whl", hash = "sha256:a39455728aabd58ceabb03c90e12f71fd30fa69615760a075b9fec596456ccc3", size = 2129560, upload-time = "2025-11-04T13:41:47.474Z" }, + { url = "https://files.pythonhosted.org/packages/3a/b1/6c990ac65e3b4c079a4fb9f5b05f5b013afa0f4ed6780a3dd236d2cbdc64/pydantic_core-2.41.5-cp314-cp314t-musllinux_1_1_armv7l.whl", hash = "sha256:239edca560d05757817c13dc17c50766136d21f7cd0fac50295499ae24f90fdf", size = 2329244, upload-time = "2025-11-04T13:41:49.992Z" }, + { url = "https://files.pythonhosted.org/packages/d9/02/3c562f3a51afd4d88fff8dffb1771b30cfdfd79befd9883ee094f5b6c0d8/pydantic_core-2.41.5-cp314-cp314t-musllinux_1_1_x86_64.whl", hash = "sha256:2a5e06546e19f24c6a96a129142a75cee553cc018ffee48a460059b1185f4470", size = 2331955, upload-time = "2025-11-04T13:41:54.079Z" }, + { url = "https://files.pythonhosted.org/packages/5c/96/5fb7d8c3c17bc8c62fdb031c47d77a1af698f1d7a406b0f79aaa1338f9ad/pydantic_core-2.41.5-cp314-cp314t-win32.whl", hash = "sha256:b4ececa40ac28afa90871c2cc2b9ffd2ff0bf749380fbdf57d165fd23da353aa", size = 1988906, upload-time = "2025-11-04T13:41:56.606Z" }, + { url = "https://files.pythonhosted.org/packages/22/ed/182129d83032702912c2e2d8bbe33c036f342cc735737064668585dac28f/pydantic_core-2.41.5-cp314-cp314t-win_amd64.whl", hash = "sha256:80aa89cad80b32a912a65332f64a4450ed00966111b6615ca6816153d3585a8c", size = 1981607, upload-time = "2025-11-04T13:41:58.889Z" }, + { url = "https://files.pythonhosted.org/packages/9f/ed/068e41660b832bb0b1aa5b58011dea2a3fe0ba7861ff38c4d4904c1c1a99/pydantic_core-2.41.5-cp314-cp314t-win_arm64.whl", hash = "sha256:35b44f37a3199f771c3eaa53051bc8a70cd7b54f333531c59e29fd4db5d15008", size = 1974769, upload-time = "2025-11-04T13:42:01.186Z" }, +] + [[package]] name = "pyflakes" version = "3.4.0" @@ -757,16 +795,15 @@ wheels = [ ] [[package]] -name = "typing-inspect" -version = "0.9.0" +name = "typing-inspection" +version = "0.4.2" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "mypy-extensions" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/dc/74/1789779d91f1961fa9438e9a8710cdae6bd138c80d7303996933d117264a/typing_inspect-0.9.0.tar.gz", hash = "sha256:b23fc42ff6f6ef6954e4852c1fb512cdd18dbea03134f91f856a95ccc9461f78", size = 13825, upload-time = "2023-05-24T20:25:47.612Z" } +sdist = { url = "https://files.pythonhosted.org/packages/55/e3/70399cb7dd41c10ac53367ae42139cf4b1ca5f36bb3dc6c9d33acdb43655/typing_inspection-0.4.2.tar.gz", hash = "sha256:ba561c48a67c5958007083d386c3295464928b01faa735ab8547c5692e87f464", size = 75949, upload-time = "2025-10-01T02:14:41.687Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/65/f3/107a22063bf27bdccf2024833d3445f4eea42b2e598abfbd46f6a63b6cb0/typing_inspect-0.9.0-py3-none-any.whl", hash = "sha256:9ee6fc59062311ef8547596ab6b955e1b8aa46242d854bfc78f4f6b0eff35f9f", size = 8827, upload-time = "2023-05-24T20:25:45.287Z" }, + { url = "https://files.pythonhosted.org/packages/dc/9b/47798a6c91d8bdb567fe2698fe81e0c6b7cb7ef4d13da4114b41d239f65d/typing_inspection-0.4.2-py3-none-any.whl", hash = "sha256:4ed1cacbdc298c220f1bd249ed5287caa16f34d44ef4e9c3d0cbad5b521545e7", size = 14611, upload-time = "2025-10-01T02:14:40.154Z" }, ] [[package]]