Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lambda/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ dependencies = [
"aws-lambda-powertools==3.25.0",
"boto3==1.42.57",
"cryptography==46.0.5",
"dataclasses-json==0.6.7",
"pydantic==2.12.5",
"mohawk==1.1.0",
"PyJWT==2.11.0",
"requests==2.32.5"
Expand Down
24 changes: 15 additions & 9 deletions lambda/src/environment/service_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@

import boto3
from aws_lambda_powertools.event_handler import CORSConfig, Response
from aws_lambda_powertools.logging import Logger

from src.middlewares.hawk_auth import HawkAuthenticationError, HawkAuthMiddleware, UidMismatchError
from src.middlewares.request_logging import RequestLoggingMiddleware
from src.middlewares.weave_timestamp import WeaveTimestampMiddleware
from src.routes.auth.account_attached_clients import AccountAttachedClientsRoute
from src.routes.auth.account_create import AccountCreateRoute
from src.routes.auth.account_device import AccountDeviceRoute
Expand Down Expand Up @@ -40,14 +44,7 @@
from src.routes.storage.delete_all import DeleteAllStorageRoute
from src.routes.storage.delete_root import DeleteAllRootRoute
from src.routes.token.request import GetTokenRoute
from src.services.api_router import (
ApiRouter,
HawkAuthenticationError,
HawkAuthMiddleware,
RequestLoggingMiddleware,
UidMismatchError,
WeaveTimestampMiddleware,
)
from src.services.api_router import ApiRouter
from src.services.auth_account_manager import AuthAccountManager
from src.services.channel_service import ChannelService
from src.services.device_manager import DeviceManager
Expand Down Expand Up @@ -88,6 +85,9 @@ def wrapper(event, context, service_provider=None):
return wrapper


logger = Logger()


class ServiceProvider:
@cached_property
def aws_region(self): # pragma: nocover
Expand Down Expand Up @@ -158,7 +158,9 @@ def handle_hawk_auth(ex):
body=json.dumps({"code": 401, "errno": 110, "message": str(ex)}),
)

return {HawkAuthenticationError: handle_hawk_auth}
return {
HawkAuthenticationError: handle_hawk_auth,
}

@cached_property
def storage_api_router(self):
Expand Down Expand Up @@ -186,6 +188,7 @@ def storage_api_router(self):
WeaveTimestampMiddleware(),
],
exception_handlers=self._storage_exception_handlers,
enable_validation=True,
)

# Token API properties
Expand Down Expand Up @@ -380,6 +383,7 @@ def auth_api_router(self):
middlewares=[WeaveTimestampMiddleware()],
cors=self.cors_config,
exception_handlers=self._auth_exception_handlers,
enable_validation=True,
)

@cached_property
Expand All @@ -396,6 +400,7 @@ def token_api_router(self):
],
middlewares=[WeaveTimestampMiddleware()],
cors=self.cors_config,
enable_validation=True,
)

@cached_property
Expand All @@ -410,6 +415,7 @@ def profile_api_router(self):
],
middlewares=[WeaveTimestampMiddleware()],
cors=self.cors_config,
enable_validation=True,
)

# HAWK Authorizer properties
Expand Down
30 changes: 12 additions & 18 deletions lambda/src/routes/auth/account_attached_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from src.services.device_manager import DeviceManager
from src.shared.base_route import BaseRoute
from src.shared.utils import json_dumps
from src.shared.models import AttachedClientOutput, ClientListAdapter


