diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 68c8266..30d6473 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -27,7 +27,7 @@ jobs: working-directory: backend run: | pip install ruff - ruff check . --output-format=concise + ruff check . --output-format=concise --exclude database/migrations simulator-lint: runs-on: ubuntu-latest diff --git a/backend/api/context.py b/backend/api/context.py index b736cc4..a3772ff 100644 --- a/backend/api/context.py +++ b/backend/api/context.py @@ -1,22 +1,64 @@ +import asyncio +import logging from typing import Any import strawberry from auth import get_user_payload -from database.models.user import User +from database.models.location import LocationNode, location_organizations +from database.models.user import User, user_root_locations from database.session import get_db_session from fastapi import Depends -from sqlalchemy import select +from graphql import GraphQLError +from sqlalchemy import delete, select +from sqlalchemy.dialects.postgresql import insert from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession from starlette.requests import HTTPConnection from strawberry.fastapi import BaseContext +class LockedAsyncSession: + def __init__(self, session: AsyncSession, lock: asyncio.Lock): + self._session = session + self._lock = lock + + async def execute(self, *args, **kwargs): + async with self._lock: + return await self._session.execute(*args, **kwargs) + + async def commit(self, *args, **kwargs): + async with self._lock: + return await self._session.commit(*args, **kwargs) + + async def rollback(self, *args, **kwargs): + async with self._lock: + return await self._session.rollback(*args, **kwargs) + + async def flush(self, *args, **kwargs): + async with self._lock: + return await self._session.flush(*args, **kwargs) + + async def refresh(self, *args, **kwargs): + async with self._lock: + return await self._session.refresh(*args, **kwargs) + + def add(self, *args, **kwargs): + return self._session.add(*args, **kwargs) + + def __getattr__(self, name): + return getattr(self._session, name) + + class Context(BaseContext): - def __init__(self, db: AsyncSession, user: "User | None" = None): + def __init__(self, db: AsyncSession, user: "User | None" = None, organizations: str | None = None): super().__init__() - self.db = db + self._db = db self.user = user + self.organizations = organizations + self._accessible_location_ids: set[str] | None = None + self._accessible_location_ids_lock = asyncio.Lock() + self._db_lock = asyncio.Lock() + self.db = LockedAsyncSession(db, self._db_lock) Info = strawberry.Info[Context, Any] @@ -28,6 +70,7 @@ async def get_context( ) -> Context: user_payload = get_user_payload(connection) db_user = None + organizations = None if user_payload: user_id = user_payload.get("sub") @@ -39,13 +82,27 @@ async def get_context( email = user_payload.get("email") picture = user_payload.get("picture") + logger = logging.getLogger(__name__) organizations_raw = user_payload.get("organization") + + if organizations_raw is None: + scope = user_payload.get("scope", "") + has_org_scope = "organization" in scope.split() if scope else False + logger.warning( + f"Organization claim not found in token for user {user_payload.get('sub', 'unknown')}. " + f"Has organization scope: {has_org_scope}. " + f"Token scope: {scope}. " + f"Available claims: {sorted(user_payload.keys())}" + ) + organizations = None if organizations_raw: if isinstance(organizations_raw, list): - organizations = ",".join(str(org) for org in organizations_raw) + org_list = [str(org) for org in organizations_raw if org] + if org_list: + organizations = ",".join(org_list) else: - organizations = str(organizations_raw) + organizations = str(organizations_raw) if organizations_raw else None if user_id: result = await session.execute( @@ -63,7 +120,6 @@ async def get_context( lastname=lastname, title="User", avatar_url=picture, - organizations=organizations, ) session.add(new_user) await session.commit() @@ -75,6 +131,12 @@ async def get_context( select(User).where(User.id == user_id), ) db_user = result.scalars().first() + except Exception as e: + await session.rollback() + raise GraphQLError( + "Failed to create user. Please contact an administrator if you believe this is an error.", + extensions={"code": "INTERNAL_SERVER_ERROR"}, + ) from e if db_user and ( db_user.username != username @@ -82,7 +144,6 @@ async def get_context( or db_user.lastname != lastname or db_user.email != email or db_user.avatar_url != picture - or db_user.organizations != organizations ): db_user.username = username db_user.firstname = firstname @@ -90,9 +151,124 @@ async def get_context( db_user.email = email if picture: db_user.avatar_url = picture - db_user.organizations = organizations session.add(db_user) await session.commit() await session.refresh(db_user) - return Context(db=session, user=db_user) + if db_user: + try: + + await _update_user_root_locations( + session, + db_user, + organizations, + ) + except Exception as e: + raise GraphQLError( + "Failed to update user root locations. Please contact an administrator if you believe this is an error.", + extensions={"code": "INTERNAL_SERVER_ERROR"}, + ) from e + + return Context(db=session, user=db_user, organizations=organizations) + + +async def _update_user_root_locations( + session: AsyncSession, + user: User, + organizations: str | None, +) -> None: + organization_ids: list[str] = [] + if organizations: + organization_ids = [ + org_id.strip() + for org_id in organizations.split(",") + if org_id.strip() + ] + + root_location_ids: list[str] = [] + + if organization_ids: + result = await session.execute( + select(LocationNode) + .join( + location_organizations, + LocationNode.id == location_organizations.c.location_id, + ) + .where( + location_organizations.c.organization_id.in_(organization_ids), + ), + ) + found_locations = result.scalars().all() + root_location_ids = [loc.id for loc in found_locations] + + found_org_ids = set() + for loc in found_locations: + org_result = await session.execute( + select(location_organizations.c.organization_id).where( + location_organizations.c.location_id == loc.id, + location_organizations.c.organization_id.in_( + organization_ids, + ), + ), + ) + found_org_ids.update(row[0] for row in org_result.all()) + + for org_id in organization_ids: + if org_id not in found_org_ids: + new_location = LocationNode( + title=f"Organization {org_id[:8]}", + kind="CLINIC", + parent_id=None, + ) + session.add(new_location) + await session.flush() + await session.refresh(new_location) + + await session.execute( + location_organizations.insert().values( + location_id=new_location.id, + organization_id=org_id, + ), + ) + root_location_ids.append(new_location.id) + + if not root_location_ids: + personal_org_title = f"{user.username}'s Organization" + result = await session.execute( + select(LocationNode).where( + LocationNode.title == personal_org_title, + LocationNode.parent_id.is_(None), + ), + ) + personal_location = result.scalars().first() + + if not personal_location: + personal_location = LocationNode( + title=personal_org_title, + kind="CLINIC", + parent_id=None, + ) + session.add(personal_location) + await session.flush() + + root_location_ids = [personal_location.id] + + await session.execute( + delete(user_root_locations).where( + user_root_locations.c.user_id == user.id, + ), + ) + + if root_location_ids: + stmt = insert(user_root_locations).values( + [ + {"user_id": user.id, "location_id": loc_id} + for loc_id in root_location_ids + ], + ) + stmt = stmt.on_conflict_do_nothing( + index_elements=["user_id", "location_id"], + ) + await session.execute(stmt) + + await session.commit() diff --git a/backend/api/inputs.py b/backend/api/inputs.py index 03c7827..4936f82 100644 --- a/backend/api/inputs.py +++ b/backend/api/inputs.py @@ -69,12 +69,12 @@ class CreatePatientInput: sex: Sex assigned_location_id: strawberry.ID | None = None assigned_location_ids: list[strawberry.ID] | None = None - clinic_id: strawberry.ID # Required: location node from kind CLINIC + clinic_id: strawberry.ID position_id: strawberry.ID | None = ( - None # Optional: location node from type hospital, practice, clinic, ward, bed or room + None ) team_ids: list[strawberry.ID] | None = ( - None # Array: location nodes from type clinic, team, practice, hospital + None ) properties: list[PropertyValueInput] | None = None state: PatientState | None = None diff --git a/backend/api/resolvers/__init__.py b/backend/api/resolvers/__init__.py index a8312ab..228d89e 100644 --- a/backend/api/resolvers/__init__.py +++ b/backend/api/resolvers/__init__.py @@ -1,6 +1,6 @@ import strawberry -from .location import LocationQuery +from .location import LocationMutation, LocationQuery, LocationSubscription from .patient import PatientMutation, PatientQuery, PatientSubscription from .property import PropertyDefinitionMutation, PropertyDefinitionQuery from .task import TaskMutation, TaskQuery, TaskSubscription @@ -23,10 +23,11 @@ class Mutation( PatientMutation, TaskMutation, PropertyDefinitionMutation, + LocationMutation, ): pass @strawberry.type -class Subscription(PatientSubscription, TaskSubscription): +class Subscription(PatientSubscription, TaskSubscription, LocationSubscription): pass diff --git a/backend/api/resolvers/location.py b/backend/api/resolvers/location.py index 819b84a..54673d7 100644 --- a/backend/api/resolvers/location.py +++ b/backend/api/resolvers/location.py @@ -1,8 +1,14 @@ +from collections.abc import AsyncGenerator + import strawberry +from api.audit import audit_log from api.context import Info -from api.inputs import LocationType +from api.inputs import CreateLocationNodeInput, LocationType, UpdateLocationNodeInput +from api.resolvers.base import BaseMutationResolver, BaseSubscriptionResolver +from api.services.authorization import AuthorizationService from api.types.location import LocationNodeType from database import models +from graphql import GraphQLError from sqlalchemy import select @@ -10,10 +16,21 @@ class LocationQuery: @strawberry.field async def location_roots(self, info: Info) -> list[LocationNodeType]: + auth_service = AuthorizationService(info.context.db) + accessible_location_ids = await auth_service.get_user_accessible_location_ids( + info.context.user, info.context + ) + + if not accessible_location_ids: + return [] + result = await info.context.db.execute( - select(models.LocationNode).where( + select(models.LocationNode) + .where( models.LocationNode.parent_id.is_(None), - ), + models.LocationNode.id.in_(accessible_location_ids), + ) + .distinct() ) return result.scalars().all() @@ -26,7 +43,20 @@ async def location_node( result = await info.context.db.execute( select(models.LocationNode).where(models.LocationNode.id == id), ) - return result.scalars().first() + location = result.scalars().first() + + if location: + auth_service = AuthorizationService(info.context.db) + accessible_location_ids = await auth_service.get_user_accessible_location_ids( + info.context.user, info.context + ) + if location.id not in accessible_location_ids: + raise GraphQLError( + "Insufficient permission. Please contact an administrator if you believe this is an error.", + extensions={"code": "FORBIDDEN"}, + ) + + return location @strawberry.field async def location_nodes( @@ -40,7 +70,21 @@ async def location_nodes( ) -> list[LocationNodeType]: db = info.context.db + auth_service = AuthorizationService(db) + accessible_location_ids = await auth_service.get_user_accessible_location_ids( + info.context.user, info.context + ) + + if not accessible_location_ids: + return [] + if recursive and parent_id: + if parent_id not in accessible_location_ids: + raise GraphQLError( + "Insufficient permission. Please contact an administrator if you believe this is an error.", + extensions={"code": "FORBIDDEN"}, + ) + cte = ( select(models.LocationNode) .where(models.LocationNode.id == parent_id) @@ -52,10 +96,17 @@ async def location_nodes( models.LocationNode.parent_id == cte.c.id, ) cte = cte.union_all(parent) - query = select(cte) + query = select(cte).where(cte.c.id.in_(accessible_location_ids)) else: - query = select(models.LocationNode) + query = select(models.LocationNode).where( + models.LocationNode.id.in_(accessible_location_ids) + ) if parent_id: + if parent_id not in accessible_location_ids: + raise GraphQLError( + "Insufficient permission. Please contact an administrator if you believe this is an error.", + extensions={"code": "FORBIDDEN"}, + ) query = query.where(models.LocationNode.parent_id == parent_id) if kind: @@ -69,3 +120,160 @@ async def location_nodes( result = await db.execute(query) return result.scalars().all() + + +@strawberry.type +class LocationMutation(BaseMutationResolver[models.LocationNode]): + @strawberry.mutation + @audit_log("create_location_node") + async def create_location_node( + self, + info: Info, + data: CreateLocationNodeInput, + ) -> LocationNodeType: + db = info.context.db + auth_service = AuthorizationService(db) + accessible_location_ids = await auth_service.get_user_accessible_location_ids( + info.context.user, info.context + ) + + if not accessible_location_ids: + raise GraphQLError( + "Insufficient permission. Please contact an administrator if you believe this is an error.", + extensions={"code": "FORBIDDEN"}, + ) + + if data.parent_id and data.parent_id not in accessible_location_ids: + raise GraphQLError( + "Insufficient permission. Please contact an administrator if you believe this is an error.", + extensions={"code": "FORBIDDEN"}, + ) + + location = models.LocationNode( + title=data.title, + kind=data.kind.value, + parent_id=data.parent_id, + ) + + location = await BaseMutationResolver.create_and_notify( + info, location, models.LocationNode, "location_node" + ) + return location + + @strawberry.mutation + @audit_log("update_location_node") + async def update_location_node( + self, + info: Info, + id: strawberry.ID, + data: UpdateLocationNodeInput, + ) -> LocationNodeType: + db = info.context.db + auth_service = AuthorizationService(db) + accessible_location_ids = await auth_service.get_user_accessible_location_ids( + info.context.user, info.context + ) + + if not accessible_location_ids: + raise GraphQLError( + "Insufficient permission. Please contact an administrator if you believe this is an error.", + extensions={"code": "FORBIDDEN"}, + ) + + result = await db.execute( + select(models.LocationNode).where(models.LocationNode.id == id) + ) + location = result.scalars().first() + + if not location: + raise GraphQLError( + "Location not found.", + extensions={"code": "NOT_FOUND"}, + ) + + if location.id not in accessible_location_ids: + raise GraphQLError( + "Insufficient permission. Please contact an administrator if you believe this is an error.", + extensions={"code": "FORBIDDEN"}, + ) + + if data.parent_id is not None and data.parent_id not in accessible_location_ids: + raise GraphQLError( + "Insufficient permission. Please contact an administrator if you believe this is an error.", + extensions={"code": "FORBIDDEN"}, + ) + + if data.title is not None: + location.title = data.title + if data.kind is not None: + location.kind = data.kind.value + if data.parent_id is not None: + location.parent_id = data.parent_id + + location = await BaseMutationResolver.update_and_notify( + info, location, models.LocationNode, "location_node" + ) + return location + + @strawberry.mutation + @audit_log("delete_location_node") + async def delete_location_node(self, info: Info, id: strawberry.ID) -> bool: + db = info.context.db + auth_service = AuthorizationService(db) + accessible_location_ids = await auth_service.get_user_accessible_location_ids( + info.context.user, info.context + ) + + if not accessible_location_ids: + raise GraphQLError( + "Insufficient permission. Please contact an administrator if you believe this is an error.", + extensions={"code": "FORBIDDEN"}, + ) + + result = await db.execute( + select(models.LocationNode).where(models.LocationNode.id == id) + ) + location = result.scalars().first() + + if not location: + raise GraphQLError( + "Location not found.", + extensions={"code": "NOT_FOUND"}, + ) + + if location.id not in accessible_location_ids: + raise GraphQLError( + "Insufficient permission. Please contact an administrator if you believe this is an error.", + extensions={"code": "FORBIDDEN"}, + ) + + await BaseMutationResolver.delete_entity( + info, location, models.LocationNode, "location_node" + ) + return True + + +@strawberry.type +class LocationSubscription(BaseSubscriptionResolver): + @strawberry.subscription + async def location_node_created( + self, info: Info + ) -> AsyncGenerator[strawberry.ID, None]: + async for location_id in BaseSubscriptionResolver.entity_created(info, "location_node"): + yield location_id + + @strawberry.subscription + async def location_node_updated( + self, + info: Info, + location_id: strawberry.ID | None = None, + ) -> AsyncGenerator[strawberry.ID, None]: + async for updated_id in BaseSubscriptionResolver.entity_updated(info, "location_node", location_id): + yield updated_id + + @strawberry.subscription + async def location_node_deleted( + self, info: Info + ) -> AsyncGenerator[strawberry.ID, None]: + async for location_id in BaseSubscriptionResolver.entity_deleted(info, "location_node"): + yield location_id diff --git a/backend/api/resolvers/patient.py b/backend/api/resolvers/patient.py index bb07f57..cb3357f 100644 --- a/backend/api/resolvers/patient.py +++ b/backend/api/resolvers/patient.py @@ -5,11 +5,13 @@ from api.context import Info from api.inputs import CreatePatientInput, PatientState, UpdatePatientInput from api.resolvers.base import BaseMutationResolver, BaseSubscriptionResolver +from api.services.authorization import AuthorizationService from api.services.checksum import validate_checksum from api.services.location import LocationService from api.services.property import PropertyService from api.types.patient import PatientType from database import models +from graphql import GraphQLError from sqlalchemy import select from sqlalchemy.orm import aliased, selectinload @@ -31,13 +33,22 @@ async def patient( selectinload(models.Patient.teams), ), ) - return result.scalars().first() + patient = result.scalars().first() + if patient: + auth_service = AuthorizationService(info.context.db) + if not await auth_service.can_access_patient(info.context.user, patient, info.context): + raise GraphQLError( + "Insufficient permission. Please contact an administrator if you believe this is an error.", + extensions={"code": "FORBIDDEN"}, + ) + return patient @strawberry.field async def patients( self, info: Info, location_node_id: strawberry.ID | None = None, + root_location_ids: list[strawberry.ID] | None = None, states: list[PatientState] | None = None, ) -> list[PatientType]: query = select(models.Patient).options( @@ -53,37 +64,73 @@ async def patients( query = query.where( models.Patient.state == PatientState.ADMITTED.value ) - if location_node_id: - cte = ( + auth_service = AuthorizationService(info.context.db) + accessible_location_ids = await auth_service.get_user_accessible_location_ids( + info.context.user, info.context + ) + + # If user has no accessible locations, return empty list + if not accessible_location_ids: + return [] + + query = auth_service.filter_patients_by_access( + info.context.user, query, accessible_location_ids + ) + + filter_cte = None + if root_location_ids: + # Filter to only include root_location_ids that the user has access to + valid_root_location_ids = [lid for lid in root_location_ids if lid in accessible_location_ids] + if not valid_root_location_ids: + # If none of the requested root_location_ids are accessible, return empty list + return [] + root_location_ids = valid_root_location_ids + filter_cte = ( + select(models.LocationNode.id) + .where(models.LocationNode.id.in_(root_location_ids)) + .cte(name="root_location_descendants", recursive=True) + ) + root_children = select(models.LocationNode.id).join( + filter_cte, models.LocationNode.parent_id == filter_cte.c.id + ) + filter_cte = filter_cte.union_all(root_children) + elif location_node_id: + filter_cte = ( select(models.LocationNode.id) .where(models.LocationNode.id == location_node_id) .cte(name="location_descendants", recursive=True) ) - parent = select(models.LocationNode.id).join( - cte, - models.LocationNode.parent_id == cte.c.id, + filter_cte, + models.LocationNode.parent_id == filter_cte.c.id, ) - cte = cte.union_all(parent) + filter_cte = filter_cte.union_all(parent) - patient_locations = aliased(models.patient_locations) - patient_teams = aliased(models.patient_teams) + if filter_cte is not None: + patient_locations_filter = aliased(models.patient_locations) + patient_teams_filter = aliased(models.patient_teams) query = ( query.outerjoin( - patient_locations, - models.Patient.id == patient_locations.c.patient_id, + patient_locations_filter, + models.Patient.id == patient_locations_filter.c.patient_id, ) .outerjoin( - patient_teams, - models.Patient.id == patient_teams.c.patient_id, + patient_teams_filter, + models.Patient.id == patient_teams_filter.c.patient_id, ) .where( - (models.Patient.assigned_location_id.in_(select(cte.c.id))) - | (patient_locations.c.location_id.in_(select(cte.c.id))) - | (models.Patient.clinic_id.in_(select(cte.c.id))) - | (models.Patient.position_id.in_(select(cte.c.id))) - | (patient_teams.c.location_id.in_(select(cte.c.id))), + (models.Patient.clinic_id.in_(select(filter_cte.c.id))) + | ( + models.Patient.position_id.isnot(None) + & models.Patient.position_id.in_(select(filter_cte.c.id)) + ) + | ( + models.Patient.assigned_location_id.isnot(None) + & models.Patient.assigned_location_id.in_(select(filter_cte.c.id)) + ) + | (patient_locations_filter.c.location_id.in_(select(filter_cte.c.id))) + | (patient_teams_filter.c.location_id.in_(select(filter_cte.c.id))) ) .distinct() ) @@ -106,6 +153,13 @@ async def recent_patients( ) .limit(limit) ) + auth_service = AuthorizationService(info.context.db) + accessible_location_ids = await auth_service.get_user_accessible_location_ids( + info.context.user, info.context + ) + query = auth_service.filter_patients_by_access( + info.context.user, query, accessible_location_ids + ) result = await info.context.db.execute(query) return result.scalars().all() @@ -133,13 +187,41 @@ async def create_patient( data.state.value if data.state else PatientState.WAIT.value ) + auth_service = AuthorizationService(db) + accessible_location_ids = await auth_service.get_user_accessible_location_ids( + info.context.user, info.context + ) + + if not accessible_location_ids: + raise GraphQLError( + "Insufficient permission. Please contact an administrator if you believe this is an error.", + extensions={"code": "FORBIDDEN"}, + ) + + if data.clinic_id not in accessible_location_ids: + raise GraphQLError( + "Insufficient permission. Please contact an administrator if you believe this is an error.", + extensions={"code": "FORBIDDEN"}, + ) + await location_service.validate_and_get_clinic(data.clinic_id) if data.position_id: + if data.position_id not in accessible_location_ids: + raise GraphQLError( + "Insufficient permission. Please contact an administrator if you believe this is an error.", + extensions={"code": "FORBIDDEN"}, + ) await location_service.validate_and_get_position(data.position_id) teams = [] if data.team_ids: + for team_id in data.team_ids: + if team_id not in accessible_location_ids: + raise GraphQLError( + "Insufficient permission. Please contact an administrator if you believe this is an error.", + extensions={"code": "FORBIDDEN"}, + ) teams = await location_service.validate_and_get_teams( data.team_ids ) @@ -159,11 +241,22 @@ async def create_patient( new_patient.teams = teams if data.assigned_location_ids: + for loc_id in data.assigned_location_ids: + if loc_id not in accessible_location_ids: + raise GraphQLError( + "Insufficient permission. Please contact an administrator if you believe this is an error.", + extensions={"code": "FORBIDDEN"}, + ) locations = await location_service.get_locations_by_ids( data.assigned_location_ids ) new_patient.assigned_locations = locations elif data.assigned_location_id: + if data.assigned_location_id not in accessible_location_ids: + raise GraphQLError( + "Insufficient permission. Please contact an administrator if you believe this is an error.", + extensions={"code": "FORBIDDEN"}, + ) location = await location_service.get_location_by_id( data.assigned_location_id ) @@ -204,6 +297,13 @@ async def update_patient( if not patient: raise Exception("Patient not found") + auth_service = AuthorizationService(db) + if not await auth_service.can_access_patient(info.context.user, patient, info.context): + raise GraphQLError( + "Forbidden: You do not have access to this patient", + extensions={"code": "FORBIDDEN"}, + ) + if data.checksum: validate_checksum(patient, data.checksum, "Patient") @@ -217,8 +317,16 @@ async def update_patient( patient.sex = data.sex.value location_service = PatientMutation._get_location_service(db) + accessible_location_ids = await auth_service.get_user_accessible_location_ids( + info.context.user + ) if data.clinic_id is not None: + if data.clinic_id not in accessible_location_ids: + raise GraphQLError( + "Insufficient permission. Please contact an administrator if you believe this is an error.", + extensions={"code": "FORBIDDEN"}, + ) await location_service.validate_and_get_clinic(data.clinic_id) patient.clinic_id = data.clinic_id @@ -226,6 +334,11 @@ async def update_patient( if data.position_id is None: patient.position_id = None else: + if data.position_id not in accessible_location_ids: + raise GraphQLError( + "Insufficient permission. Please contact an administrator if you believe this is an error.", + extensions={"code": "FORBIDDEN"}, + ) await location_service.validate_and_get_position( data.position_id ) @@ -235,16 +348,33 @@ async def update_patient( if data.team_ids is None or len(data.team_ids) == 0: patient.teams = [] else: + for team_id in data.team_ids: + if team_id not in accessible_location_ids: + raise GraphQLError( + "Insufficient permission. Please contact an administrator if you believe this is an error.", + extensions={"code": "FORBIDDEN"}, + ) patient.teams = await location_service.validate_and_get_teams( data.team_ids ) if data.assigned_location_ids is not None: + for loc_id in data.assigned_location_ids: + if loc_id not in accessible_location_ids: + raise GraphQLError( + "Insufficient permission. Please contact an administrator if you believe this is an error.", + extensions={"code": "FORBIDDEN"}, + ) locations = await location_service.get_locations_by_ids( data.assigned_location_ids ) patient.assigned_locations = locations elif data.assigned_location_id is not None: + if data.assigned_location_id not in accessible_location_ids: + raise GraphQLError( + "Insufficient permission. Please contact an administrator if you believe this is an error.", + extensions={"code": "FORBIDDEN"}, + ) location = await location_service.get_location_by_id( data.assigned_location_id ) @@ -269,6 +399,14 @@ async def delete_patient(self, info: Info, id: strawberry.ID) -> bool: patient = await repo.get_by_id(id) if not patient: return False + + auth_service = AuthorizationService(info.context.db) + if not await auth_service.can_access_patient(info.context.user, patient, info.context): + raise GraphQLError( + "Forbidden: You do not have access to this patient", + extensions={"code": "FORBIDDEN"}, + ) + await BaseMutationResolver.delete_entity( info, patient, models.Patient, "patient" ) @@ -293,6 +431,14 @@ async def _update_patient_state( patient = result.scalars().first() if not patient: raise Exception("Patient not found") + + auth_service = AuthorizationService(db) + if not await auth_service.can_access_patient(info.context.user, patient, info.context): + raise GraphQLError( + "Forbidden: You do not have access to this patient", + extensions={"code": "FORBIDDEN"}, + ) + patient.state = state.value await BaseMutationResolver.update_and_notify( info, patient, models.Patient, "patient" diff --git a/backend/api/resolvers/task.py b/backend/api/resolvers/task.py index 83dd089..fef46d2 100644 --- a/backend/api/resolvers/task.py +++ b/backend/api/resolvers/task.py @@ -5,21 +5,35 @@ from api.context import Info from api.inputs import CreateTaskInput, UpdateTaskInput from api.resolvers.base import BaseMutationResolver, BaseSubscriptionResolver -from api.services.base import BaseRepository +from api.services.authorization import AuthorizationService from api.services.checksum import validate_checksum from api.services.datetime import normalize_datetime_to_utc from api.services.property import PropertyService from api.types.task import TaskType from database import models +from graphql import GraphQLError from sqlalchemy import desc, select +from sqlalchemy.orm import aliased, selectinload @strawberry.type class TaskQuery: @strawberry.field async def task(self, info: Info, id: strawberry.ID) -> TaskType | None: - repo = BaseRepository(info.context.db, models.Task) - return await repo.get_by_id(id) + result = await info.context.db.execute( + select(models.Task) + .where(models.Task.id == id) + .options(selectinload(models.Task.patient).selectinload(models.Patient.assigned_locations)) + ) + task = result.scalars().first() + if task and task.patient: + auth_service = AuthorizationService(info.context.db) + if not await auth_service.can_access_patient(info.context.user, task.patient, info.context): + raise GraphQLError( + "Insufficient permission. Please contact an administrator if you believe this is an error.", + extensions={"code": "FORBIDDEN"}, + ) + return task @strawberry.field async def tasks( @@ -27,10 +41,97 @@ async def tasks( info: Info, patient_id: strawberry.ID | None = None, assignee_id: strawberry.ID | None = None, + root_location_ids: list[strawberry.ID] | None = None, ) -> list[TaskType]: - query = select(models.Task) + auth_service = AuthorizationService(info.context.db) + if patient_id: - query = query.where(models.Task.patient_id == patient_id) + if not await auth_service.can_access_patient_id(info.context.user, patient_id, info.context): + raise GraphQLError( + "Insufficient permission. Please contact an administrator if you believe this is an error.", + extensions={"code": "FORBIDDEN"}, + ) + + query = select(models.Task).options( + selectinload(models.Task.patient).selectinload(models.Patient.assigned_locations) + ).where(models.Task.patient_id == patient_id) + + if assignee_id: + query = query.where(models.Task.assignee_id == assignee_id) + + result = await info.context.db.execute(query) + return result.scalars().all() + + accessible_location_ids = await auth_service.get_user_accessible_location_ids( + info.context.user, info.context + ) + + if not accessible_location_ids: + return [] + + patient_locations = aliased(models.patient_locations) + patient_teams = aliased(models.patient_teams) + + cte = ( + select(models.LocationNode.id) + .where(models.LocationNode.id.in_(accessible_location_ids)) + .cte(name="accessible_locations", recursive=True) + ) + + children = select(models.LocationNode.id).join( + cte, models.LocationNode.parent_id == cte.c.id + ) + cte = cte.union_all(children) + + if root_location_ids: + invalid_ids = [lid for lid in root_location_ids if lid not in accessible_location_ids] + if invalid_ids: + raise GraphQLError( + "Insufficient permission. Please contact an administrator if you believe this is an error.", + extensions={"code": "FORBIDDEN"}, + ) + root_cte = ( + select(models.LocationNode.id) + .where(models.LocationNode.id.in_(root_location_ids)) + .cte(name="root_location_descendants", recursive=True) + ) + root_children = select(models.LocationNode.id).join( + root_cte, models.LocationNode.parent_id == root_cte.c.id + ) + root_cte = root_cte.union_all(root_children) + else: + root_cte = cte + + query = ( + select(models.Task) + .options( + selectinload(models.Task.patient).selectinload(models.Patient.assigned_locations) + ) + .join(models.Patient, models.Task.patient_id == models.Patient.id) + .outerjoin( + patient_locations, + models.Patient.id == patient_locations.c.patient_id, + ) + .outerjoin( + patient_teams, + models.Patient.id == patient_teams.c.patient_id, + ) + .where( + (models.Patient.clinic_id.in_(select(root_cte.c.id))) + | ( + models.Patient.position_id.isnot(None) + & models.Patient.position_id.in_(select(root_cte.c.id)) + ) + | ( + models.Patient.assigned_location_id.isnot(None) + & models.Patient.assigned_location_id.in_(select(root_cte.c.id)) + ) + | (patient_locations.c.location_id.in_(select(root_cte.c.id))) + | (patient_teams.c.location_id.in_(select(root_cte.c.id))) + ) + .distinct() + ) + if assignee_id: query = query.where(models.Task.assignee_id == assignee_id) @@ -43,11 +144,61 @@ async def recent_tasks( info: Info, limit: int = 10, ) -> list[TaskType]: - result = await info.context.db.execute( + auth_service = AuthorizationService(info.context.db) + accessible_location_ids = await auth_service.get_user_accessible_location_ids( + info.context.user, info.context + ) + + if not accessible_location_ids: + return [] + + patient_locations = aliased(models.patient_locations) + patient_teams = aliased(models.patient_teams) + + cte = ( + select(models.LocationNode.id) + .where(models.LocationNode.id.in_(accessible_location_ids)) + .cte(name="accessible_locations", recursive=True) + ) + + children = select(models.LocationNode.id).join( + cte, models.LocationNode.parent_id == cte.c.id + ) + cte = cte.union_all(children) + + query = ( select(models.Task) + .options( + selectinload(models.Task.patient).selectinload(models.Patient.assigned_locations) + ) + .join(models.Patient, models.Task.patient_id == models.Patient.id) + .outerjoin( + patient_locations, + models.Patient.id == patient_locations.c.patient_id, + ) + .outerjoin( + patient_teams, + models.Patient.id == patient_teams.c.patient_id, + ) + .where( + (models.Patient.clinic_id.in_(select(cte.c.id))) + | ( + models.Patient.position_id.isnot(None) + & models.Patient.position_id.in_(select(cte.c.id)) + ) + | ( + models.Patient.assigned_location_id.isnot(None) + & models.Patient.assigned_location_id.in_(select(cte.c.id)) + ) + | (patient_locations.c.location_id.in_(select(cte.c.id))) + | (patient_teams.c.location_id.in_(select(cte.c.id))) + ) .order_by(desc(models.Task.update_date)) - .limit(limit), + .limit(limit) + .distinct() ) + + result = await info.context.db.execute(query) return result.scalars().all() @@ -60,6 +211,13 @@ def _get_property_service(db) -> PropertyService: @strawberry.mutation @audit_log("create_task") async def create_task(self, info: Info, data: CreateTaskInput) -> TaskType: + auth_service = AuthorizationService(info.context.db) + if not await auth_service.can_access_patient_id(info.context.user, data.patient_id, info.context): + raise GraphQLError( + "Insufficient permission. Please contact an administrator if you believe this is an error.", + extensions={"code": "FORBIDDEN"}, + ) + new_task = models.Task( title=data.title, description=data.description, @@ -92,8 +250,22 @@ async def update_task( data: UpdateTaskInput, ) -> TaskType: db = info.context.db - repo = BaseMutationResolver.get_repository(db, models.Task) - task = await repo.get_by_id_or_raise(id, "Task not found") + result = await db.execute( + select(models.Task) + .where(models.Task.id == id) + .options(selectinload(models.Task.patient).selectinload(models.Patient.assigned_locations)) + ) + task = result.scalars().first() + if not task: + raise Exception("Task not found") + + if task.patient: + auth_service = AuthorizationService(db) + if not await auth_service.can_access_patient(info.context.user, task.patient, info.context): + raise GraphQLError( + "Insufficient permission. Please contact an administrator if you believe this is an error.", + extensions={"code": "FORBIDDEN"}, + ) if data.checksum: validate_checksum(task, data.checksum, "Task") @@ -134,8 +306,23 @@ async def _update_task_field( field_updater, ) -> TaskType: db = info.context.db - repo = BaseMutationResolver.get_repository(db, models.Task) - task = await repo.get_by_id_or_raise(id, "Task not found") + result = await db.execute( + select(models.Task) + .where(models.Task.id == id) + .options(selectinload(models.Task.patient).selectinload(models.Patient.assigned_locations)) + ) + task = result.scalars().first() + if not task: + raise Exception("Task not found") + + if task.patient: + auth_service = AuthorizationService(db) + if not await auth_service.can_access_patient(info.context.user, task.patient, info.context): + raise GraphQLError( + "Insufficient permission. Please contact an administrator if you believe this is an error.", + extensions={"code": "FORBIDDEN"}, + ) + field_updater(task) await BaseMutationResolver.update_and_notify( info, task, models.Task, "task", "patient", task.patient_id @@ -187,11 +374,23 @@ async def reopen_task(self, info: Info, id: strawberry.ID) -> TaskType: @audit_log("delete_task") async def delete_task(self, info: Info, id: strawberry.ID) -> bool: db = info.context.db - repo = BaseMutationResolver.get_repository(db, models.Task) - task = await repo.get_by_id(id) + result = await db.execute( + select(models.Task) + .where(models.Task.id == id) + .options(selectinload(models.Task.patient).selectinload(models.Patient.assigned_locations)) + ) + task = result.scalars().first() if not task: return False + if task.patient: + auth_service = AuthorizationService(db) + if not await auth_service.can_access_patient(info.context.user, task.patient, info.context): + raise GraphQLError( + "Insufficient permission. Please contact an administrator if you believe this is an error.", + extensions={"code": "FORBIDDEN"}, + ) + patient_id = task.patient_id await BaseMutationResolver.delete_entity( info, task, models.Task, "task", "patient", patient_id diff --git a/backend/api/services/authorization.py b/backend/api/services/authorization.py new file mode 100644 index 0000000..8844967 --- /dev/null +++ b/backend/api/services/authorization.py @@ -0,0 +1,174 @@ +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import aliased, selectinload + +from database import models + + +class AuthorizationService: + def __init__(self, db: AsyncSession): + self.db = db + + async def get_user_accessible_location_ids( + self, user: models.User | None, context=None + ) -> set[str]: + if context and hasattr(context, '_accessible_location_ids') and context._accessible_location_ids is not None: + return context._accessible_location_ids + + if not context or not hasattr(context, '_accessible_location_ids_lock'): + return await self._compute_accessible_location_ids(user, context) + + async with context._accessible_location_ids_lock: + if context._accessible_location_ids is not None: + return context._accessible_location_ids + return await self._compute_accessible_location_ids(user, context) + + async def _compute_accessible_location_ids( + self, user: models.User | None, context=None + ) -> set[str]: + if not user: + result = set() + if context: + context._accessible_location_ids = result + return result + + result = await self.db.execute( + select(models.user_root_locations.c.location_id).where( + models.user_root_locations.c.user_id == user.id + ) + ) + rows = result.fetchall() + root_location_ids = {row[0] for row in rows} + + if not root_location_ids: + result = set() + if context: + context._accessible_location_ids = result + return result + + cte = ( + select(models.LocationNode.id) + .where(models.LocationNode.id.in_(root_location_ids)) + .cte(name="accessible_locations", recursive=True) + ) + + children = select(models.LocationNode.id).join( + cte, models.LocationNode.parent_id == cte.c.id + ) + cte = cte.union_all(children) + + result = await self.db.execute(select(cte.c.id)) + rows = result.fetchall() + accessible_ids = {row[0] for row in rows} + + if context: + context._accessible_location_ids = accessible_ids + + return accessible_ids + + async def can_access_patient( + self, user: models.User | None, patient: models.Patient, context=None + ) -> bool: + if not user: + return False + + accessible_location_ids = await self.get_user_accessible_location_ids(user, context) + + if not accessible_location_ids: + return False + + if patient.clinic_id in accessible_location_ids: + return True + + if patient.position_id and patient.position_id in accessible_location_ids: + return True + + if ( + patient.assigned_location_id + and patient.assigned_location_id in accessible_location_ids + ): + return True + + if patient.assigned_locations: + for location in patient.assigned_locations: + if location.id in accessible_location_ids: + return True + + if patient.teams: + for team in patient.teams: + if team.id in accessible_location_ids: + return True + + return False + + async def can_access_patient_id( + self, user: models.User | None, patient_id: str, context=None + ) -> bool: + if not user: + return False + + result = await self.db.execute( + select(models.Patient) + .where(models.Patient.id == patient_id) + .options( + selectinload(models.Patient.assigned_locations), + selectinload(models.Patient.teams), + ) + ) + patient = result.scalars().first() + + if not patient: + return False + + return await self.can_access_patient(user, patient, context) + + def filter_patients_by_access( + self, user: models.User | None, query, accessible_location_ids: set[str] | None = None + ): + if not user: + return query.where(False) + + if accessible_location_ids is None: + return query + + if not accessible_location_ids: + return query.where(False) + + cte = ( + select(models.LocationNode.id) + .where(models.LocationNode.id.in_(accessible_location_ids)) + .cte(name="accessible_locations", recursive=True) + ) + + children = select(models.LocationNode.id).join( + cte, models.LocationNode.parent_id == cte.c.id + ) + cte = cte.union_all(children) + + patient_locations = aliased(models.patient_locations) + patient_teams = aliased(models.patient_teams) + + return ( + query.outerjoin( + patient_locations, + models.Patient.id == patient_locations.c.patient_id, + ) + .outerjoin( + patient_teams, + models.Patient.id == patient_teams.c.patient_id, + ) + .where( + (models.Patient.clinic_id.in_(select(cte.c.id))) + | ( + models.Patient.position_id.isnot(None) + & models.Patient.position_id.in_(select(cte.c.id)) + ) + | ( + models.Patient.assigned_location_id.isnot(None) + & models.Patient.assigned_location_id.in_(select(cte.c.id)) + ) + | (patient_locations.c.location_id.in_(select(cte.c.id))) + | (patient_teams.c.location_id.in_(select(cte.c.id))) + ) + .distinct() + ) diff --git a/backend/api/types/location.py b/backend/api/types/location.py index 6704a2a..3fe2bd3 100644 --- a/backend/api/types/location.py +++ b/backend/api/types/location.py @@ -63,3 +63,12 @@ async def patients( ), ) return result.scalars().all() + + @strawberry.field + async def organization_ids(self, info: Info) -> list[str]: + result = await info.context.db.execute( + select(models.location_organizations.c.organization_id).where( + models.location_organizations.c.location_id == self.id, + ), + ) + return [row[0] for row in result.all()] diff --git a/backend/api/types/user.py b/backend/api/types/user.py index d0b3a18..85578ed 100644 --- a/backend/api/types/user.py +++ b/backend/api/types/user.py @@ -5,6 +5,7 @@ from sqlalchemy import select if TYPE_CHECKING: + from api.types.location import LocationNodeType from api.types.task import TaskType @@ -17,7 +18,6 @@ class UserType: lastname: str | None title: str | None avatar_url: str | None - organizations: str | None @strawberry.field def name(self) -> str: @@ -25,13 +25,116 @@ def name(self) -> str: return f"{self.firstname} {self.lastname}" return self.username + @strawberry.field + def organizations(self, info) -> str | None: + """Get organizations from the context""" + return info.context.organizations + @strawberry.field async def tasks( self, info, ) -> list[Annotated["TaskType", strawberry.lazy("api.types.task")]]: + from api.services.authorization import AuthorizationService - result = await info.context.db.execute( - select(models.Task).where(models.Task.assignee_id == self.id), + auth_service = AuthorizationService(info.context.db) + accessible_location_ids = await auth_service.get_user_accessible_location_ids( + info.context.user, info.context + ) + + if not accessible_location_ids: + return [] + + from sqlalchemy.orm import aliased + patient_locations = aliased(models.patient_locations) + patient_teams = aliased(models.patient_teams) + + from sqlalchemy import select + cte = ( + select(models.LocationNode.id) + .where(models.LocationNode.id.in_(accessible_location_ids)) + .cte(name="accessible_locations", recursive=True) + ) + + children = select(models.LocationNode.id).join( + cte, models.LocationNode.parent_id == cte.c.id ) + cte = cte.union_all(children) + + query = ( + select(models.Task) + .join(models.Patient, models.Task.patient_id == models.Patient.id) + .outerjoin( + patient_locations, + models.Patient.id == patient_locations.c.patient_id, + ) + .outerjoin( + patient_teams, + models.Patient.id == patient_teams.c.patient_id, + ) + .where( + models.Task.assignee_id == self.id, + ( + (models.Patient.clinic_id.in_(select(cte.c.id))) + | ( + models.Patient.position_id.isnot(None) + & models.Patient.position_id.in_(select(cte.c.id)) + ) + | ( + models.Patient.assigned_location_id.isnot(None) + & models.Patient.assigned_location_id.in_(select(cte.c.id)) + ) + | (patient_locations.c.location_id.in_(select(cte.c.id))) + | (patient_teams.c.location_id.in_(select(cte.c.id))) + ) + ) + .distinct() + ) + + result = await info.context.db.execute(query) return result.scalars().all() + + @strawberry.field + async def root_locations( + self, + info, + ) -> list[Annotated["LocationNodeType", strawberry.lazy("api.types.location")]]: + import logging + logger = logging.getLogger(__name__) + + # First check what's in user_root_locations table + user_root_check = await info.context.db.execute( + select(models.user_root_locations.c.location_id).where( + models.user_root_locations.c.user_id == self.id + ) + ) + user_root_location_ids = [row[0] for row in user_root_check.all()] + logger.info(f"User {self.id} has {len(user_root_location_ids)} entries in user_root_locations: {user_root_location_ids}") + + result = await info.context.db.execute( + select(models.LocationNode) + .join( + models.user_root_locations, + models.LocationNode.id == models.user_root_locations.c.location_id, + ) + .where(models.user_root_locations.c.user_id == self.id) + .distinct() + ) + locations = result.scalars().all() + logger.info(f"User {self.id} root_locations query returned {len(locations)} locations: {[loc.id for loc in locations]}") + + # If we have user_root_locations entries but no locations returned, check if locations exist + if user_root_location_ids and not locations: + location_check = await info.context.db.execute( + select(models.LocationNode).where( + models.LocationNode.id.in_(user_root_location_ids) + ) + ) + existing_locations = location_check.scalars().all() + logger.warning( + f"User {self.id} has {len(user_root_location_ids)} root location IDs but query returned empty. " + f"Checking if locations exist: {[loc.id for loc in existing_locations]} " + f"with parent_ids: {[loc.parent_id for loc in existing_locations]}" + ) + + return locations diff --git a/backend/database/migrations/versions/0de3078888ba_merge_location_type_enum_and_remove_.py b/backend/database/migrations/versions/0de3078888ba_merge_location_type_enum_and_remove_.py new file mode 100644 index 0000000..1b31f78 --- /dev/null +++ b/backend/database/migrations/versions/0de3078888ba_merge_location_type_enum_and_remove_.py @@ -0,0 +1,28 @@ +"""Merge location type enum and remove organizations + +Revision ID: 0de3078888ba +Revises: add_location_type_enum, remove_organizations_from_users +Create Date: 2025-12-23 22:12:52.604315 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '0de3078888ba' +down_revision: Union[str, Sequence[str], None] = ('add_location_type_enum', 'remove_organizations_from_users') +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + pass + + +def downgrade() -> None: + """Downgrade schema.""" + pass diff --git a/backend/database/migrations/versions/add_location_organizations_table.py b/backend/database/migrations/versions/add_location_organizations_table.py new file mode 100644 index 0000000..c6e6c56 --- /dev/null +++ b/backend/database/migrations/versions/add_location_organizations_table.py @@ -0,0 +1,35 @@ +"""Add location organizations table. + +Revision ID: add_location_organizations_table +Revises: add_user_root_locations +Create Date: 2025-01-17 00:00:00.000000 +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "add_location_organizations_table" +down_revision: Union[str, Sequence[str], None] = "add_user_root_locations" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + op.create_table( + "location_organizations", + sa.Column("location_id", sa.String(), nullable=False), + sa.Column("organization_id", sa.String(), nullable=False), + sa.ForeignKeyConstraint(["location_id"], ["location_nodes.id"]), + sa.PrimaryKeyConstraint("location_id", "organization_id"), + ) + + +def downgrade() -> None: + """Downgrade schema.""" + op.drop_table("location_organizations") + diff --git a/backend/database/migrations/versions/add_location_type_enum.py b/backend/database/migrations/versions/add_location_type_enum.py new file mode 100644 index 0000000..cda3918 --- /dev/null +++ b/backend/database/migrations/versions/add_location_type_enum.py @@ -0,0 +1,51 @@ +"""Add location type enum. + +Revision ID: add_location_type_enum +Revises: add_location_organizations_table +Create Date: 2025-01-17 00:00:00.000000 +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "add_location_type_enum" +down_revision: Union[str, Sequence[str], None] = "add_location_organizations_table" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + op.alter_column( + 'location_nodes', + 'kind', + type_=sa.Enum( + 'HOSPITAL', 'PRACTICE', 'CLINIC', 'TEAM', 'WARD', 'ROOM', 'BED', 'OTHER', + name='locationtypeenum', + native_enum=False, + length=50 + ), + existing_type=sa.String(), + existing_nullable=False + ) + + +def downgrade() -> None: + """Downgrade schema.""" + op.alter_column( + 'location_nodes', + 'kind', + type_=sa.String(), + existing_type=sa.Enum( + 'HOSPITAL', 'PRACTICE', 'CLINIC', 'TEAM', 'WARD', 'ROOM', 'BED', 'OTHER', + name='locationtypeenum', + native_enum=False, + length=50 + ), + existing_nullable=False + ) + diff --git a/backend/database/migrations/versions/add_user_root_locations.py b/backend/database/migrations/versions/add_user_root_locations.py new file mode 100644 index 0000000..afbba08 --- /dev/null +++ b/backend/database/migrations/versions/add_user_root_locations.py @@ -0,0 +1,36 @@ +"""Add user root locations table. + +Revision ID: add_user_root_locations +Revises: add_patient_location_mapping +Create Date: 2025-01-16 00:00:00.000000 +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "add_user_root_locations" +down_revision: Union[str, Sequence[str], None] = "add_patient_location_mapping" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + op.create_table( + "user_root_locations", + sa.Column("user_id", sa.String(), nullable=False), + sa.Column("location_id", sa.String(), nullable=False), + sa.ForeignKeyConstraint(["user_id"], ["users.id"]), + sa.ForeignKeyConstraint(["location_id"], ["location_nodes.id"]), + sa.PrimaryKeyConstraint("user_id", "location_id"), + ) + + +def downgrade() -> None: + """Downgrade schema.""" + op.drop_table("user_root_locations") + diff --git a/backend/database/migrations/versions/remove_organizations_from_users.py b/backend/database/migrations/versions/remove_organizations_from_users.py new file mode 100644 index 0000000..1d4f7bf --- /dev/null +++ b/backend/database/migrations/versions/remove_organizations_from_users.py @@ -0,0 +1,29 @@ +"""Remove organizations field from users table. + +Revision ID: remove_organizations_from_users +Revises: add_location_organizations_table +Create Date: 2025-01-16 12:00:00.000000 +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "remove_organizations_from_users" +down_revision: Union[str, Sequence[str], None] = "add_location_organizations_table" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + op.drop_column("users", "organizations") + + +def downgrade() -> None: + """Downgrade schema.""" + op.add_column("users", sa.Column("organizations", sa.String(), nullable=True)) + diff --git a/backend/database/models/__init__.py b/backend/database/models/__init__.py index faea722..5de948a 100644 --- a/backend/database/models/__init__.py +++ b/backend/database/models/__init__.py @@ -1,5 +1,5 @@ -from .user import User # noqa: F401 -from .location import LocationNode # noqa: F401 +from .user import User, user_root_locations # noqa: F401 +from .location import LocationNode, location_organizations # noqa: F401 from .patient import Patient, patient_locations, patient_teams # noqa: F401 from .task import Task, task_dependencies # noqa: F401 from .property import PropertyDefinition, PropertyValue # noqa: F401 diff --git a/backend/database/models/location.py b/backend/database/models/location.py index 61d0a41..1042953 100644 --- a/backend/database/models/location.py +++ b/backend/database/models/location.py @@ -1,14 +1,35 @@ from __future__ import annotations import uuid +from enum import Enum from typing import TYPE_CHECKING from database.models.base import Base -from sqlalchemy import ForeignKey, String +from sqlalchemy import Column, ForeignKey, String, Table, Enum as SQLEnum from sqlalchemy.orm import Mapped, mapped_column, relationship if TYPE_CHECKING: from .patient import Patient + from .user import User + + +class LocationTypeEnum(str, Enum): + HOSPITAL = "HOSPITAL" + PRACTICE = "PRACTICE" + CLINIC = "CLINIC" + TEAM = "TEAM" + WARD = "WARD" + ROOM = "ROOM" + BED = "BED" + OTHER = "OTHER" + + +location_organizations = Table( + "location_organizations", + Base.metadata, + Column("location_id", ForeignKey("location_nodes.id"), primary_key=True), + Column("organization_id", String, primary_key=True), +) class LocationNode(Base): @@ -20,7 +41,9 @@ class LocationNode(Base): default=lambda: str(uuid.uuid4()), ) title: Mapped[str] = mapped_column(String) - kind: Mapped[str] = mapped_column(String) + kind: Mapped[LocationTypeEnum] = mapped_column( + SQLEnum(LocationTypeEnum, native_enum=False, length=50) + ) parent_id: Mapped[str | None] = mapped_column( ForeignKey("location_nodes.id"), nullable=True, @@ -60,3 +83,8 @@ class LocationNode(Base): secondary="patient_teams", back_populates="teams", ) + root_users: Mapped[list[User]] = relationship( + "User", + secondary="user_root_locations", + back_populates="root_locations", + ) diff --git a/backend/database/models/user.py b/backend/database/models/user.py index 9b79212..67b9cac 100644 --- a/backend/database/models/user.py +++ b/backend/database/models/user.py @@ -4,12 +4,20 @@ from typing import TYPE_CHECKING from database.models.base import Base -from sqlalchemy import String +from sqlalchemy import Column, ForeignKey, String, Table from sqlalchemy.orm import Mapped, mapped_column, relationship if TYPE_CHECKING: + from .location import LocationNode from .task import Task +user_root_locations = Table( + "user_root_locations", + Base.metadata, + Column("user_id", ForeignKey("users.id"), primary_key=True), + Column("location_id", ForeignKey("location_nodes.id"), primary_key=True), +) + class User(Base): __tablename__ = "users" @@ -29,6 +37,10 @@ class User(Base): nullable=True, default="https://cdn.helpwave.de/boringavatar.svg", ) - organizations: Mapped[str | None] = mapped_column(String, nullable=True) tasks: Mapped[list[Task]] = relationship("Task", back_populates="assignee") + root_locations: Mapped[list[LocationNode]] = relationship( + "LocationNode", + secondary=user_root_locations, + back_populates="root_users", + ) diff --git a/backend/scaffold.py b/backend/scaffold.py index 2855a9c..0016bf8 100644 --- a/backend/scaffold.py +++ b/backend/scaffold.py @@ -5,7 +5,7 @@ from api.inputs import LocationType from config import LOGGER, SCAFFOLD_DIRECTORY -from database.models.location import LocationNode +from database.models.location import LocationNode, location_organizations from database.session import async_session from sqlalchemy import select @@ -118,6 +118,32 @@ async def _create_location_tree( location_id = location.id logger.debug(f"Created location '{name}' ({location_type.value})") + organization_ids = data.get("organization_ids", []) + if organization_ids: + allowed_types_for_orgs = {"HOSPITAL", "CLINIC", "PRACTICE", "TEAM"} + if location_type.value not in allowed_types_for_orgs: + logger.warning( + f"Organization IDs can only be assigned to HOSPITAL, CLINIC, PRACTICE, or TEAM. " + f"Skipping organization assignment for location '{name}' (type: {location_type.value})" + ) + else: + for org_id in organization_ids: + stmt = select(location_organizations).where( + location_organizations.c.location_id == location_id, + location_organizations.c.organization_id == org_id, + ) + result = await session.execute(stmt) + existing_org = result.first() + if not existing_org: + await session.execute( + location_organizations.insert().values( + location_id=location_id, organization_id=org_id + ) + ) + logger.debug( + f"Assigned organization '{org_id}' to location '{name}'" + ) + for child_data in children: await _create_location_tree(session, child_data, location_id) diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 4877679..98fb584 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -99,3 +99,27 @@ async def sample_task( await db_session.commit() await db_session.refresh(task) return task + + +@pytest.fixture +async def sample_user_with_location_access( + db_session: AsyncSession, sample_user: User, sample_location: LocationNode +) -> User: + from database.models.user import user_root_locations + from sqlalchemy import select, insert + + result = await db_session.execute( + select(user_root_locations).where( + user_root_locations.c.user_id == sample_user.id, + user_root_locations.c.location_id == sample_location.id, + ) + ) + existing = result.first() + if not existing: + stmt = insert(user_root_locations).values( + user_id=sample_user.id, location_id=sample_location.id + ) + await db_session.execute(stmt) + await db_session.commit() + await db_session.refresh(sample_user) + return sample_user diff --git a/backend/tests/integration/test_patient_resolver.py b/backend/tests/integration/test_patient_resolver.py index be75266..15139e2 100644 --- a/backend/tests/integration/test_patient_resolver.py +++ b/backend/tests/integration/test_patient_resolver.py @@ -5,13 +5,13 @@ class MockInfo: - def __init__(self, db): - self.context = Context(db=db) + def __init__(self, db, user=None): + self.context = Context(db=db, user=user) @pytest.mark.asyncio -async def test_patient_query_get_patient(db_session, sample_patient): - info = MockInfo(db_session) +async def test_patient_query_get_patient(db_session, sample_patient, sample_user_with_location_access): + info = MockInfo(db_session, sample_user_with_location_access) query = PatientQuery() result = await query.patient(info, sample_patient.id) assert result is not None @@ -20,8 +20,8 @@ async def test_patient_query_get_patient(db_session, sample_patient): @pytest.mark.asyncio -async def test_patient_query_patients(db_session, sample_patient): - info = MockInfo(db_session) +async def test_patient_query_patients(db_session, sample_patient, sample_user_with_location_access): + info = MockInfo(db_session, sample_user_with_location_access) query = PatientQuery() results = await query.patients(info) assert len(results) >= 1 @@ -29,11 +29,11 @@ async def test_patient_query_patients(db_session, sample_patient): @pytest.mark.asyncio -async def test_patient_mutation_create_patient(db_session, sample_location): +async def test_patient_mutation_create_patient(db_session, sample_location, sample_user_with_location_access): from api.inputs import CreatePatientInput from datetime import date - info = MockInfo(db_session) + info = MockInfo(db_session, sample_user_with_location_access) mutation = PatientMutation() input_data = CreatePatientInput( firstname="Jane", @@ -50,10 +50,10 @@ async def test_patient_mutation_create_patient(db_session, sample_location): @pytest.mark.asyncio -async def test_patient_mutation_update_patient(db_session, sample_patient): +async def test_patient_mutation_update_patient(db_session, sample_patient, sample_user_with_location_access): from api.inputs import UpdatePatientInput - info = MockInfo(db_session) + info = MockInfo(db_session, sample_user_with_location_access) mutation = PatientMutation() input_data = UpdatePatientInput(firstname="Updated Name") result = await mutation.update_patient(info, sample_patient.id, input_data) @@ -62,16 +62,16 @@ async def test_patient_mutation_update_patient(db_session, sample_patient): @pytest.mark.asyncio -async def test_patient_mutation_admit_patient(db_session, sample_patient): - info = MockInfo(db_session) +async def test_patient_mutation_admit_patient(db_session, sample_patient, sample_user_with_location_access): + info = MockInfo(db_session, sample_user_with_location_access) mutation = PatientMutation() result = await mutation.admit_patient(info, sample_patient.id) assert result.state == PatientState.ADMITTED.value @pytest.mark.asyncio -async def test_patient_mutation_discharge_patient(db_session, sample_patient): - info = MockInfo(db_session) +async def test_patient_mutation_discharge_patient(db_session, sample_patient, sample_user_with_location_access): + info = MockInfo(db_session, sample_user_with_location_access) mutation = PatientMutation() result = await mutation.discharge_patient(info, sample_patient.id) assert result.state == PatientState.DISCHARGED.value diff --git a/backend/tests/integration/test_task_resolver.py b/backend/tests/integration/test_task_resolver.py index 3cf68c7..14a4d8d 100644 --- a/backend/tests/integration/test_task_resolver.py +++ b/backend/tests/integration/test_task_resolver.py @@ -5,13 +5,13 @@ class MockInfo: - def __init__(self, db): - self.context = Context(db=db) + def __init__(self, db, user=None): + self.context = Context(db=db, user=user) @pytest.mark.asyncio -async def test_task_query_get_task(db_session, sample_task): - info = MockInfo(db_session) +async def test_task_query_get_task(db_session, sample_task, sample_user_with_location_access): + info = MockInfo(db_session, sample_user_with_location_access) query = TaskQuery() result = await query.task(info, sample_task.id) assert result is not None @@ -20,8 +20,8 @@ async def test_task_query_get_task(db_session, sample_task): @pytest.mark.asyncio -async def test_task_query_tasks_by_patient(db_session, sample_patient): - info = MockInfo(db_session) +async def test_task_query_tasks_by_patient(db_session, sample_patient, sample_user_with_location_access): + info = MockInfo(db_session, sample_user_with_location_access) task1 = Task(title="Task 1", patient_id=sample_patient.id) task2 = Task(title="Task 2", patient_id=sample_patient.id) db_session.add(task1) @@ -37,10 +37,10 @@ async def test_task_query_tasks_by_patient(db_session, sample_patient): @pytest.mark.asyncio -async def test_task_mutation_create_task(db_session, sample_patient): +async def test_task_mutation_create_task(db_session, sample_patient, sample_user_with_location_access): from api.inputs import CreateTaskInput - info = MockInfo(db_session) + info = MockInfo(db_session, sample_user_with_location_access) mutation = TaskMutation() input_data = CreateTaskInput( title="New Task", @@ -54,10 +54,10 @@ async def test_task_mutation_create_task(db_session, sample_patient): @pytest.mark.asyncio -async def test_task_mutation_update_task(db_session, sample_task): +async def test_task_mutation_update_task(db_session, sample_task, sample_user_with_location_access): from api.inputs import UpdateTaskInput - info = MockInfo(db_session) + info = MockInfo(db_session, sample_user_with_location_access) mutation = TaskMutation() input_data = UpdateTaskInput(title="Updated Title") result = await mutation.update_task(info, sample_task.id, input_data) @@ -66,8 +66,8 @@ async def test_task_mutation_update_task(db_session, sample_task): @pytest.mark.asyncio -async def test_task_mutation_complete_task(db_session, sample_task): - info = MockInfo(db_session) +async def test_task_mutation_complete_task(db_session, sample_task, sample_user_with_location_access): + info = MockInfo(db_session, sample_user_with_location_access) mutation = TaskMutation() result = await mutation.complete_task(info, sample_task.id) assert result.done is True @@ -75,8 +75,8 @@ async def test_task_mutation_complete_task(db_session, sample_task): @pytest.mark.asyncio -async def test_task_mutation_delete_task(db_session, sample_task): - info = MockInfo(db_session) +async def test_task_mutation_delete_task(db_session, sample_task, sample_user_with_location_access): + info = MockInfo(db_session, sample_user_with_location_access) mutation = TaskMutation() task_id = sample_task.id result = await mutation.delete_task(info, task_id) diff --git a/keycloak/tasks.json b/keycloak/tasks.json index f560215..217244b 100644 --- a/keycloak/tasks.json +++ b/keycloak/tasks.json @@ -400,7 +400,36 @@ "requiredActions" : [ ], "realmRoles" : [ "default-roles-tasks" ], "notBefore" : 0, - "groups" : [ ] + "groups" : [ ], + "attributes" : { + "organization" : [ "test-org-1", "test-org-2" ] + } + }, { + "id" : "15d5c9b3-2d30-4a0d-9c2d-ed3212d5e2db", + "username" : "test2", + "firstName" : "Jane", + "lastName" : "Smith", + "email" : "jane.smith@helpwave.de", + "emailVerified" : true, + "enabled" : true, + "createdTimestamp" : 1764546986928, + "totp" : false, + "credentials" : [ { + "id" : "f15dab03-6c5b-4b53-ad9e-91d1024f2a1c", + "type" : "password", + "userLabel" : "My password", + "createdDate" : 1764546995066, + "secretData" : "{\"value\":\"Yg9msci7ctlF0zTXiQe+vjPkTrDq7lBkKyeWcxHLxlE=\",\"salt\":\"mYbxcfft3FMwjIDx03aHdw==\",\"additionalParameters\":{}}", + "credentialData" : "{\"hashIterations\":5,\"algorithm\":\"argon2\",\"additionalParameters\":{\"hashLength\":[\"32\"],\"memory\":[\"7168\"],\"type\":[\"id\"],\"version\":[\"1.3\"],\"parallelism\":[\"1\"]}}" + } ], + "disableableCredentialTypes" : [ ], + "requiredActions" : [ ], + "realmRoles" : [ "default-roles-tasks" ], + "notBefore" : 0, + "groups" : [ ], + "attributes" : { + "organization" : [ "test-org-2" ] + } } ], "scopeMappings" : [ { "clientScope" : "offline_access", @@ -661,8 +690,8 @@ "authenticationFlowBindingOverrides" : { }, "fullScopeAllowed" : true, "nodeReRegistrationTimeout" : -1, - "defaultClientScopes" : [ "web-origins", "acr", "roles", "profile", "basic", "email" ], - "optionalClientScopes" : [ "address", "phone", "organization", "offline_access", "microprofile-jwt" ] + "defaultClientScopes" : [ "web-origins", "acr", "roles", "profile", "basic", "email", "organization" ], + "optionalClientScopes" : [ "address", "phone", "offline_access", "microprofile-jwt" ] }, { "id" : "1c76a254-0a04-4dc1-b287-4e642ae3e9be", "clientId" : "tasks-web", @@ -1283,12 +1312,13 @@ "id" : "32ae8303-353f-43ad-bc9d-130923bdea7d", "name" : "organization", "protocol" : "openid-connect", - "protocolMapper" : "oidc-organization-membership-mapper", + "protocolMapper" : "oidc-usermodel-attribute-mapper", "consentRequired" : false, "config" : { "introspection.token.claim" : "true", "multivalued" : "true", "userinfo.token.claim" : "true", + "user.attribute" : "organization", "id.token.claim" : "true", "access.token.claim" : "true", "claim.name" : "organization", @@ -1327,8 +1357,8 @@ } } ] } ], - "defaultDefaultClientScopes" : [ "role_list", "saml_organization", "profile", "email", "roles", "web-origins", "acr", "basic" ], - "defaultOptionalClientScopes" : [ "offline_access", "address", "phone", "microprofile-jwt", "organization" ], + "defaultDefaultClientScopes" : [ "role_list", "saml_organization", "profile", "email", "roles", "web-origins", "acr", "basic", "organization" ], + "defaultOptionalClientScopes" : [ "offline_access", "address", "phone", "microprofile-jwt" ], "browserSecurityHeaders" : { "contentSecurityPolicyReportOnly" : "", "xContentTypeOptions" : "nosniff", diff --git a/scaffold/initial.json b/scaffold/initial.json index afeb895..5c2f141 100644 --- a/scaffold/initial.json +++ b/scaffold/initial.json @@ -6,6 +6,7 @@ { "name": "Cardiology", "type": "CLINIC", + "organization_ids": ["test-org-1"], "children": [ { "name": "Cardio-Diagnostics Team", "type": "TEAM" }, { "name": "Interventional Team", "type": "TEAM" }, @@ -36,6 +37,7 @@ { "name": "Orthopedics", "type": "CLINIC", + "organization_ids": ["test-org-2"], "children": [ { "name": "Surgical Excellence Team", "type": "TEAM" }, { "name": "Physio-Support Team", "type": "TEAM" }, diff --git a/web/api/gql/generated.ts b/web/api/gql/generated.ts index fd85967..f6d0c5f 100644 --- a/web/api/gql/generated.ts +++ b/web/api/gql/generated.ts @@ -18,6 +18,12 @@ export type Scalars = { DateTime: { input: any; output: any; } }; +export type CreateLocationNodeInput = { + kind: LocationType; + parentId?: InputMaybe; + title: Scalars['String']['input']; +}; + export type CreatePatientInput = { assignedLocationId?: InputMaybe; assignedLocationIds?: InputMaybe>; @@ -67,6 +73,7 @@ export type LocationNodeType = { children: Array; id: Scalars['ID']['output']; kind: LocationType; + organizationIds: Array; parent?: Maybe; parentId?: Maybe; patients: Array; @@ -89,9 +96,11 @@ export type Mutation = { admitPatient: PatientType; assignTask: TaskType; completeTask: TaskType; + createLocationNode: LocationNodeType; createPatient: PatientType; createPropertyDefinition: PropertyDefinitionType; createTask: TaskType; + deleteLocationNode: Scalars['Boolean']['output']; deletePatient: Scalars['Boolean']['output']; deletePropertyDefinition: Scalars['Boolean']['output']; deleteTask: Scalars['Boolean']['output']; @@ -99,6 +108,7 @@ export type Mutation = { markPatientDead: PatientType; reopenTask: TaskType; unassignTask: TaskType; + updateLocationNode: LocationNodeType; updatePatient: PatientType; updatePropertyDefinition: PropertyDefinitionType; updateTask: TaskType; @@ -122,6 +132,11 @@ export type MutationCompleteTaskArgs = { }; +export type MutationCreateLocationNodeArgs = { + data: CreateLocationNodeInput; +}; + + export type MutationCreatePatientArgs = { data: CreatePatientInput; }; @@ -137,6 +152,11 @@ export type MutationCreateTaskArgs = { }; +export type MutationDeleteLocationNodeArgs = { + id: Scalars['ID']['input']; +}; + + export type MutationDeletePatientArgs = { id: Scalars['ID']['input']; }; @@ -172,6 +192,12 @@ export type MutationUnassignTaskArgs = { }; +export type MutationUpdateLocationNodeArgs = { + data: UpdateLocationNodeInput; + id: Scalars['ID']['input']; +}; + + export type MutationUpdatePatientArgs = { data: UpdatePatientInput; id: Scalars['ID']['input']; @@ -307,6 +333,7 @@ export type QueryPatientArgs = { export type QueryPatientsArgs = { locationNodeId?: InputMaybe; + rootLocationIds?: InputMaybe>; states?: InputMaybe>; }; @@ -329,6 +356,7 @@ export type QueryTaskArgs = { export type QueryTasksArgs = { assigneeId?: InputMaybe; patientId?: InputMaybe; + rootLocationIds?: InputMaybe>; }; @@ -344,6 +372,9 @@ export enum Sex { export type Subscription = { __typename?: 'Subscription'; + locationNodeCreated: Scalars['ID']['output']; + locationNodeDeleted: Scalars['ID']['output']; + locationNodeUpdated: Scalars['ID']['output']; patientCreated: Scalars['ID']['output']; patientStateChanged: Scalars['ID']['output']; patientUpdated: Scalars['ID']['output']; @@ -353,6 +384,11 @@ export type Subscription = { }; +export type SubscriptionLocationNodeUpdatedArgs = { + locationId?: InputMaybe; +}; + + export type SubscriptionPatientStateChangedArgs = { patientId?: InputMaybe; }; @@ -384,6 +420,12 @@ export type TaskType = { updateDate?: Maybe; }; +export type UpdateLocationNodeInput = { + kind?: InputMaybe; + parentId?: InputMaybe; + title?: InputMaybe; +}; + export type UpdatePatientInput = { assignedLocationId?: InputMaybe; assignedLocationIds?: InputMaybe>; @@ -426,6 +468,7 @@ export type UserType = { lastname?: Maybe; name: Scalars['String']['output']; organizations?: Maybe; + rootLocations: Array; tasks: Array; title?: Maybe; username: Scalars['String']['output']; @@ -462,6 +505,7 @@ export type GetPatientQuery = { __typename?: 'Query', patient?: { __typename?: ' export type GetPatientsQueryVariables = Exact<{ locationId?: InputMaybe; + rootLocationIds?: InputMaybe | Scalars['ID']['input']>; states?: InputMaybe | PatientState>; }>; @@ -475,15 +519,25 @@ export type GetTaskQueryVariables = Exact<{ export type GetTaskQuery = { __typename?: 'Query', task?: { __typename?: 'TaskType', id: string, title: string, description?: string | null, done: boolean, dueDate?: any | null, checksum: string, patient: { __typename?: 'PatientType', id: string, name: string }, assignee?: { __typename?: 'UserType', id: string, name: string } | null, properties: Array<{ __typename?: 'PropertyValueType', textValue?: string | null, numberValue?: number | null, booleanValue?: boolean | null, dateValue?: any | null, dateTimeValue?: any | null, selectValue?: string | null, multiSelectValues?: Array | null, definition: { __typename?: 'PropertyDefinitionType', id: string, name: string, description?: string | null, fieldType: FieldType, isActive: boolean, allowedEntities: Array, options: Array } }> } | null }; +export type GetTasksQueryVariables = Exact<{ + rootLocationIds?: InputMaybe | Scalars['ID']['input']>; + assigneeId?: InputMaybe; +}>; + + +export type GetTasksQuery = { __typename?: 'Query', tasks: Array<{ __typename?: 'TaskType', id: string, title: string, description?: string | null, done: boolean, dueDate?: any | null, creationDate: any, updateDate?: any | null, patient: { __typename?: 'PatientType', id: string, name: string, assignedLocation?: { __typename?: 'LocationNodeType', id: string, title: string, parent?: { __typename?: 'LocationNodeType', id: string, title: string } | null } | null, assignedLocations: Array<{ __typename?: 'LocationNodeType', id: string, title: string, kind: LocationType, parent?: { __typename?: 'LocationNodeType', id: string, title: string, parent?: { __typename?: 'LocationNodeType', id: string, title: string } | null } | null }> }, assignee?: { __typename?: 'UserType', id: string, name: string, avatarUrl?: string | null } | null }> }; + export type GetUsersQueryVariables = Exact<{ [key: string]: never; }>; export type GetUsersQuery = { __typename?: 'Query', users: Array<{ __typename?: 'UserType', id: string, name: string, avatarUrl?: string | null }> }; -export type GetGlobalDataQueryVariables = Exact<{ [key: string]: never; }>; +export type GetGlobalDataQueryVariables = Exact<{ + rootLocationIds?: InputMaybe | Scalars['ID']['input']>; +}>; -export type GetGlobalDataQuery = { __typename?: 'Query', me?: { __typename?: 'UserType', id: string, username: string, name: string, firstname?: string | null, lastname?: string | null, avatarUrl?: string | null, tasks: Array<{ __typename?: 'TaskType', id: string, done: boolean }> } | null, wards: Array<{ __typename?: 'LocationNodeType', id: string, title: string }>, teams: Array<{ __typename?: 'LocationNodeType', id: string, title: string }>, clinics: Array<{ __typename?: 'LocationNodeType', id: string, title: string }>, patients: Array<{ __typename?: 'PatientType', id: string, state: PatientState, assignedLocation?: { __typename?: 'LocationNodeType', id: string } | null }>, waitingPatients: Array<{ __typename?: 'PatientType', id: string, state: PatientState }> }; +export type GetGlobalDataQuery = { __typename?: 'Query', me?: { __typename?: 'UserType', id: string, username: string, name: string, firstname?: string | null, lastname?: string | null, avatarUrl?: string | null, organizations?: string | null, rootLocations: Array<{ __typename?: 'LocationNodeType', id: string, title: string, kind: LocationType }>, tasks: Array<{ __typename?: 'TaskType', id: string, done: boolean }> } | null, wards: Array<{ __typename?: 'LocationNodeType', id: string, title: string, parentId?: string | null }>, teams: Array<{ __typename?: 'LocationNodeType', id: string, title: string, parentId?: string | null }>, clinics: Array<{ __typename?: 'LocationNodeType', id: string, title: string, parentId?: string | null }>, patients: Array<{ __typename?: 'PatientType', id: string, state: PatientState, assignedLocation?: { __typename?: 'LocationNodeType', id: string } | null }>, waitingPatients: Array<{ __typename?: 'PatientType', id: string, state: PatientState }> }; export type CreatePatientMutationVariables = Exact<{ data: CreatePatientInput; @@ -599,6 +653,23 @@ export type TaskDeletedSubscriptionVariables = Exact<{ [key: string]: never; }>; export type TaskDeletedSubscription = { __typename?: 'Subscription', taskDeleted: string }; +export type LocationNodeUpdatedSubscriptionVariables = Exact<{ + locationId?: InputMaybe; +}>; + + +export type LocationNodeUpdatedSubscription = { __typename?: 'Subscription', locationNodeUpdated: string }; + +export type LocationNodeCreatedSubscriptionVariables = Exact<{ [key: string]: never; }>; + + +export type LocationNodeCreatedSubscription = { __typename?: 'Subscription', locationNodeCreated: string }; + +export type LocationNodeDeletedSubscriptionVariables = Exact<{ [key: string]: never; }>; + + +export type LocationNodeDeletedSubscription = { __typename?: 'Subscription', locationNodeDeleted: string }; + export type CreateTaskMutationVariables = Exact<{ data: CreateTaskInput; }>; @@ -991,8 +1062,12 @@ export const useGetPatientQuery = < )}; export const GetPatientsDocument = ` - query GetPatients($locationId: ID, $states: [PatientState!]) { - patients(locationNodeId: $locationId, states: $states) { + query GetPatients($locationId: ID, $rootLocationIds: [ID!], $states: [PatientState!]) { + patients( + locationNodeId: $locationId + rootLocationIds: $rootLocationIds + states: $states + ) { id name firstname @@ -1183,6 +1258,66 @@ export const useGetTaskQuery = < } )}; +export const GetTasksDocument = ` + query GetTasks($rootLocationIds: [ID!], $assigneeId: ID) { + tasks(rootLocationIds: $rootLocationIds, assigneeId: $assigneeId) { + id + title + description + done + dueDate + creationDate + updateDate + patient { + id + name + assignedLocation { + id + title + parent { + id + title + } + } + assignedLocations { + id + title + kind + parent { + id + title + parent { + id + title + } + } + } + } + assignee { + id + name + avatarUrl + } + } +} + `; + +export const useGetTasksQuery = < + TData = GetTasksQuery, + TError = unknown + >( + variables?: GetTasksQueryVariables, + options?: Omit, 'queryKey'> & { queryKey?: UseQueryOptions['queryKey'] } + ) => { + + return useQuery( + { + queryKey: variables === undefined ? ['GetTasks'] : ['GetTasks', variables], + queryFn: fetcher(GetTasksDocument, variables), + ...options + } + )}; + export const GetUsersDocument = ` query GetUsers { users { @@ -1210,7 +1345,7 @@ export const useGetUsersQuery = < )}; export const GetGlobalDataDocument = ` - query GetGlobalData { + query GetGlobalData($rootLocationIds: [ID!]) { me { id username @@ -1218,6 +1353,12 @@ export const GetGlobalDataDocument = ` firstname lastname avatarUrl + organizations + rootLocations { + id + title + kind + } tasks { id done @@ -1226,23 +1367,26 @@ export const GetGlobalDataDocument = ` wards: locationNodes(kind: WARD) { id title + parentId } teams: locationNodes(kind: TEAM) { id title + parentId } clinics: locationNodes(kind: CLINIC) { id title + parentId } - patients { + patients(rootLocationIds: $rootLocationIds) { id state assignedLocation { id } } - waitingPatients: patients(states: [WAIT]) { + waitingPatients: patients(states: [WAIT], rootLocationIds: $rootLocationIds) { id state } @@ -1617,6 +1761,21 @@ export const TaskDeletedDocument = ` taskDeleted } `; +export const LocationNodeUpdatedDocument = ` + subscription LocationNodeUpdated($locationId: ID) { + locationNodeUpdated(locationId: $locationId) +} + `; +export const LocationNodeCreatedDocument = ` + subscription LocationNodeCreated { + locationNodeCreated +} + `; +export const LocationNodeDeletedDocument = ` + subscription LocationNodeDeleted { + locationNodeDeleted +} + `; export const CreateTaskDocument = ` mutation CreateTask($data: CreateTaskInput!) { createTask(data: $data) { diff --git a/web/api/graphql/GetPatients.graphql b/web/api/graphql/GetPatients.graphql index e2be458..17bc1a0 100644 --- a/web/api/graphql/GetPatients.graphql +++ b/web/api/graphql/GetPatients.graphql @@ -1,5 +1,5 @@ -query GetPatients($locationId: ID, $states: [PatientState!]) { - patients(locationNodeId: $locationId, states: $states) { +query GetPatients($locationId: ID, $rootLocationIds: [ID!], $states: [PatientState!]) { + patients(locationNodeId: $locationId, rootLocationIds: $rootLocationIds, states: $states) { id name firstname diff --git a/web/api/graphql/GetTasks.graphql b/web/api/graphql/GetTasks.graphql new file mode 100644 index 0000000..ec6f7b8 --- /dev/null +++ b/web/api/graphql/GetTasks.graphql @@ -0,0 +1,42 @@ +query GetTasks($rootLocationIds: [ID!], $assigneeId: ID) { + tasks(rootLocationIds: $rootLocationIds, assigneeId: $assigneeId) { + id + title + description + done + dueDate + creationDate + updateDate + patient { + id + name + assignedLocation { + id + title + parent { + id + title + } + } + assignedLocations { + id + title + kind + parent { + id + title + parent { + id + title + } + } + } + } + assignee { + id + name + avatarUrl + } + } +} + diff --git a/web/api/graphql/GlobalData.graphql b/web/api/graphql/GlobalData.graphql index 3eb3695..319978f 100644 --- a/web/api/graphql/GlobalData.graphql +++ b/web/api/graphql/GlobalData.graphql @@ -1,4 +1,4 @@ -query GetGlobalData { +query GetGlobalData($rootLocationIds: [ID!]) { me { id username @@ -6,6 +6,12 @@ query GetGlobalData { firstname lastname avatarUrl + organizations + rootLocations { + id + title + kind + } tasks { id done @@ -14,23 +20,26 @@ query GetGlobalData { wards: locationNodes(kind: WARD) { id title + parentId } teams: locationNodes(kind: TEAM) { id title + parentId } clinics: locationNodes(kind: CLINIC) { id title + parentId } - patients { + patients(rootLocationIds: $rootLocationIds) { id state assignedLocation { id } } - waitingPatients: patients(states: [WAIT]) { + waitingPatients: patients(states: [WAIT], rootLocationIds: $rootLocationIds) { id state } diff --git a/web/api/graphql/Subscriptions.graphql b/web/api/graphql/Subscriptions.graphql index bebc73b..4ecc7fe 100644 --- a/web/api/graphql/Subscriptions.graphql +++ b/web/api/graphql/Subscriptions.graphql @@ -22,3 +22,15 @@ subscription TaskDeleted { taskDeleted } +subscription LocationNodeUpdated($locationId: ID) { + locationNodeUpdated(locationId: $locationId) +} + +subscription LocationNodeCreated { + locationNodeCreated +} + +subscription LocationNodeDeleted { + locationNodeDeleted +} + diff --git a/web/components/layout/Page.tsx b/web/components/layout/Page.tsx index d102ae4..7fcff95 100644 --- a/web/components/layout/Page.tsx +++ b/web/components/layout/Page.tsx @@ -10,7 +10,6 @@ import { Button, Dialog, Expandable, - LoadingContainer, MarkdownInterpreter, useLocalStorage } from '@helpwave/hightide' @@ -33,8 +32,10 @@ import { Notifications } from '@/components/Notifications' import { TasksLogo } from '@/components/TasksLogo' import { useRouter } from 'next/router' import { useTasksContext } from '@/hooks/useTasksContext' +import { useGetLocationsQuery } from '@/api/gql/generated' import { hashString } from '@/utils/hash' import { useSwipeGesture } from '@/hooks/useSwipeGesture' +import { LocationSelectionDialog } from '@/components/locations/LocationSelectionDialog' export const StagingDisclaimerDialog = () => { const config = getConfig() @@ -204,8 +205,69 @@ type HeaderProps = HTMLAttributes & { export const Header = ({ onMenuClick, isMenuOpen, ...props }: HeaderProps) => { const router = useRouter() - const { user } = useTasksContext() + const { user, rootLocations, selectedRootLocationIds, update } = useTasksContext() + const translation = useTasksTranslation() + const [isLocationPickerOpen, setIsLocationPickerOpen] = useState(false) + const [selectedLocationsCache, setSelectedLocationsCache] = useState>([]) + + const { data: locationsData } = useGetLocationsQuery( + {}, + { + enabled: !!selectedRootLocationIds && selectedRootLocationIds.length > 0, + refetchInterval: 30000, + refetchOnWindowFocus: true, + } + ) + useEffect(() => { + if (selectedRootLocationIds && selectedRootLocationIds.length > 0) { + const foundInRoot = rootLocations?.filter(loc => selectedRootLocationIds.includes(loc.id)) || [] + + if (foundInRoot.length === selectedRootLocationIds.length) { + setSelectedLocationsCache(foundInRoot.map(loc => ({ id: loc.id, title: loc.title, kind: loc.kind }))) + } else if (locationsData?.locationNodes) { + const allLocations = locationsData.locationNodes + const foundLocations: Array<{ id: string, title: string, kind?: string }> = [] + for (const id of selectedRootLocationIds) { + const inRoot = rootLocations?.find(loc => loc.id === id) + if (inRoot) { + foundLocations.push({ id: inRoot.id, title: inRoot.title, kind: inRoot.kind }) + } else { + const inAll = allLocations.find(loc => loc.id === id) + if (inAll) { + foundLocations.push({ id: inAll.id, title: inAll.title, kind: inAll.kind }) + } + } + } + + if (foundLocations.length > 0) { + setSelectedLocationsCache(foundLocations) + } + } else if (foundInRoot.length > 0) { + setSelectedLocationsCache(foundInRoot.map(loc => ({ id: loc.id, title: loc.title, kind: loc.kind }))) + } + } else { + setSelectedLocationsCache([]) + } + }, [rootLocations, selectedRootLocationIds, locationsData]) + + const selectedRootLocations = selectedLocationsCache.length > 0 + ? selectedLocationsCache + : (rootLocations?.filter(loc => selectedRootLocationIds?.includes(loc.id)) || []) + const firstSelectedRootLocation = selectedRootLocations[0] + + const handleRootLocationSelect = (locations: Array<{ id: string, title: string, kind?: string }>) => { + if (locations.length === 0) return + const locationIds = locations.map(loc => loc.id) + setSelectedLocationsCache(locations) + update(prevState => { + return { + ...prevState, + selectedRootLocationIds: locationIds, + } + }) + setIsLocationPickerOpen(false) + } return (
{ {isMenuOpen ? : } -
+
+ {rootLocations && rootLocations.length > 0 && ( +
+ + setIsLocationPickerOpen(false)} + onSelect={handleRootLocationSelect} + initialSelectedIds={selectedRootLocationIds || []} + multiSelect={true} + useCase="root" + /> +
+ )}
@@ -342,89 +430,89 @@ export const Sidebar = ({ isOpen, onClose, ...props }: SidebarProps) => { {context?.totalPatientsCount !== undefined && ({context.totalPatientsCount})} - - - {translation('teams')} -
- )} - headerClassName="!px-2.5 !py-1.5 hover:bg-black/30" - contentClassName="!px-0 !pb-0 gap-y-0" - className="!shadow-none" - isExpanded={context.sidebar.isShowingTeams} - onChange={isExpanded => context.update(prevState => ({ - ...prevState, - sidebar: { - ...prevState.sidebar, - isShowingTeams: isExpanded, - } - }))} - > - {!context?.teams ? ( - - ) : context.teams.map(team => ( - - {team.title} - - ))} - - - - - {translation('wards')} -
- )} - headerClassName="!px-2.5 !py-1.5 hover:bg-black/30" - contentClassName="!px-0 !pb-0 gap-y-0" - className="!shadow-none" - isExpanded={context.sidebar.isShowingWards} - onChange={isExpanded => context.update(prevState => ({ - ...prevState, - sidebar: { - ...prevState.sidebar, - isShowingWards: isExpanded, - } - }))} - > - {!context?.wards ? ( - - ) : context.wards.map(ward => ( - - {ward.title} - - ))} - - - - - {translation('clinics')} - - )} - headerClassName="!px-2.5 !py-1.5 hover:bg-black/30" - contentClassName="!px-0 !pb-0 gap-y-0" - className="!shadow-none" - isExpanded={context.sidebar.isShowingClinics} - onChange={isExpanded => context.update(prevState => ({ - ...prevState, - sidebar: { - ...prevState.sidebar, - isShowingClinics: isExpanded, - } - }))} - > - {!context?.clinics ? ( - - ) : context.clinics.map(clinic => ( - - {clinic.title} - - ))} - + {context?.teams && context.teams.length > 0 && ( + + + {translation('teams')} + + )} + headerClassName="!px-2.5 !py-1.5 hover:bg-black/30" + contentClassName="!px-0 !pb-0 gap-y-0" + className="!shadow-none" + isExpanded={context.sidebar.isShowingTeams} + onChange={isExpanded => context.update(prevState => ({ + ...prevState, + sidebar: { + ...prevState.sidebar, + isShowingTeams: isExpanded, + } + }))} + > + {context.teams.map(team => ( + + {team.title} + + ))} + + )} + + {context?.wards && context.wards.length > 0 && ( + + + {translation('wards')} + + )} + headerClassName="!px-2.5 !py-1.5 hover:bg-black/30" + contentClassName="!px-0 !pb-0 gap-y-0" + className="!shadow-none" + isExpanded={context.sidebar.isShowingWards} + onChange={isExpanded => context.update(prevState => ({ + ...prevState, + sidebar: { + ...prevState.sidebar, + isShowingWards: isExpanded, + } + }))} + > + {context.wards.map(ward => ( + + {ward.title} + + ))} + + )} + + {context?.clinics && context.clinics.length > 0 && ( + + + {translation('clinics')} + + )} + headerClassName="!px-2.5 !py-1.5 hover:bg-black/30" + contentClassName="!px-0 !pb-0 gap-y-0" + className="!shadow-none" + isExpanded={context.sidebar.isShowingClinics} + onChange={isExpanded => context.update(prevState => ({ + ...prevState, + sidebar: { + ...prevState.sidebar, + isShowingClinics: isExpanded, + } + }))} + > + {context.clinics.map(clinic => ( + + {clinic.title} + + ))} + + )} diff --git a/web/components/locations/LocationSelectionDialog.tsx b/web/components/locations/LocationSelectionDialog.tsx index 6293a36..173dc8f 100644 --- a/web/components/locations/LocationSelectionDialog.tsx +++ b/web/components/locations/LocationSelectionDialog.tsx @@ -26,6 +26,7 @@ export type LocationPickerUseCase = | 'clinic' | 'position' | 'teams' + | 'root' interface LocationSelectionDialogProps { isOpen: boolean, @@ -48,11 +49,13 @@ interface LocationTreeItemProps { const getKindStyles = (kind: string) => { const k = kind.toUpperCase() - if (k.includes('CLINIC')) return 'bg-blue-100 text-blue-700 dark:bg-blue-900/30 dark:text-blue-300' - if (k.includes('WARD')) return 'bg-orange-100 text-orange-700 dark:bg-orange-900/30 dark:text-orange-300' - if (k.includes('TEAM')) return 'bg-purple-100 text-purple-700 dark:bg-purple-900/30 dark:text-purple-300' - if (k.includes('ROOM')) return 'bg-gray-100 text-gray-700 dark:bg-gray-800 dark:text-gray-300' - if (k.includes('BED')) return 'bg-emerald-100 text-emerald-700 dark:bg-emerald-900/30 dark:text-emerald-300' + if (k === 'HOSPITAL') return 'bg-red-100 text-red-700 dark:bg-red-900/30 dark:text-red-300' + if (k === 'PRACTICE') return 'bg-indigo-100 text-indigo-700 dark:bg-indigo-900/30 dark:text-indigo-300' + if (k === 'CLINIC') return 'bg-blue-100 text-blue-700 dark:bg-blue-900/30 dark:text-blue-300' + if (k === 'TEAM') return 'bg-purple-100 text-purple-700 dark:bg-purple-900/30 dark:text-purple-300' + if (k === 'WARD') return 'bg-orange-100 text-orange-700 dark:bg-orange-900/30 dark:text-orange-300' + if (k === 'ROOM') return 'bg-gray-100 text-gray-700 dark:bg-gray-800 dark:text-gray-300' + if (k === 'BED') return 'bg-emerald-100 text-emerald-700 dark:bg-emerald-900/30 dark:text-emerald-300' return 'bg-surface-subdued text-text-tertiary' } @@ -176,6 +179,7 @@ export const LocationSelectionDialog = ({ const hasInitialized = useRef(false) + useEffect(() => { if (isOpen) { setSelectedIds(new Set(initialSelectedIds)) @@ -191,7 +195,23 @@ export const LocationSelectionDialog = ({ return () => true } - if (useCase === 'clinic') { + if (useCase === 'root') { + const allowedKinds = new Set([ + LocationType.Hospital, + LocationType.Practice, + LocationType.Clinic, + LocationType.Team, + 'HOSPITAL', + 'PRACTICE', + 'CLINIC', + 'TEAM', + ]) + return (node: LocationNodeType) => { + const kindStr = node.kind.toString().toUpperCase() + return allowedKinds.has(node.kind as LocationType) || + allowedKinds.has(kindStr) + } + } else if (useCase === 'clinic') { return (node: LocationNodeType) => { const kindStr = node.kind.toString().toUpperCase() return kindStr === 'CLINIC' || node.kind === LocationType.Clinic @@ -314,10 +334,11 @@ export const LocationSelectionDialog = ({ const handleToggleSelect = (node: LocationNodeType, checked: boolean) => { const newSet = new Set(selectedIds) if (checked) { - // For clinic useCase, enforce exactly one selection (clear all first) if (useCase === 'clinic' || !multiSelect) { newSet.clear() } + + newSet.add(node.id) } else { newSet.delete(node.id) @@ -347,8 +368,11 @@ export const LocationSelectionDialog = ({ const handleConfirm = () => { if (!data?.locationNodes) return + if (selectedIds.size === 0) return const nodes = data.locationNodes as LocationNodeType[] - const selectedNodes = nodes.filter(n => selectedIds.has(n.id)) + + const finalSelectedIds = Array.from(selectedIds) + const selectedNodes = nodes.filter(n => finalSelectedIds.includes(n.id)) onSelect(selectedNodes) onClose() } @@ -373,6 +397,7 @@ export const LocationSelectionDialog = ({ {useCase === 'clinic' ? translation('pickClinic') : useCase === 'position' ? translation('pickPosition') : useCase === 'teams' ? translation('pickTeams') : + useCase === 'root' ? translation('selectRootLocation') : translation('selectLocation')} )} @@ -380,6 +405,7 @@ export const LocationSelectionDialog = ({ useCase === 'clinic' ? translation('pickClinicDescription') : useCase === 'position' ? translation('pickPositionDescription') : useCase === 'teams' ? translation('pickTeamsDescription') : + useCase === 'root' ? translation('selectRootLocationDescription') : translation('selectLocationDescription') } className="w-[600px] h-[80vh] flex flex-col max-w-full" @@ -405,7 +431,7 @@ export const LocationSelectionDialog = ({ - {multiSelect && ( + {multiSelect && useCase !== 'root' && (
diff --git a/web/components/patients/LocationChips.tsx b/web/components/patients/LocationChips.tsx index 59a4eb8..513eacb 100644 --- a/web/components/patients/LocationChips.tsx +++ b/web/components/patients/LocationChips.tsx @@ -21,11 +21,13 @@ interface LocationChipsProps { const getKindStyles = (kind: string | undefined) => { if (!kind) return 'bg-surface-subdued text-text-tertiary' const k = kind.toUpperCase() - if (k.includes('CLINIC')) return 'bg-blue-100 text-blue-700 dark:bg-blue-900/30 dark:text-blue-300' - if (k.includes('WARD')) return 'bg-orange-100 text-orange-700 dark:bg-orange-900/30 dark:text-orange-300' - if (k.includes('TEAM')) return 'bg-purple-100 text-purple-700 dark:bg-purple-900/30 dark:text-purple-300' - if (k.includes('ROOM')) return 'bg-gray-100 text-gray-700 dark:bg-gray-800 dark:text-gray-300' - if (k.includes('BED')) return 'bg-emerald-100 text-emerald-700 dark:bg-emerald-900/30 dark:text-emerald-300' + if (k === 'HOSPITAL') return 'bg-red-100 text-red-700 dark:bg-red-900/30 dark:text-red-300' + if (k === 'PRACTICE') return 'bg-indigo-100 text-indigo-700 dark:bg-indigo-900/30 dark:text-indigo-300' + if (k === 'CLINIC') return 'bg-blue-100 text-blue-700 dark:bg-blue-900/30 dark:text-blue-300' + if (k === 'TEAM') return 'bg-purple-100 text-purple-700 dark:bg-purple-900/30 dark:text-purple-300' + if (k === 'WARD') return 'bg-orange-100 text-orange-700 dark:bg-orange-900/30 dark:text-orange-300' + if (k === 'ROOM') return 'bg-gray-100 text-gray-700 dark:bg-gray-800 dark:text-gray-300' + if (k === 'BED') return 'bg-emerald-100 text-emerald-700 dark:bg-emerald-900/30 dark:text-emerald-300' return 'bg-surface-subdued text-text-tertiary' } diff --git a/web/components/patients/PatientDetailView.tsx b/web/components/patients/PatientDetailView.tsx index 8f177f3..87ff093 100644 --- a/web/components/patients/PatientDetailView.tsx +++ b/web/components/patients/PatientDetailView.tsx @@ -171,7 +171,8 @@ export const PatientDetailView = ({ initialCreateData = {} }: PatientDetailViewProps) => { const translation = useTasksTranslation() - const { selectedLocationId } = useTasksContext() + const { selectedLocationId, selectedRootLocationIds, rootLocations } = useTasksContext() + const firstSelectedRootLocationId = selectedRootLocationIds && selectedRootLocationIds.length > 0 ? selectedRootLocationIds[0] : undefined const queryClient = useQueryClient() const [taskId, setTaskId] = useState(null) const [isCreatingTask, setIsCreatingTask] = useState(false) @@ -226,6 +227,8 @@ export const PatientDetailView = ({ const [selectedTeams, setSelectedTeams] = useState([]) const [isMarkDeadDialogOpen, setIsMarkDeadDialogOpen] = useState(false) const [isDischargeDialogOpen, setIsDischargeDialogOpen] = useState(false) + const [isLocationChangeConfirmOpen, setIsLocationChangeConfirmOpen] = useState(false) + const [pendingLocationUpdate, setPendingLocationUpdate] = useState<(() => void) | null>(null) // Validation state for required fields const [firstnameError, setFirstnameError] = useState(null) @@ -238,7 +241,8 @@ export const PatientDetailView = ({ if (patientData?.patient) { const patient = patientData.patient const { firstname, lastname, sex, birthdate, assignedLocations, clinic, position, teams } = patient - setFormData({ + setFormData(prev => ({ + ...prev, firstname, lastname, sex, @@ -247,13 +251,42 @@ export const PatientDetailView = ({ clinicId: clinic?.id || undefined, positionId: position?.id || undefined, teamIds: teams?.map(t => t.id) || undefined, - } as CreatePatientInput & { clinicId?: string, positionId?: string, teamIds?: string[] }) + } as CreatePatientInput & { clinicId?: string, positionId?: string, teamIds?: string[] })) setSelectedClinic(clinic ? (clinic as LocationNodeType) : null) setSelectedPosition(position ? (position as LocationNodeType) : null) setSelectedTeams((teams || []) as LocationNodeType[]) } }, [patientData]) + useEffect(() => { + if (!isEditMode && locationsData?.locationNodes && !formData.clinicId) { + // Try to find a CLINIC in the selected root locations first + let clinicLocation: LocationNodeType | undefined + if (firstSelectedRootLocationId) { + const selectedRootLocation = locationsData.locationNodes.find( + loc => loc.id === firstSelectedRootLocationId && loc.kind === 'CLINIC' + ) + if (selectedRootLocation) { + clinicLocation = selectedRootLocation as LocationNodeType + } + } + // If no CLINIC found in selected, try first root location that is a CLINIC + if (!clinicLocation && rootLocations && rootLocations.length > 0) { + const firstClinic = rootLocations.find(loc => loc.kind === 'CLINIC') + if (firstClinic) { + clinicLocation = firstClinic as LocationNodeType + } + } + if (clinicLocation) { + setSelectedClinic(clinicLocation) + setFormData(prev => ({ + ...prev, + clinicId: clinicLocation!.id, + })) + } + } + }, [isEditMode, firstSelectedRootLocationId, locationsData, formData.clinicId, rootLocations]) + const { mutate: createPatient, isLoading: isCreating } = useCreatePatientMutation({ onSuccess: () => { queryClient.invalidateQueries({ queryKey: ['GetGlobalData'] }) @@ -432,16 +465,30 @@ export const PatientDetailView = ({ setSelectedClinic(clinic) updateLocalState({ clinicId: clinic.id } as Partial) if (isEditMode) { - persistChanges({ clinicId: clinic.id } as Partial) + const updateFn = () => { + persistChanges({ clinicId: clinic.id } as Partial) + setIsLocationChangeConfirmOpen(false) + setPendingLocationUpdate(null) + } + setPendingLocationUpdate(() => updateFn) + setIsLocationChangeConfirmOpen(true) + } else { + validateClinic(clinic) } - validateClinic(clinic) } else { setSelectedClinic(null) updateLocalState({ clinicId: undefined } as Partial) if (isEditMode) { - persistChanges({ clinicId: undefined } as Partial) + const updateFn = () => { + persistChanges({ clinicId: undefined } as Partial) + setIsLocationChangeConfirmOpen(false) + setPendingLocationUpdate(null) + } + setPendingLocationUpdate(() => updateFn) + setIsLocationChangeConfirmOpen(true) + } else { + validateClinic(null) } - validateClinic(null) } setIsClinicDialogOpen(false) } @@ -451,11 +498,31 @@ export const PatientDetailView = ({ if (position) { setSelectedPosition(position) updateLocalState({ positionId: position.id } as Partial) - persistChanges({ positionId: position.id } as Partial) + if (isEditMode) { + const updateFn = () => { + persistChanges({ positionId: position.id } as Partial) + setIsLocationChangeConfirmOpen(false) + setPendingLocationUpdate(null) + } + setPendingLocationUpdate(() => updateFn) + setIsLocationChangeConfirmOpen(true) + } else { + persistChanges({ positionId: position.id } as Partial) + } } else { setSelectedPosition(null) updateLocalState({ positionId: undefined } as Partial) - persistChanges({ positionId: undefined } as Partial) + if (isEditMode) { + const updateFn = () => { + persistChanges({ positionId: undefined } as Partial) + setIsLocationChangeConfirmOpen(false) + setPendingLocationUpdate(null) + } + setPendingLocationUpdate(() => updateFn) + setIsLocationChangeConfirmOpen(true) + } else { + persistChanges({ positionId: undefined } as Partial) + } } setIsPositionDialogOpen(false) } @@ -464,7 +531,17 @@ export const PatientDetailView = ({ setSelectedTeams(locations) const teamIds = locations.map(loc => loc.id) updateLocalState({ teamIds } as Partial) - persistChanges({ teamIds } as Partial) + if (isEditMode) { + const updateFn = () => { + persistChanges({ teamIds } as Partial) + setIsLocationChangeConfirmOpen(false) + setPendingLocationUpdate(null) + } + setPendingLocationUpdate(() => updateFn) + setIsLocationChangeConfirmOpen(true) + } else { + persistChanges({ teamIds } as Partial) + } setIsTeamsDialogOpen(false) } @@ -1062,6 +1139,22 @@ export const PatientDetailView = ({ confirmType="neutral" /> + { + setIsLocationChangeConfirmOpen(false) + setPendingLocationUpdate(null) + }} + onConfirm={() => { + if (pendingLocationUpdate) { + pendingLocationUpdate() + } + }} + titleElement={translation('updateLocation')} + description={translation('updateLocationConfirmation')} + confirmType="neutral" + /> + setIsClinicDialogOpen(false)} diff --git a/web/components/patients/PatientList.tsx b/web/components/patients/PatientList.tsx index 067c3f0..02e6984 100644 --- a/web/components/patients/PatientList.tsx +++ b/web/components/patients/PatientList.tsx @@ -8,6 +8,7 @@ import { SmartDate } from '@/utils/date' import { LocationChips } from '@/components/patients/LocationChips' import { PatientStateChip } from '@/components/patients/PatientStateChip' import { useTasksTranslation } from '@/i18n/useTasksTranslation' +import { useTasksContext } from '@/hooks/useTasksContext' import type { ColumnDef } from '@tanstack/table-core' type PatientViewModel = { @@ -38,6 +39,7 @@ type PatientListProps = { export const PatientList = forwardRef(({ locationId, initialPatientId, onInitialPatientOpened, acceptedStates }, ref) => { const translation = useTasksTranslation() + const { selectedRootLocationIds } = useTasksContext() const [isPanelOpen, setIsPanelOpen] = useState(false) const [selectedPatient, setSelectedPatient] = useState(undefined) const [searchQuery, setSearchQuery] = useState('') @@ -46,6 +48,7 @@ export const PatientList = forwardRef(({ locat const { data: queryData, refetch } = useGetPatientsQuery( { locationId: locationId, + rootLocationIds: selectedRootLocationIds && selectedRootLocationIds.length > 0 ? selectedRootLocationIds : undefined, states: acceptedStates }, { diff --git a/web/hooks/useTasksContext.tsx b/web/hooks/useTasksContext.tsx index 567a445..0e0ee49 100644 --- a/web/hooks/useTasksContext.tsx +++ b/web/hooks/useTasksContext.tsx @@ -1,18 +1,65 @@ import type { Dispatch, SetStateAction } from 'react' -import { createContext, type PropsWithChildren, useContext, useEffect, useState } from 'react' +import { createContext, type PropsWithChildren, useContext, useEffect, useRef, useState } from 'react' import { usePathname } from 'next/navigation' -import { useGetGlobalDataQuery } from '@/api/gql/generated' +import { useGetGlobalDataQuery, useGetLocationsQuery } from '@/api/gql/generated' +import { useQueryClient } from '@tanstack/react-query' import { useAuth } from './useAuth' +import { useLocalStorage } from '@helpwave/hightide' + +function filterLocationsByRootSubtree( + locations: Array<{ id: string, title: string, parentId?: string | null }>, + selectedRootLocationIds: string[], + rootLocations: Array<{ id: string, title: string, kind?: string }>, + allLocations?: Array<{ id: string, title: string, parentId?: string | null }> +): Array<{ id: string, title: string, kind?: string }> { + if (!selectedRootLocationIds || selectedRootLocationIds.length === 0) { + return [] + } + + const rootLocationSet = new Set(selectedRootLocationIds) + const allLocationsMap = new Map() + + if (allLocations) { + allLocations.forEach(loc => allLocationsMap.set(loc.id, loc)) + } + locations.forEach(loc => allLocationsMap.set(loc.id, loc)) + rootLocations.forEach(loc => allLocationsMap.set(loc.id, { id: loc.id, title: loc.title, parentId: null })) + + const isDescendantOfRoot = (locationId: string): boolean => { + if (rootLocationSet.has(locationId)) { + return true + } + + let current = allLocationsMap.get(locationId) + const visited = new Set() + + while (current?.parentId && !visited.has(current.id)) { + visited.add(current.id) + if (rootLocationSet.has(current.parentId)) { + return true + } + current = allLocationsMap.get(current.parentId) + } + + return false + } + + return locations + .filter(loc => isDescendantOfRoot(loc.id)) + .map(loc => ({ id: loc.id, title: loc.title })) +} type User = { id: string, name: string, avatarUrl?: string | null, + organizations?: string | null, } type LocationNode = { id: string, title: string, + kind?: string, } type SidebarContextType = { @@ -30,8 +77,10 @@ export type TasksContextState = { wards?: LocationNode[], clinics?: LocationNode[], selectedLocationId?: string, + selectedRootLocationIds?: string[], sidebar: SidebarContextType, user?: User, + rootLocations?: LocationNode[], } export type TasksContextType = TasksContextState & { @@ -53,48 +102,146 @@ export const useTasksContext = (): TasksContextType => { export const TasksContextProvider = ({ children }: PropsWithChildren) => { const pathName = usePathname() const { identity, isLoading: isAuthLoading } = useAuth() + const queryClient = useQueryClient() + const { + value: storedSelectedRootLocationIds, + setValue: setStoredSelectedRootLocationIds + } = useLocalStorage('selected-root-location-ids', []) const [state, setState] = useState({ sidebar: { isShowingTeams: false, isShowingWards: false, isShowingClinics: false, - } + }, + selectedRootLocationIds: storedSelectedRootLocationIds.length > 0 ? storedSelectedRootLocationIds : undefined, }) - const { data } = useGetGlobalDataQuery(undefined, { - enabled: !isAuthLoading && !!identity, - refetchInterval: 5000, - refetchOnWindowFocus: true, - refetchOnMount: true, - }) + const { data: allLocationsData } = useGetLocationsQuery( + {}, + { + enabled: !isAuthLoading && !!identity, + refetchInterval: 30000, + refetchOnWindowFocus: true, + } + ) + + const selectedRootLocationIdsForQuery = state.selectedRootLocationIds && state.selectedRootLocationIds.length > 0 + ? state.selectedRootLocationIds + : undefined + + const { data } = useGetGlobalDataQuery( + { + rootLocationIds: selectedRootLocationIdsForQuery + }, + { + enabled: !isAuthLoading && !!identity, + refetchInterval: 5000, + refetchOnWindowFocus: true, + refetchOnMount: true, + } + ) + + + const prevRootLocationIdsRef = useRef('') + const prevSelectedRootLocationIdsRef = useRef('') + + useEffect(() => { + const currentSelectedIds = (state.selectedRootLocationIds || []).sort().join(',') + if (prevSelectedRootLocationIdsRef.current !== currentSelectedIds) { + prevSelectedRootLocationIdsRef.current = currentSelectedIds + queryClient.invalidateQueries({ queryKey: ['GetGlobalData'] }) + } + }, [state.selectedRootLocationIds, queryClient]) + + useEffect(() => { + if (data?.me?.rootLocations) { + const currentRootLocationIds = data.me.rootLocations.map(loc => loc.id).sort().join(',') + + if (prevRootLocationIdsRef.current && prevRootLocationIdsRef.current !== currentRootLocationIds) { + queryClient.invalidateQueries() + } + prevRootLocationIdsRef.current = currentRootLocationIds + } + }, [data?.me?.rootLocations, queryClient]) useEffect(() => { const totalPatientsCount = data?.patients?.length ?? 0 const waitingPatientsCount = data?.waitingPatients?.length ?? 0 - setState(prevState => ({ - ...prevState, - user: data?.me ? { - id: data.me.id, - name: data.me.name, - avatarUrl: data.me.avatarUrl - } : undefined, - myTasksCount: data?.me?.tasks?.filter(t => !t.done).length ?? 0, - totalPatientsCount, - waitingPatientsCount, - locationPatientsCount: prevState.selectedLocationId - ? data?.patients?.filter(p => p.assignedLocation?.id === prevState.selectedLocationId).length ?? 0 - : totalPatientsCount, - teams: data?.teams, - wards: data?.wards, - clinics: data?.clinics, - })) - }, [data]) + const rootLocations = data?.me?.rootLocations?.map(loc => ({ id: loc.id, title: loc.title, kind: loc.kind })) ?? [] + + setState(prevState => { + let selectedRootLocationIds = prevState.selectedRootLocationIds || [] + + if (rootLocations.length > 0 && selectedRootLocationIds.length === 0 && storedSelectedRootLocationIds.length === 0) { + selectedRootLocationIds = [rootLocations[0]!.id] + } + + return { + ...prevState, + user: data?.me ? { + id: data.me.id, + name: data.me.name, + avatarUrl: data.me.avatarUrl, + organizations: data.me.organizations ?? null + } : undefined, + myTasksCount: data?.me?.tasks?.filter(t => !t.done).length ?? 0, + totalPatientsCount, + waitingPatientsCount, + locationPatientsCount: prevState.selectedLocationId + ? data?.patients?.filter(p => p.assignedLocation?.id === prevState.selectedLocationId).length ?? 0 + : totalPatientsCount, + teams: filterLocationsByRootSubtree( + data?.teams || [], + selectedRootLocationIds, + rootLocations, + allLocationsData?.locationNodes + ), + wards: filterLocationsByRootSubtree( + data?.wards || [], + selectedRootLocationIds, + rootLocations, + allLocationsData?.locationNodes + ), + clinics: filterLocationsByRootSubtree( + data?.clinics || [], + selectedRootLocationIds, + rootLocations, + allLocationsData?.locationNodes + ), + rootLocations, + selectedRootLocationIds, + } + }) + }, [data, storedSelectedRootLocationIds, allLocationsData]) + + const lastWrittenLocationIdsRef = useRef(undefined) + + useEffect(() => { + if (state.selectedRootLocationIds !== undefined) { + const currentIds = state.selectedRootLocationIds + const lastWritten = lastWrittenLocationIdsRef.current + if (JSON.stringify(currentIds) !== JSON.stringify(lastWritten)) { + lastWrittenLocationIdsRef.current = currentIds + setStoredSelectedRootLocationIds(currentIds) + } + } + }, [state.selectedRootLocationIds, setStoredSelectedRootLocationIds]) + + const updateState: Dispatch> = (updater) => { + setState(prevState => { + const newState = typeof updater === 'function' ? updater(prevState) : updater + if (newState.selectedRootLocationIds !== prevState.selectedRootLocationIds) { + setStoredSelectedRootLocationIds(newState.selectedRootLocationIds || []) + } + return newState + }) + } return ( diff --git a/web/i18n/translations.ts b/web/i18n/translations.ts index 9c7e666..3b3fb95 100644 --- a/web/i18n/translations.ts +++ b/web/i18n/translations.ts @@ -74,6 +74,7 @@ export type TasksTranslationEntries = { 'lastUpdate': string, 'loading': string, 'location': string, + 'locations': string, 'locationType': (values: { type: string }) => string, 'login': string, 'loginRequired': string, @@ -110,6 +111,7 @@ export type TasksTranslationEntries = { 'openSurvey': string, 'openTasks': string, 'option': string, + 'organizations': string, 'overview': string, 'pages.404.notFound': string, 'pages.404.notFoundDescription1': string, @@ -155,8 +157,12 @@ export type TasksTranslationEntries = { 'selectLocation': string, 'selectLocationDescription': string, 'selectOptions': string, + 'selectOrganization': string, + 'selectOrganizations': string, 'selectPatient': string, 'selectPosition': string, + 'selectRootLocation': string, + 'selectRootLocationDescription': string, 'selectTeams': string, 'settings': string, 'settingsDescription': string, @@ -182,6 +188,8 @@ export type TasksTranslationEntries = { 'type': string, 'unassigned': string, 'updated': string, + 'updateLocation': string, + 'updateLocationConfirmation': string, 'visibility': string, 'waitingForPatient': string, 'waitingroom': string, @@ -259,6 +267,7 @@ export const tasksTranslation: Translation { return TranslationGen.resolveSelect(type, { 'CLINIC': `Klinik`, @@ -357,6 +366,7 @@ export const tasksTranslation: Translation { return TranslationGen.resolveSelect(type, { 'CLINIC': `Clinic`, @@ -646,6 +663,7 @@ export const tasksTranslation: Translation { const k = kind.toUpperCase() - if (k.includes('CLINIC')) return 'bg-blue-100 text-blue-700 dark:bg-blue-900/30 dark:text-blue-300' - if (k.includes('WARD')) return 'bg-orange-100 text-orange-700 dark:bg-orange-900/30 dark:text-orange-300' - if (k.includes('TEAM')) return 'bg-purple-100 text-purple-700 dark:bg-purple-900/30 dark:text-purple-300' - if (k.includes('ROOM')) return 'bg-gray-100 text-gray-700 dark:bg-gray-800 dark:text-gray-300' - if (k.includes('BED')) return 'bg-emerald-100 text-emerald-700 dark:bg-emerald-900/30 dark:text-emerald-300' + if (k === 'HOSPITAL') return 'bg-red-100 text-red-700 dark:bg-red-900/30 dark:text-red-300' + if (k === 'PRACTICE') return 'bg-indigo-100 text-indigo-700 dark:bg-indigo-900/30 dark:text-indigo-300' + if (k === 'CLINIC') return 'bg-blue-100 text-blue-700 dark:bg-blue-900/30 dark:text-blue-300' + if (k === 'TEAM') return 'bg-purple-100 text-purple-700 dark:bg-purple-900/30 dark:text-purple-300' + if (k === 'WARD') return 'bg-orange-100 text-orange-700 dark:bg-orange-900/30 dark:text-orange-300' + if (k === 'ROOM') return 'bg-gray-100 text-gray-700 dark:bg-gray-800 dark:text-gray-300' + if (k === 'BED') return 'bg-emerald-100 text-emerald-700 dark:bg-emerald-900/30 dark:text-emerald-300' return 'bg-surface-subdued text-text-tertiary' } const LocationPage: NextPage = () => { const translation = useTasksTranslation() const router = useRouter() + const { selectedRootLocationIds } = useTasksContext() const id = Array.isArray(router.query['id']) ? router.query['id'][0] : router.query['id'] const { data: locationData, isLoading: isLoadingLocation, isError: isLocationError } = useGetLocationNodeQuery( @@ -37,7 +41,7 @@ const LocationPage: NextPage = () => { ) const { data: patientsData, refetch: refetchPatients, isLoading: isLoadingPatients } = useGetPatientsQuery( - { locationId: id }, + { locationId: id, rootLocationIds: selectedRootLocationIds && selectedRootLocationIds.length > 0 ? selectedRootLocationIds : undefined }, { enabled: !!id, refetchInterval: 5000, diff --git a/web/pages/patients/index.tsx b/web/pages/patients/index.tsx index de11bd5..83a6065 100644 --- a/web/pages/patients/index.tsx +++ b/web/pages/patients/index.tsx @@ -10,14 +10,15 @@ import { useRouter } from 'next/router' const PatientsPage: NextPage = () => { const translation = useTasksTranslation() const router = useRouter() - const { selectedLocationId } = useTasksContext() + const { selectedRootLocationIds } = useTasksContext() const patientId = router.query['patientId'] as string | undefined + const firstSelectedRootLocationId = selectedRootLocationIds && selectedRootLocationIds.length > 0 ? selectedRootLocationIds[0] : undefined return ( router.replace('/patients', undefined, { shallow: true })} /> diff --git a/web/pages/settings/index.tsx b/web/pages/settings/index.tsx index de322ca..68143bc 100644 --- a/web/pages/settings/index.tsx +++ b/web/pages/settings/index.tsx @@ -16,7 +16,7 @@ import { import type { HightideTranslationLocales, ThemeType } from '@helpwave/hightide' import { useTasksContext } from '@/hooks/useTasksContext' import { useAuth } from '@/hooks/useAuth' -import { LogOut, MonitorCog, MoonIcon, SunIcon, Trash2, ClipboardList, Shield, TableProperties } from 'lucide-react' +import { LogOut, MonitorCog, MoonIcon, SunIcon, Trash2, ClipboardList, Shield, TableProperties, Building2 } from 'lucide-react' import { useRouter } from 'next/router' import clsx from 'clsx' import { removeUser } from '@/api/auth/authService' @@ -112,6 +112,28 @@ const SettingsPage: NextPage = () => {
+ {/* Organizations Section */} + {user?.organizations && ( +
+

{translation('organizations') || 'Organizations'}

+
+ +
+ {user.organizations.split(',').map((org, index) => { + const trimmedOrg = org.trim() + if (!trimmedOrg) return null + return ( + + {trimmedOrg} + {index < user.organizations!.split(',').length - 1 && ,} + + ) + })} +
+
+
+ )} + {/* System / Management */}

{translation('system')}

diff --git a/web/pages/tasks/index.tsx b/web/pages/tasks/index.tsx index f24e335..0eb6a82 100644 --- a/web/pages/tasks/index.tsx +++ b/web/pages/tasks/index.tsx @@ -5,14 +5,19 @@ import { useTasksTranslation } from '@/i18n/useTasksTranslation' import { ContentPanel } from '@/components/layout/ContentPanel' import { TaskList, type TaskViewModel } from '@/components/tasks/TaskList' import { useMemo } from 'react' -import { useGetMyTasksQuery } from '@/api/gql/generated' +import { useGetTasksQuery } from '@/api/gql/generated' import { useRouter } from 'next/router' +import { useTasksContext } from '@/hooks/useTasksContext' const TasksPage: NextPage = () => { const translation = useTasksTranslation() const router = useRouter() - const { data: queryData, refetch } = useGetMyTasksQuery( - undefined, + const { selectedRootLocationIds, user } = useTasksContext() + const { data: queryData, refetch } = useGetTasksQuery( + { + rootLocationIds: selectedRootLocationIds, + assigneeId: user?.id, + }, { refetchInterval: 5000, refetchOnWindowFocus: true, @@ -22,9 +27,9 @@ const TasksPage: NextPage = () => { const taskId = router.query['taskId'] as string | undefined const tasks: TaskViewModel[] = useMemo(() => { - if (!queryData?.me?.tasks) return [] + if (!queryData?.tasks) return [] - return queryData.me.tasks.map((task) => ({ + return queryData.tasks.map((task) => ({ id: task.id, name: task.title, description: task.description || undefined,