Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
working-directory: backend
run: |
pip install ruff
ruff check . --output-format=concise
ruff check . --output-format=concise --exclude database/migrations

simulator-lint:
runs-on: ubuntu-latest
Expand Down
196 changes: 186 additions & 10 deletions backend/api/context.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,64 @@
import asyncio
import logging
from typing import Any

import strawberry
from auth import get_user_payload
from database.models.user import User
from database.models.location import LocationNode, location_organizations
from database.models.user import User, user_root_locations
from database.session import get_db_session
from fastapi import Depends
from sqlalchemy import select
from graphql import GraphQLError
from sqlalchemy import delete, select
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from starlette.requests import HTTPConnection
from strawberry.fastapi import BaseContext


class LockedAsyncSession:
def __init__(self, session: AsyncSession, lock: asyncio.Lock):
self._session = session
self._lock = lock

async def execute(self, *args, **kwargs):
async with self._lock:
return await self._session.execute(*args, **kwargs)

async def commit(self, *args, **kwargs):
async with self._lock:
return await self._session.commit(*args, **kwargs)

async def rollback(self, *args, **kwargs):
async with self._lock:
return await self._session.rollback(*args, **kwargs)

async def flush(self, *args, **kwargs):
async with self._lock:
return await self._session.flush(*args, **kwargs)

async def refresh(self, *args, **kwargs):
async with self._lock:
return await self._session.refresh(*args, **kwargs)

def add(self, *args, **kwargs):
return self._session.add(*args, **kwargs)

def __getattr__(self, name):
return getattr(self._session, name)


class Context(BaseContext):
def __init__(self, db: AsyncSession, user: "User | None" = None):
def __init__(self, db: AsyncSession, user: "User | None" = None, organizations: str | None = None):
super().__init__()
self.db = db
self._db = db
self.user = user
self.organizations = organizations
self._accessible_location_ids: set[str] | None = None
self._accessible_location_ids_lock = asyncio.Lock()
self._db_lock = asyncio.Lock()
self.db = LockedAsyncSession(db, self._db_lock)


Info = strawberry.Info[Context, Any]
Expand All @@ -28,6 +70,7 @@ async def get_context(
) -> Context:
user_payload = get_user_payload(connection)
db_user = None
organizations = None