class AccountAttachedClientsRoute(BaseRoute):
Expand All @@ -33,26 +33,20 @@ def handle(self, event) -> Response:
devices = self._device_manager.get_devices(uid)
clients = []
for d in devices:
clients.append(
{
"clientId": None,
"deviceId": d.get("id"),
"sessionTokenId": d.get("sessionTokenId"),
"refreshTokenId": None,
"isCurrentSession": d.get("sessionTokenId") == session_token_id,
"deviceType": d.get("type"),
"name": d.get("name"),
"createdTime": d.get("createdAt"),
"lastAccessTime": d.get("lastAccessTime"),
"scope": None,
"location": {},
"userAgent": "",
"os": None,
}
client = AttachedClientOutput(
device_id=d.get("id"),
session_token_id=d.get("sessionTokenId"),
is_current_session=d.get("sessionTokenId") == session_token_id,
device_type=d.get("type"),
name=d.get("name"),
created_time=d.get("createdAt"),
last_access_time=d.get("lastAccessTime"),
user_agent="",
)
clients.append(client)

return Response(
status_code=200,
content_type="application/json",
body=json_dumps(clients),
body=ClientListAdapter.dump_json(clients, by_alias=True).decode(),
)
16 changes: 8 additions & 8 deletions lambda/src/routes/auth/account_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from src.services.fxa_token_manager import FxATokenManager
from src.services.oidc_validator import OIDCValidator
from src.shared.base_route import BaseRoute
from src.shared.models import AccountCreateOutput

BEARER_PATTERN = re.compile(r"^Bearer\s+(.+)$", re.IGNORECASE)
AUTH_PW_PATTERN = re.compile(r"^[0-9a-f]{64}$")
Expand Down Expand Up @@ -102,17 +103,16 @@ def handle(self, event) -> Response:
session_token = self._token_manager.create_session_token(uid)
key_fetch_token = self._token_manager.create_key_fetch_token(uid)

result = AccountCreateOutput(
uid=uid,
session_token=session_token.hex(),
key_fetch_token=key_fetch_token.hex(),
verified=True,
)
return Response(
status_code=200,
content_type="application/json",
body=json.dumps(
{
"uid": uid,
"sessionToken": session_token.hex(),
"keyFetchToken": key_fetch_token.hex(),
"verified": True,
}
),
body=result.model_dump_json(by_alias=True),
)

@staticmethod
Expand Down
6 changes: 3 additions & 3 deletions lambda/src/routes/auth/account_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from src.services.device_manager import DeviceManager
from src.shared.base_route import BaseRoute
from src.shared.utils import json_dumps
from src.shared.models import DeviceOutput


class AccountDeviceRoute(BaseRoute):
Expand All @@ -32,9 +32,9 @@ def handle(self, event) -> Response:
session_token_id = event["requestContext"].get("hawk_token_id", "")
body = json.loads(event.body or "{}")
device = self._device_manager.upsert_device(uid, session_token_id, body)

result = DeviceOutput.model_validate(device)
return Response(
status_code=200,
content_type="application/json",
body=json_dumps(device),
body=result.model_dump_json(by_alias=True),
)
9 changes: 6 additions & 3 deletions lambda/src/routes/auth/account_devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from src.services.device_manager import DeviceManager
from src.shared.base_route import BaseRoute
from src.shared.utils import json_dumps
from src.shared.models import DeviceListAdapter, DeviceOutput


class AccountDevicesRoute(BaseRoute):
Expand Down Expand Up @@ -35,11 +35,14 @@ def handle(self, event) -> Response:
filter_idle = int(filter_ts) if filter_ts else None

devices = self._device_manager.get_devices(uid, filter_idle)
results = []
for d in devices:
d["isCurrentDevice"] = d.get("sessionTokenId") == session_token_id
device = DeviceOutput.model_validate(d)
device.is_current_device = d.get("sessionTokenId") == session_token_id
results.append(device)

return Response(
status_code=200,
content_type="application/json",
body=json_dumps(devices),
body=DeviceListAdapter.dump_json(results, by_alias=True).decode(),
)
4 changes: 3 additions & 1 deletion lambda/src/routes/auth/account_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from src.services.fxa_crypto import derive_key_request_key, encrypt_key_bundle
from src.services.fxa_token_manager import KEY_FETCH_TOKEN_INFO, FxATokenManager
from src.shared.base_route import BaseRoute
from src.shared.models import AccountKeysOutput
from src.shared.utils import extract_hawk_request_params


Expand Down Expand Up @@ -58,10 +59,11 @@ def handle(self, event) -> Response:
# Encrypt key bundle
bundle = encrypt_key_bundle(key_request_key, k_a, wrap_kb)

result = AccountKeysOutput(bundle=bundle.hex())
return Response(
status_code=200,
content_type="application/json",
body=json.dumps({"bundle": bundle.hex()}),
body=result.model_dump_json(),
)

@staticmethod
Expand Down
15 changes: 8 additions & 7 deletions lambda/src/routes/auth/account_login.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from src.services.fxa_crypto import constant_time_compare, derive_verify_hash
from src.services.fxa_token_manager import FxATokenManager
from src.shared.base_route import BaseRoute
from src.shared.models import AccountLoginOutput

AUTH_PW_PATTERN = re.compile(r"^[0-9a-f]{64}$")

Expand Down Expand Up @@ -74,22 +75,22 @@ def handle(self, event) -> Response:
# Create session token (always)
session_token = self._token_manager.create_session_token(uid)

result: dict = {
"uid": uid,
"sessionToken": session_token.hex(),
"verified": True,
}
result = AccountLoginOutput(
uid=uid,
session_token=session_token.hex(),
verified=True,
)

# Create key-fetch token if keys=true
params = event.query_string_parameters or {}
if params.get("keys") == "true":
key_fetch_token = self._token_manager.create_key_fetch_token(uid)
result["keyFetchToken"] = key_fetch_token.hex()
result.key_fetch_token = key_fetch_token.hex()

return Response(
status_code=200,
content_type="application/json",
body=json.dumps(result),
body=result.model_dump_json(by_alias=True, exclude_none=True),
)

@staticmethod
Expand Down
4 changes: 3 additions & 1 deletion lambda/src/routes/auth/account_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from src.services.auth_account_manager import AuthAccountManager
from src.shared.base_route import BaseRoute
from src.shared.models import AccountStatusOutput


class AccountStatusRoute(BaseRoute):
Expand All @@ -30,8 +31,9 @@ def handle(self, event) -> Response:
)

account = self._account_manager.get_account_by_email(email)
result = AccountStatusOutput(exists=account is not None)
return Response(
status_code=200,
content_type="application/json",
body=json.dumps({"exists": account is not None}),
body=result.model_dump_json(),
)
14 changes: 7 additions & 7 deletions lambda/src/routes/auth/oauth_authorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from src.services.oauth_code_manager import OAuthCodeManager
from src.shared.base_route import BaseRoute
from src.shared.models import OAuthAuthorizationOutput

ALLOWED_REDIRECT_URIS = {
"urn:ietf:wg:oauth:2.0:oob",
Expand Down Expand Up @@ -74,16 +75,15 @@ def handle(self, event) -> Response:
keys_jwe=keys_jwe,
)

result = OAuthAuthorizationOutput(
code=code,
state=state,
redirect=redirect_uri,
)
return Response(
status_code=200,
content_type="application/json",
body=json.dumps(
{
"code": code,
"state": state,
"redirect": redirect_uri,
}
),
body=result.model_dump_json(),
)

@staticmethod
Expand Down
Loading
Loading