if user_payload:
user_id = user_payload.get("sub")
Expand All @@ -39,13 +82,27 @@ async def get_context(
email = user_payload.get("email")
picture = user_payload.get("picture")

logger = logging.getLogger(__name__)
organizations_raw = user_payload.get("organization")

if organizations_raw is None:
scope = user_payload.get("scope", "")
has_org_scope = "organization" in scope.split() if scope else False
logger.warning(
f"Organization claim not found in token for user {user_payload.get('sub', 'unknown')}. "
f"Has organization scope: {has_org_scope}. "
f"Token scope: {scope}. "
f"Available claims: {sorted(user_payload.keys())}"
)

organizations = None
if organizations_raw:
if isinstance(organizations_raw, list):
organizations = ",".join(str(org) for org in organizations_raw)
org_list = [str(org) for org in organizations_raw if org]
if org_list:
organizations = ",".join(org_list)
else:
organizations = str(organizations_raw)
organizations = str(organizations_raw) if organizations_raw else None

if user_id:
result = await session.execute(
Expand All @@ -63,7 +120,6 @@ async def get_context(
lastname=lastname,
title="User",
avatar_url=picture,
organizations=organizations,
)
session.add(new_user)
await session.commit()
Expand All @@ -75,24 +131,144 @@ async def get_context(
select(User).where(User.id == user_id),
)
db_user = result.scalars().first()
except Exception as e:
await session.rollback()
raise GraphQLError(
"Failed to create user. Please contact an administrator if you believe this is an error.",
extensions={"code": "INTERNAL_SERVER_ERROR"},
) from e

if db_user and (
db_user.username != username
or db_user.firstname != firstname
or db_user.lastname != lastname
or db_user.email != email
or db_user.avatar_url != picture
or db_user.organizations != organizations
):
db_user.username = username
db_user.firstname = firstname
db_user.lastname = lastname
db_user.email = email
if picture:
db_user.avatar_url = picture
db_user.organizations = organizations
session.add(db_user)
await session.commit()
await session.refresh(db_user)

return Context(db=session, user=db_user)
if db_user:
try:

await _update_user_root_locations(
session,
db_user,
organizations,
)
except Exception as e:
raise GraphQLError(
"Failed to update user root locations. Please contact an administrator if you believe this is an error.",
extensions={"code": "INTERNAL_SERVER_ERROR"},
) from e

return Context(db=session, user=db_user, organizations=organizations)


async def _update_user_root_locations(
session: AsyncSession,
user: User,
organizations: str | None,
) -> None:
organization_ids: list[str] = []
if organizations:
organization_ids = [
org_id.strip()
for org_id in organizations.split(",")
if org_id.strip()
]

root_location_ids: list[str] = []

if organization_ids:
result = await session.execute(
select(LocationNode)
.join(
location_organizations,
LocationNode.id == location_organizations.c.location_id,
)
.where(
location_organizations.c.organization_id.in_(organization_ids),
),
)
found_locations = result.scalars().all()
root_location_ids = [loc.id for loc in found_locations]

found_org_ids = set()
for loc in found_locations:
org_result = await session.execute(
select(location_organizations.c.organization_id).where(
location_organizations.c.location_id == loc.id,
location_organizations.c.organization_id.in_(
organization_ids,
),
),
)
found_org_ids.update(row[0] for row in org_result.all())

for org_id in organization_ids:
if org_id not in found_org_ids:
new_location = LocationNode(
title=f"Organization {org_id[:8]}",
kind="CLINIC",
parent_id=None,
)
session.add(new_location)
await session.flush()
await session.refresh(new_location)

await session.execute(
location_organizations.insert().values(
location_id=new_location.id,
organization_id=org_id,
),
)
root_location_ids.append(new_location.id)

if not root_location_ids:
personal_org_title = f"{user.username}'s Organization"
result = await session.execute(
select(LocationNode).where(
LocationNode.title == personal_org_title,
LocationNode.parent_id.is_(None),
),
)
personal_location = result.scalars().first()

if not personal_location:
personal_location = LocationNode(
title=personal_org_title,
kind="CLINIC",
parent_id=None,
)
session.add(personal_location)
await session.flush()

root_location_ids = [personal_location.id]

await session.execute(
delete(user_root_locations).where(
user_root_locations.c.user_id == user.id,
),
)

if root_location_ids:
stmt = insert(user_root_locations).values(
[
{"user_id": user.id, "location_id": loc_id}
for loc_id in root_location_ids
],
)
stmt = stmt.on_conflict_do_nothing(
index_elements=["user_id", "location_id"],
)
await session.execute(stmt)

await session.commit()
6 changes: 3 additions & 3 deletions backend/api/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,12 @@ class CreatePatientInput:
sex: Sex
assigned_location_id: strawberry.ID | None = None
assigned_location_ids: list[strawberry.ID] | None = None
clinic_id: strawberry.ID # Required: location node from kind CLINIC
clinic_id: strawberry.ID
position_id: strawberry.ID | None = (
None # Optional: location node from type hospital, practice, clinic, ward, bed or room
None
)
team_ids: list[strawberry.ID] | None = (
None # Array: location nodes from type clinic, team, practice, hospital
None
)
properties: list[PropertyValueInput] | None = None
state: PatientState | None = None
Expand Down
5 changes: 3 additions & 2 deletions backend/api/resolvers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import strawberry

from .location import LocationQuery
from .location import LocationMutation, LocationQuery, LocationSubscription
from .patient import PatientMutation, PatientQuery, PatientSubscription
from .property import PropertyDefinitionMutation, PropertyDefinitionQuery
from .task import TaskMutation, TaskQuery, TaskSubscription
Expand All @@ -23,10 +23,11 @@ class Mutation(
PatientMutation,
TaskMutation,
PropertyDefinitionMutation,
LocationMutation,
):
pass


@strawberry.type
class Subscription(PatientSubscription, TaskSubscription):
class Subscription(PatientSubscription, TaskSubscription, LocationSubscription):
pass
Loading
Loading