From 901891335b388702923a0a27cb17b2eb7e93eb01 Mon Sep 17 00:00:00 2001 From: Felix Evers Date: Mon, 22 Dec 2025 00:19:06 +0100 Subject: [PATCH 01/16] scope resolvers --- backend/api/context.py | 116 +++++++++- backend/api/resolvers/location.py | 51 ++++- backend/api/resolvers/patient.py | 119 +++++++++- backend/api/resolvers/task.py | 206 ++++++++++++++++-- backend/api/services/authorization.py | 175 +++++++++++++++ backend/api/types/location.py | 9 + backend/api/types/user.py | 17 ++ .../add_location_organizations_table.py | 35 +++ .../versions/add_user_root_locations.py | 36 +++ backend/database/models/__init__.py | 4 +- backend/database/models/location.py | 15 +- backend/database/models/user.py | 15 +- backend/scaffold.py | 75 ++++++- keycloak/tasks.json | 31 ++- scaffold/initial.json | 2 + web/api/gql/generated.ts | 9 +- web/api/graphql/GlobalData.graphql | 6 + web/components/layout/Page.tsx | 20 +- web/components/patients/PatientDetailView.tsx | 76 ++++++- web/hooks/useTasksContext.tsx | 6 +- 20 files changed, 987 insertions(+), 36 deletions(-) create mode 100644 backend/api/services/authorization.py create mode 100644 backend/database/migrations/versions/add_location_organizations_table.py create mode 100644 backend/database/migrations/versions/add_user_root_locations.py diff --git a/backend/api/context.py b/backend/api/context.py index b736cc4..c0084cb 100644 --- a/backend/api/context.py +++ b/backend/api/context.py @@ -1,22 +1,60 @@ +import asyncio from typing import Any import strawberry from auth import get_user_payload -from database.models.user import User +from database.models.location import LocationNode +from database.models.user import User, user_root_locations from database.session import get_db_session from fastapi import Depends -from sqlalchemy import select +from sqlalchemy import delete, select from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession from starlette.requests import HTTPConnection from strawberry.fastapi import BaseContext +class LockedAsyncSession: + def __init__(self, session: AsyncSession, lock: asyncio.Lock): + self._session = session + self._lock = lock + + async def execute(self, *args, **kwargs): + async with self._lock: + return await self._session.execute(*args, **kwargs) + + async def commit(self, *args, **kwargs): + async with self._lock: + return await self._session.commit(*args, **kwargs) + + async def rollback(self, *args, **kwargs): + async with self._lock: + return await self._session.rollback(*args, **kwargs) + + async def flush(self, *args, **kwargs): + async with self._lock: + return await self._session.flush(*args, **kwargs) + + async def refresh(self, *args, **kwargs): + async with self._lock: + return await self._session.refresh(*args, **kwargs) + + def add(self, *args, **kwargs): + return self._session.add(*args, **kwargs) + + def __getattr__(self, name): + return getattr(self._session, name) + + class Context(BaseContext): def __init__(self, db: AsyncSession, user: "User | None" = None): super().__init__() - self.db = db + self._db = db self.user = user + self._accessible_location_ids: set[str] | None = None + self._accessible_location_ids_lock = asyncio.Lock() + self._db_lock = asyncio.Lock() + self.db = LockedAsyncSession(db, self._db_lock) Info = strawberry.Info[Context, Any] @@ -95,4 +133,76 @@ async def get_context( await session.commit() await session.refresh(db_user) + if db_user: + await _update_user_root_locations(session, db_user, organizations) + return Context(db=session, user=db_user) + + +async def _update_user_root_locations( + session: AsyncSession, user: User, organizations: str | None +) -> None: + organization_ids: list[str] = [] + if organizations: + organization_ids = [ + org_id.strip() for org_id in organizations.split(",") if org_id.strip() + ] + + root_location_ids: list[str] = [] + + if organization_ids: + result = await session.execute( + select(LocationNode).where(LocationNode.id.in_(organization_ids)) + ) + found_locations = result.scalars().all() + root_location_ids = [loc.id for loc in found_locations] + + found_ids = {loc.id for loc in found_locations} + for org_id in organization_ids: + if org_id not in found_ids: + new_location = LocationNode( + id=org_id, + title=f"Organization {org_id[:8]}", + kind="CLINIC", + parent_id=None, + ) + session.add(new_location) + root_location_ids.append(org_id) + + if not root_location_ids: + personal_org_title = f"{user.username}'s Organization" + result = await session.execute( + select(LocationNode).where( + LocationNode.title == personal_org_title, + LocationNode.parent_id.is_(None), + ) + ) + personal_location = result.scalars().first() + + if not personal_location: + personal_location = LocationNode( + title=personal_org_title, + kind="CLINIC", + parent_id=None, + ) + session.add(personal_location) + await session.flush() + + root_location_ids = [personal_location.id] + + await session.execute( + delete(user_root_locations).where(user_root_locations.c.user_id == user.id) + ) + + if root_location_ids: + from sqlalchemy.dialects.postgresql import insert + stmt = insert(user_root_locations).values( + [ + {"user_id": user.id, "location_id": loc_id} + for loc_id in root_location_ids + ] + ) + stmt = stmt.on_conflict_do_nothing(index_elements=["user_id", "location_id"]) + await session.execute(stmt) + + await session.commit() diff --git a/backend/api/resolvers/location.py b/backend/api/resolvers/location.py index 819b84a..5df6ad8 100644 --- a/backend/api/resolvers/location.py +++ b/backend/api/resolvers/location.py @@ -1,8 +1,10 @@ import strawberry from api.context import Info from api.inputs import LocationType +from api.services.authorization import AuthorizationService from api.types.location import LocationNodeType from database import models +from graphql import GraphQLError from sqlalchemy import select @@ -10,9 +12,18 @@ class LocationQuery: @strawberry.field async def location_roots(self, info: Info) -> list[LocationNodeType]: + auth_service = AuthorizationService(info.context.db) + accessible_location_ids = await auth_service.get_user_accessible_location_ids( + info.context.user, info.context + ) + + if not accessible_location_ids: + return [] + result = await info.context.db.execute( select(models.LocationNode).where( models.LocationNode.parent_id.is_(None), + models.LocationNode.id.in_(accessible_location_ids), ), ) return result.scalars().all() @@ -26,7 +37,20 @@ async def location_node( result = await info.context.db.execute( select(models.LocationNode).where(models.LocationNode.id == id), ) - return result.scalars().first() + location = result.scalars().first() + + if location: + auth_service = AuthorizationService(info.context.db) + accessible_location_ids = await auth_service.get_user_accessible_location_ids( + info.context.user, info.context + ) + if location.id not in accessible_location_ids: + raise GraphQLError( + "Forbidden: You do not have access to this location", + extensions={"code": "FORBIDDEN"}, + ) + + return location @strawberry.field async def location_nodes( @@ -39,8 +63,22 @@ async def location_nodes( order_by_name: bool = False, ) -> list[LocationNodeType]: db = info.context.db + + auth_service = AuthorizationService(db) + accessible_location_ids = await auth_service.get_user_accessible_location_ids( + info.context.user, info.context + ) + + if not accessible_location_ids: + return [] if recursive and parent_id: + if parent_id not in accessible_location_ids: + raise GraphQLError( + "Forbidden: You do not have access to this location", + extensions={"code": "FORBIDDEN"}, + ) + cte = ( select(models.LocationNode) .where(models.LocationNode.id == parent_id) @@ -52,10 +90,17 @@ async def location_nodes( models.LocationNode.parent_id == cte.c.id, ) cte = cte.union_all(parent) - query = select(cte) + query = select(cte).where(cte.c.id.in_(accessible_location_ids)) else: - query = select(models.LocationNode) + query = select(models.LocationNode).where( + models.LocationNode.id.in_(accessible_location_ids) + ) if parent_id: + if parent_id not in accessible_location_ids: + raise GraphQLError( + "Forbidden: You do not have access to this location", + extensions={"code": "FORBIDDEN"}, + ) query = query.where(models.LocationNode.parent_id == parent_id) if kind: diff --git a/backend/api/resolvers/patient.py b/backend/api/resolvers/patient.py index bb07f57..0f72ec3 100644 --- a/backend/api/resolvers/patient.py +++ b/backend/api/resolvers/patient.py @@ -5,11 +5,13 @@ from api.context import Info from api.inputs import CreatePatientInput, PatientState, UpdatePatientInput from api.resolvers.base import BaseMutationResolver, BaseSubscriptionResolver +from api.services.authorization import AuthorizationService from api.services.checksum import validate_checksum from api.services.location import LocationService from api.services.property import PropertyService from api.types.patient import PatientType from database import models +from graphql import GraphQLError from sqlalchemy import select from sqlalchemy.orm import aliased, selectinload @@ -31,7 +33,15 @@ async def patient( selectinload(models.Patient.teams), ), ) - return result.scalars().first() + patient = result.scalars().first() + if patient: + auth_service = AuthorizationService(info.context.db) + if not await auth_service.can_access_patient(info.context.user, patient, info.context): + raise GraphQLError( + "Forbidden: You do not have access to this patient", + extensions={"code": "FORBIDDEN"}, + ) + return patient @strawberry.field async def patients( @@ -88,6 +98,14 @@ async def patients( .distinct() ) + auth_service = AuthorizationService(info.context.db) + accessible_location_ids = await auth_service.get_user_accessible_location_ids( + info.context.user, info.context + ) + query = auth_service.filter_patients_by_access( + info.context.user, query, accessible_location_ids + ) + result = await info.context.db.execute(query) return result.scalars().all() @@ -106,6 +124,13 @@ async def recent_patients( ) .limit(limit) ) + auth_service = AuthorizationService(info.context.db) + accessible_location_ids = await auth_service.get_user_accessible_location_ids( + info.context.user, info.context + ) + query = auth_service.filter_patients_by_access( + info.context.user, query, accessible_location_ids + ) result = await info.context.db.execute(query) return result.scalars().all() @@ -133,13 +158,41 @@ async def create_patient( data.state.value if data.state else PatientState.WAIT.value ) + auth_service = AuthorizationService(db) + accessible_location_ids = await auth_service.get_user_accessible_location_ids( + info.context.user, info.context + ) + + if not accessible_location_ids: + raise GraphQLError( + "Forbidden: You do not have access to create patients", + extensions={"code": "FORBIDDEN"}, + ) + + if data.clinic_id not in accessible_location_ids: + raise GraphQLError( + "Forbidden: You do not have access to this clinic", + extensions={"code": "FORBIDDEN"}, + ) + await location_service.validate_and_get_clinic(data.clinic_id) if data.position_id: + if data.position_id not in accessible_location_ids: + raise GraphQLError( + "Forbidden: You do not have access to this position", + extensions={"code": "FORBIDDEN"}, + ) await location_service.validate_and_get_position(data.position_id) teams = [] if data.team_ids: + for team_id in data.team_ids: + if team_id not in accessible_location_ids: + raise GraphQLError( + "Forbidden: You do not have access to one or more teams", + extensions={"code": "FORBIDDEN"}, + ) teams = await location_service.validate_and_get_teams( data.team_ids ) @@ -159,11 +212,22 @@ async def create_patient( new_patient.teams = teams if data.assigned_location_ids: + for loc_id in data.assigned_location_ids: + if loc_id not in accessible_location_ids: + raise GraphQLError( + "Forbidden: You do not have access to one or more assigned locations", + extensions={"code": "FORBIDDEN"}, + ) locations = await location_service.get_locations_by_ids( data.assigned_location_ids ) new_patient.assigned_locations = locations elif data.assigned_location_id: + if data.assigned_location_id not in accessible_location_ids: + raise GraphQLError( + "Forbidden: You do not have access to this assigned location", + extensions={"code": "FORBIDDEN"}, + ) location = await location_service.get_location_by_id( data.assigned_location_id ) @@ -204,6 +268,13 @@ async def update_patient( if not patient: raise Exception("Patient not found") + auth_service = AuthorizationService(db) + if not await auth_service.can_access_patient(info.context.user, patient, info.context): + raise GraphQLError( + "Forbidden: You do not have access to this patient", + extensions={"code": "FORBIDDEN"}, + ) + if data.checksum: validate_checksum(patient, data.checksum, "Patient") @@ -217,8 +288,16 @@ async def update_patient( patient.sex = data.sex.value location_service = PatientMutation._get_location_service(db) + accessible_location_ids = await auth_service.get_user_accessible_location_ids( + info.context.user + ) if data.clinic_id is not None: + if data.clinic_id not in accessible_location_ids: + raise GraphQLError( + "Forbidden: You do not have access to this clinic", + extensions={"code": "FORBIDDEN"}, + ) await location_service.validate_and_get_clinic(data.clinic_id) patient.clinic_id = data.clinic_id @@ -226,6 +305,11 @@ async def update_patient( if data.position_id is None: patient.position_id = None else: + if data.position_id not in accessible_location_ids: + raise GraphQLError( + "Forbidden: You do not have access to this position", + extensions={"code": "FORBIDDEN"}, + ) await location_service.validate_and_get_position( data.position_id ) @@ -235,16 +319,33 @@ async def update_patient( if data.team_ids is None or len(data.team_ids) == 0: patient.teams = [] else: + for team_id in data.team_ids: + if team_id not in accessible_location_ids: + raise GraphQLError( + "Forbidden: You do not have access to one or more teams", + extensions={"code": "FORBIDDEN"}, + ) patient.teams = await location_service.validate_and_get_teams( data.team_ids ) if data.assigned_location_ids is not None: + for loc_id in data.assigned_location_ids: + if loc_id not in accessible_location_ids: + raise GraphQLError( + "Forbidden: You do not have access to one or more assigned locations", + extensions={"code": "FORBIDDEN"}, + ) locations = await location_service.get_locations_by_ids( data.assigned_location_ids ) patient.assigned_locations = locations elif data.assigned_location_id is not None: + if data.assigned_location_id not in accessible_location_ids: + raise GraphQLError( + "Forbidden: You do not have access to this assigned location", + extensions={"code": "FORBIDDEN"}, + ) location = await location_service.get_location_by_id( data.assigned_location_id ) @@ -269,6 +370,14 @@ async def delete_patient(self, info: Info, id: strawberry.ID) -> bool: patient = await repo.get_by_id(id) if not patient: return False + + auth_service = AuthorizationService(info.context.db) + if not await auth_service.can_access_patient(info.context.user, patient, info.context): + raise GraphQLError( + "Forbidden: You do not have access to this patient", + extensions={"code": "FORBIDDEN"}, + ) + await BaseMutationResolver.delete_entity( info, patient, models.Patient, "patient" ) @@ -293,6 +402,14 @@ async def _update_patient_state( patient = result.scalars().first() if not patient: raise Exception("Patient not found") + + auth_service = AuthorizationService(db) + if not await auth_service.can_access_patient(info.context.user, patient, info.context): + raise GraphQLError( + "Forbidden: You do not have access to this patient", + extensions={"code": "FORBIDDEN"}, + ) + patient.state = state.value await BaseMutationResolver.update_and_notify( info, patient, models.Patient, "patient" diff --git a/backend/api/resolvers/task.py b/backend/api/resolvers/task.py index 83dd089..b02345b 100644 --- a/backend/api/resolvers/task.py +++ b/backend/api/resolvers/task.py @@ -5,21 +5,36 @@ from api.context import Info from api.inputs import CreateTaskInput, UpdateTaskInput from api.resolvers.base import BaseMutationResolver, BaseSubscriptionResolver +from api.services.authorization import AuthorizationService from api.services.base import BaseRepository from api.services.checksum import validate_checksum from api.services.datetime import normalize_datetime_to_utc from api.services.property import PropertyService from api.types.task import TaskType from database import models +from graphql import GraphQLError from sqlalchemy import desc, select +from sqlalchemy.orm import aliased, selectinload @strawberry.type class TaskQuery: @strawberry.field async def task(self, info: Info, id: strawberry.ID) -> TaskType | None: - repo = BaseRepository(info.context.db, models.Task) - return await repo.get_by_id(id) + result = await info.context.db.execute( + select(models.Task) + .where(models.Task.id == id) + .options(selectinload(models.Task.patient).selectinload(models.Patient.assigned_locations)) + ) + task = result.scalars().first() + if task and task.patient: + auth_service = AuthorizationService(info.context.db) + if not await auth_service.can_access_patient(info.context.user, task.patient, info.context): + raise GraphQLError( + "Forbidden: You do not have access to this task", + extensions={"code": "FORBIDDEN"}, + ) + return task @strawberry.field async def tasks( @@ -28,12 +43,79 @@ async def tasks( patient_id: strawberry.ID | None = None, assignee_id: strawberry.ID | None = None, ) -> list[TaskType]: - query = select(models.Task) + auth_service = AuthorizationService(info.context.db) + if patient_id: - query = query.where(models.Task.patient_id == patient_id) + if not await auth_service.can_access_patient_id(info.context.user, patient_id, info.context): + raise GraphQLError( + "Forbidden: You do not have access to this patient's tasks", + extensions={"code": "FORBIDDEN"}, + ) + + query = select(models.Task).options( + selectinload(models.Task.patient).selectinload(models.Patient.assigned_locations) + ).where(models.Task.patient_id == patient_id) + + if assignee_id: + query = query.where(models.Task.assignee_id == assignee_id) + + result = await info.context.db.execute(query) + return result.scalars().all() + + accessible_location_ids = await auth_service.get_user_accessible_location_ids( + info.context.user, info.context + ) + + if not accessible_location_ids: + return [] + + patient_locations = aliased(models.patient_locations) + patient_teams = aliased(models.patient_teams) + + cte = ( + select(models.LocationNode.id) + .where(models.LocationNode.id.in_(accessible_location_ids)) + .cte(name="accessible_locations", recursive=True) + ) + + children = select(models.LocationNode.id).join( + cte, models.LocationNode.parent_id == cte.c.id + ) + cte = cte.union_all(children) + + query = ( + select(models.Task) + .options( + selectinload(models.Task.patient).selectinload(models.Patient.assigned_locations) + ) + .join(models.Patient, models.Task.patient_id == models.Patient.id) + .outerjoin( + patient_locations, + models.Patient.id == patient_locations.c.patient_id, + ) + .outerjoin( + patient_teams, + models.Patient.id == patient_teams.c.patient_id, + ) + .where( + (models.Patient.clinic_id.in_(select(cte.c.id))) + | ( + models.Patient.position_id.isnot(None) + & models.Patient.position_id.in_(select(cte.c.id)) + ) + | ( + models.Patient.assigned_location_id.isnot(None) + & models.Patient.assigned_location_id.in_(select(cte.c.id)) + ) + | (patient_locations.c.location_id.in_(select(cte.c.id))) + | (patient_teams.c.location_id.in_(select(cte.c.id))) + ) + .distinct() + ) + if assignee_id: query = query.where(models.Task.assignee_id == assignee_id) - + result = await info.context.db.execute(query) return result.scalars().all() @@ -43,11 +125,61 @@ async def recent_tasks( info: Info, limit: int = 10, ) -> list[TaskType]: - result = await info.context.db.execute( + auth_service = AuthorizationService(info.context.db) + accessible_location_ids = await auth_service.get_user_accessible_location_ids( + info.context.user, info.context + ) + + if not accessible_location_ids: + return [] + + patient_locations = aliased(models.patient_locations) + patient_teams = aliased(models.patient_teams) + + cte = ( + select(models.LocationNode.id) + .where(models.LocationNode.id.in_(accessible_location_ids)) + .cte(name="accessible_locations", recursive=True) + ) + + children = select(models.LocationNode.id).join( + cte, models.LocationNode.parent_id == cte.c.id + ) + cte = cte.union_all(children) + + query = ( select(models.Task) + .options( + selectinload(models.Task.patient).selectinload(models.Patient.assigned_locations) + ) + .join(models.Patient, models.Task.patient_id == models.Patient.id) + .outerjoin( + patient_locations, + models.Patient.id == patient_locations.c.patient_id, + ) + .outerjoin( + patient_teams, + models.Patient.id == patient_teams.c.patient_id, + ) + .where( + (models.Patient.clinic_id.in_(select(cte.c.id))) + | ( + models.Patient.position_id.isnot(None) + & models.Patient.position_id.in_(select(cte.c.id)) + ) + | ( + models.Patient.assigned_location_id.isnot(None) + & models.Patient.assigned_location_id.in_(select(cte.c.id)) + ) + | (patient_locations.c.location_id.in_(select(cte.c.id))) + | (patient_teams.c.location_id.in_(select(cte.c.id))) + ) .order_by(desc(models.Task.update_date)) - .limit(limit), + .limit(limit) + .distinct() ) + + result = await info.context.db.execute(query) return result.scalars().all() @@ -60,6 +192,13 @@ def _get_property_service(db) -> PropertyService: @strawberry.mutation @audit_log("create_task") async def create_task(self, info: Info, data: CreateTaskInput) -> TaskType: + auth_service = AuthorizationService(info.context.db) + if not await auth_service.can_access_patient_id(info.context.user, data.patient_id, info.context): + raise GraphQLError( + "Forbidden: You do not have access to create tasks for this patient", + extensions={"code": "FORBIDDEN"}, + ) + new_task = models.Task( title=data.title, description=data.description, @@ -92,8 +231,22 @@ async def update_task( data: UpdateTaskInput, ) -> TaskType: db = info.context.db - repo = BaseMutationResolver.get_repository(db, models.Task) - task = await repo.get_by_id_or_raise(id, "Task not found") + result = await db.execute( + select(models.Task) + .where(models.Task.id == id) + .options(selectinload(models.Task.patient).selectinload(models.Patient.assigned_locations)) + ) + task = result.scalars().first() + if not task: + raise Exception("Task not found") + + if task.patient: + auth_service = AuthorizationService(db) + if not await auth_service.can_access_patient(info.context.user, task.patient, info.context): + raise GraphQLError( + "Forbidden: You do not have access to this task", + extensions={"code": "FORBIDDEN"}, + ) if data.checksum: validate_checksum(task, data.checksum, "Task") @@ -134,8 +287,23 @@ async def _update_task_field( field_updater, ) -> TaskType: db = info.context.db - repo = BaseMutationResolver.get_repository(db, models.Task) - task = await repo.get_by_id_or_raise(id, "Task not found") + result = await db.execute( + select(models.Task) + .where(models.Task.id == id) + .options(selectinload(models.Task.patient).selectinload(models.Patient.assigned_locations)) + ) + task = result.scalars().first() + if not task: + raise Exception("Task not found") + + if task.patient: + auth_service = AuthorizationService(db) + if not await auth_service.can_access_patient(info.context.user, task.patient, info.context): + raise GraphQLError( + "Forbidden: You do not have access to this task", + extensions={"code": "FORBIDDEN"}, + ) + field_updater(task) await BaseMutationResolver.update_and_notify( info, task, models.Task, "task", "patient", task.patient_id @@ -187,11 +355,23 @@ async def reopen_task(self, info: Info, id: strawberry.ID) -> TaskType: @audit_log("delete_task") async def delete_task(self, info: Info, id: strawberry.ID) -> bool: db = info.context.db - repo = BaseMutationResolver.get_repository(db, models.Task) - task = await repo.get_by_id(id) + result = await db.execute( + select(models.Task) + .where(models.Task.id == id) + .options(selectinload(models.Task.patient).selectinload(models.Patient.assigned_locations)) + ) + task = result.scalars().first() if not task: return False + if task.patient: + auth_service = AuthorizationService(db) + if not await auth_service.can_access_patient(info.context.user, task.patient, info.context): + raise GraphQLError( + "Forbidden: You do not have access to this task", + extensions={"code": "FORBIDDEN"}, + ) + patient_id = task.patient_id await BaseMutationResolver.delete_entity( info, task, models.Task, "task", "patient", patient_id diff --git a/backend/api/services/authorization.py b/backend/api/services/authorization.py new file mode 100644 index 0000000..062167e --- /dev/null +++ b/backend/api/services/authorization.py @@ -0,0 +1,175 @@ +import asyncio +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import aliased, selectinload + +from database import models + + +class AuthorizationService: + def __init__(self, db: AsyncSession): + self.db = db + + async def get_user_accessible_location_ids( + self, user: models.User | None, context=None + ) -> set[str]: + if context and hasattr(context, '_accessible_location_ids') and context._accessible_location_ids is not None: + return context._accessible_location_ids + + if not context or not hasattr(context, '_accessible_location_ids_lock'): + return await self._compute_accessible_location_ids(user, context) + + async with context._accessible_location_ids_lock: + if context._accessible_location_ids is not None: + return context._accessible_location_ids + return await self._compute_accessible_location_ids(user, context) + + async def _compute_accessible_location_ids( + self, user: models.User | None, context=None + ) -> set[str]: + if not user: + result = set() + if context: + context._accessible_location_ids = result + return result + + result = await self.db.execute( + select(models.user_root_locations.c.location_id).where( + models.user_root_locations.c.user_id == user.id + ) + ) + rows = result.fetchall() + root_location_ids = {row[0] for row in rows} + + if not root_location_ids: + result = set() + if context: + context._accessible_location_ids = result + return result + + cte = ( + select(models.LocationNode.id) + .where(models.LocationNode.id.in_(root_location_ids)) + .cte(name="accessible_locations", recursive=True) + ) + + children = select(models.LocationNode.id).join( + cte, models.LocationNode.parent_id == cte.c.id + ) + cte = cte.union_all(children) + + result = await self.db.execute(select(cte.c.id)) + rows = result.fetchall() + accessible_ids = {row[0] for row in rows} + + if context: + context._accessible_location_ids = accessible_ids + + return accessible_ids + + async def can_access_patient( + self, user: models.User | None, patient: models.Patient, context=None + ) -> bool: + if not user: + return False + + accessible_location_ids = await self.get_user_accessible_location_ids(user, context) + + if not accessible_location_ids: + return False + + if patient.clinic_id in accessible_location_ids: + return True + + if patient.position_id and patient.position_id in accessible_location_ids: + return True + + if ( + patient.assigned_location_id + and patient.assigned_location_id in accessible_location_ids + ): + return True + + if patient.assigned_locations: + for location in patient.assigned_locations: + if location.id in accessible_location_ids: + return True + + if patient.teams: + for team in patient.teams: + if team.id in accessible_location_ids: + return True + + return False + + async def can_access_patient_id( + self, user: models.User | None, patient_id: str, context=None + ) -> bool: + if not user: + return False + + result = await self.db.execute( + select(models.Patient) + .where(models.Patient.id == patient_id) + .options( + selectinload(models.Patient.assigned_locations), + selectinload(models.Patient.teams), + ) + ) + patient = result.scalars().first() + + if not patient: + return False + + return await self.can_access_patient(user, patient, context) + + def filter_patients_by_access( + self, user: models.User | None, query, accessible_location_ids: set[str] | None = None + ): + if not user: + return query.where(False) + + if accessible_location_ids is None: + return query + + if not accessible_location_ids: + return query.where(False) + + cte = ( + select(models.LocationNode.id) + .where(models.LocationNode.id.in_(accessible_location_ids)) + .cte(name="accessible_locations", recursive=True) + ) + + children = select(models.LocationNode.id).join( + cte, models.LocationNode.parent_id == cte.c.id + ) + cte = cte.union_all(children) + + patient_locations = aliased(models.patient_locations) + patient_teams = aliased(models.patient_teams) + + return ( + query.outerjoin( + patient_locations, + models.Patient.id == patient_locations.c.patient_id, + ) + .outerjoin( + patient_teams, + models.Patient.id == patient_teams.c.patient_id, + ) + .where( + (models.Patient.clinic_id.in_(select(cte.c.id))) + | ( + models.Patient.position_id.isnot(None) + & models.Patient.position_id.in_(select(cte.c.id)) + ) + | ( + models.Patient.assigned_location_id.isnot(None) + & models.Patient.assigned_location_id.in_(select(cte.c.id)) + ) + | (patient_locations.c.location_id.in_(select(cte.c.id))) + | (patient_teams.c.location_id.in_(select(cte.c.id))) + ) + .distinct() + ) diff --git a/backend/api/types/location.py b/backend/api/types/location.py index 6704a2a..3fe2bd3 100644 --- a/backend/api/types/location.py +++ b/backend/api/types/location.py @@ -63,3 +63,12 @@ async def patients( ), ) return result.scalars().all() + + @strawberry.field + async def organization_ids(self, info: Info) -> list[str]: + result = await info.context.db.execute( + select(models.location_organizations.c.organization_id).where( + models.location_organizations.c.location_id == self.id, + ), + ) + return [row[0] for row in result.all()] diff --git a/backend/api/types/user.py b/backend/api/types/user.py index d0b3a18..14d86ee 100644 --- a/backend/api/types/user.py +++ b/backend/api/types/user.py @@ -3,8 +3,10 @@ import strawberry from database import models from sqlalchemy import select +from sqlalchemy.orm import selectinload if TYPE_CHECKING: + from api.types.location import LocationNodeType from api.types.task import TaskType @@ -35,3 +37,18 @@ async def tasks( select(models.Task).where(models.Task.assignee_id == self.id), ) return result.scalars().all() + + @strawberry.field + async def root_locations( + self, + info, + ) -> list[Annotated["LocationNodeType", strawberry.lazy("api.types.location")]]: + result = await info.context.db.execute( + select(models.User) + .where(models.User.id == self.id) + .options(selectinload(models.User.root_locations)) + ) + user = result.scalars().first() + if not user: + return [] + return user.root_locations or [] diff --git a/backend/database/migrations/versions/add_location_organizations_table.py b/backend/database/migrations/versions/add_location_organizations_table.py new file mode 100644 index 0000000..c6e6c56 --- /dev/null +++ b/backend/database/migrations/versions/add_location_organizations_table.py @@ -0,0 +1,35 @@ +"""Add location organizations table. + +Revision ID: add_location_organizations_table +Revises: add_user_root_locations +Create Date: 2025-01-17 00:00:00.000000 +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "add_location_organizations_table" +down_revision: Union[str, Sequence[str], None] = "add_user_root_locations" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + op.create_table( + "location_organizations", + sa.Column("location_id", sa.String(), nullable=False), + sa.Column("organization_id", sa.String(), nullable=False), + sa.ForeignKeyConstraint(["location_id"], ["location_nodes.id"]), + sa.PrimaryKeyConstraint("location_id", "organization_id"), + ) + + +def downgrade() -> None: + """Downgrade schema.""" + op.drop_table("location_organizations") + diff --git a/backend/database/migrations/versions/add_user_root_locations.py b/backend/database/migrations/versions/add_user_root_locations.py new file mode 100644 index 0000000..afbba08 --- /dev/null +++ b/backend/database/migrations/versions/add_user_root_locations.py @@ -0,0 +1,36 @@ +"""Add user root locations table. + +Revision ID: add_user_root_locations +Revises: add_patient_location_mapping +Create Date: 2025-01-16 00:00:00.000000 +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "add_user_root_locations" +down_revision: Union[str, Sequence[str], None] = "add_patient_location_mapping" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + op.create_table( + "user_root_locations", + sa.Column("user_id", sa.String(), nullable=False), + sa.Column("location_id", sa.String(), nullable=False), + sa.ForeignKeyConstraint(["user_id"], ["users.id"]), + sa.ForeignKeyConstraint(["location_id"], ["location_nodes.id"]), + sa.PrimaryKeyConstraint("user_id", "location_id"), + ) + + +def downgrade() -> None: + """Downgrade schema.""" + op.drop_table("user_root_locations") + diff --git a/backend/database/models/__init__.py b/backend/database/models/__init__.py index faea722..5de948a 100644 --- a/backend/database/models/__init__.py +++ b/backend/database/models/__init__.py @@ -1,5 +1,5 @@ -from .user import User # noqa: F401 -from .location import LocationNode # noqa: F401 +from .user import User, user_root_locations # noqa: F401 +from .location import LocationNode, location_organizations # noqa: F401 from .patient import Patient, patient_locations, patient_teams # noqa: F401 from .task import Task, task_dependencies # noqa: F401 from .property import PropertyDefinition, PropertyValue # noqa: F401 diff --git a/backend/database/models/location.py b/backend/database/models/location.py index 61d0a41..7efaa51 100644 --- a/backend/database/models/location.py +++ b/backend/database/models/location.py @@ -4,11 +4,19 @@ from typing import TYPE_CHECKING from database.models.base import Base -from sqlalchemy import ForeignKey, String +from sqlalchemy import Column, ForeignKey, String, Table from sqlalchemy.orm import Mapped, mapped_column, relationship if TYPE_CHECKING: from .patient import Patient + from .user import User + +location_organizations = Table( + "location_organizations", + Base.metadata, + Column("location_id", ForeignKey("location_nodes.id"), primary_key=True), + Column("organization_id", String, primary_key=True), +) class LocationNode(Base): @@ -60,3 +68,8 @@ class LocationNode(Base): secondary="patient_teams", back_populates="teams", ) + root_users: Mapped[list[User]] = relationship( + "User", + secondary="user_root_locations", + back_populates="root_locations", + ) diff --git a/backend/database/models/user.py b/backend/database/models/user.py index 9b79212..c4231f8 100644 --- a/backend/database/models/user.py +++ b/backend/database/models/user.py @@ -4,12 +4,20 @@ from typing import TYPE_CHECKING from database.models.base import Base -from sqlalchemy import String +from sqlalchemy import Column, ForeignKey, String, Table from sqlalchemy.orm import Mapped, mapped_column, relationship if TYPE_CHECKING: + from .location import LocationNode from .task import Task +user_root_locations = Table( + "user_root_locations", + Base.metadata, + Column("user_id", ForeignKey("users.id"), primary_key=True), + Column("location_id", ForeignKey("location_nodes.id"), primary_key=True), +) + class User(Base): __tablename__ = "users" @@ -32,3 +40,8 @@ class User(Base): organizations: Mapped[str | None] = mapped_column(String, nullable=True) tasks: Mapped[list[Task]] = relationship("Task", back_populates="assignee") + root_locations: Mapped[list[LocationNode]] = relationship( + "LocationNode", + secondary=user_root_locations, + back_populates="root_users", + ) diff --git a/backend/scaffold.py b/backend/scaffold.py index 2855a9c..7431f57 100644 --- a/backend/scaffold.py +++ b/backend/scaffold.py @@ -5,7 +5,8 @@ from api.inputs import LocationType from config import LOGGER, SCAFFOLD_DIRECTORY -from database.models.location import LocationNode +from database.models.location import LocationNode, location_organizations +from database.models.user import User from database.session import async_session from sqlalchemy import select @@ -71,6 +72,7 @@ async def load_scaffold_data() -> None: logger.info( f"Successfully loaded scaffold data from {json_file}" ) + await _assign_clinics_to_users(session) except json.JSONDecodeError as e: logger.error(f"Failed to parse JSON file {json_file}: {e}") await session.rollback() @@ -118,7 +120,78 @@ async def _create_location_tree( location_id = location.id logger.debug(f"Created location '{name}' ({location_type.value})") + organization_ids = data.get("organization_ids", []) + if organization_ids: + for org_id in organization_ids: + stmt = select(location_organizations).where( + location_organizations.c.location_id == location_id, + location_organizations.c.organization_id == org_id, + ) + result = await session.execute(stmt) + existing_org = result.first() + if not existing_org: + await session.execute( + location_organizations.insert().values( + location_id=location_id, organization_id=org_id + ) + ) + logger.debug( + f"Assigned organization '{org_id}' to location '{name}'" + ) + for child_data in children: await _create_location_tree(session, child_data, location_id) return location_id + + +async def _assign_clinics_to_users(session: Any) -> None: + result = await session.execute(select(User)) + users = result.scalars().all() + + for user in users: + if not user.organizations: + continue + + org_ids = [ + org_id.strip() + for org_id in user.organizations.split(",") + if org_id.strip() + ] + + for org_id in org_ids: + clinic_result = await session.execute( + select(LocationNode) + .join( + location_organizations, + LocationNode.id == location_organizations.c.location_id, + ) + .where( + LocationNode.kind == "CLINIC", + location_organizations.c.organization_id == org_id, + ) + .limit(1) + ) + clinic = clinic_result.scalar_one_or_none() + + if clinic: + from database.models.user import user_root_locations + + existing_result = await session.execute( + select(user_root_locations).where( + user_root_locations.c.user_id == user.id, + user_root_locations.c.location_id == clinic.id, + ) + ) + existing = existing_result.first() + if not existing: + await session.execute( + user_root_locations.insert().values( + user_id=user.id, location_id=clinic.id + ) + ) + logger.info( + f"Assigned clinic '{clinic.title}' to user '{user.username}' based on organization '{org_id}'" + ) + + await session.commit() diff --git a/keycloak/tasks.json b/keycloak/tasks.json index f560215..c4cdb6b 100644 --- a/keycloak/tasks.json +++ b/keycloak/tasks.json @@ -400,7 +400,36 @@ "requiredActions" : [ ], "realmRoles" : [ "default-roles-tasks" ], "notBefore" : 0, - "groups" : [ ] + "groups" : [ ], + "attributes" : { + "organization" : [ "test-org-1", "test-org-2" ] + } + }, { + "id" : "15d5c9b3-2d30-4a0d-9c2d-ed3212d5e2db", + "username" : "test2", + "firstName" : "Jane", + "lastName" : "Smith", + "email" : "jane.smith@helpwave.de", + "emailVerified" : true, + "enabled" : true, + "createdTimestamp" : 1764546986928, + "totp" : false, + "credentials" : [ { + "id" : "f15dab03-6c5b-4b53-ad9e-91d1024f2a1c", + "type" : "password", + "userLabel" : "My password", + "createdDate" : 1764546995066, + "secretData" : "{\"value\":\"Yg9msci7ctlF0zTXiQe+vjPkTrDq7lBkKyeWcxHLxlE=\",\"salt\":\"mYbxcfft3FMwjIDx03aHdw==\",\"additionalParameters\":{}}", + "credentialData" : "{\"hashIterations\":5,\"algorithm\":\"argon2\",\"additionalParameters\":{\"hashLength\":[\"32\"],\"memory\":[\"7168\"],\"type\":[\"id\"],\"version\":[\"1.3\"],\"parallelism\":[\"1\"]}}" + } ], + "disableableCredentialTypes" : [ ], + "requiredActions" : [ ], + "realmRoles" : [ "default-roles-tasks" ], + "notBefore" : 0, + "groups" : [ ], + "attributes" : { + "organization" : [ "test-org-2" ] + } } ], "scopeMappings" : [ { "clientScope" : "offline_access", diff --git a/scaffold/initial.json b/scaffold/initial.json index afeb895..5c2f141 100644 --- a/scaffold/initial.json +++ b/scaffold/initial.json @@ -6,6 +6,7 @@ { "name": "Cardiology", "type": "CLINIC", + "organization_ids": ["test-org-1"], "children": [ { "name": "Cardio-Diagnostics Team", "type": "TEAM" }, { "name": "Interventional Team", "type": "TEAM" }, @@ -36,6 +37,7 @@ { "name": "Orthopedics", "type": "CLINIC", + "organization_ids": ["test-org-2"], "children": [ { "name": "Surgical Excellence Team", "type": "TEAM" }, { "name": "Physio-Support Team", "type": "TEAM" }, diff --git a/web/api/gql/generated.ts b/web/api/gql/generated.ts index fd85967..3bd3aca 100644 --- a/web/api/gql/generated.ts +++ b/web/api/gql/generated.ts @@ -426,6 +426,7 @@ export type UserType = { lastname?: Maybe; name: Scalars['String']['output']; organizations?: Maybe; + rootLocations: Array; tasks: Array; title?: Maybe; username: Scalars['String']['output']; @@ -483,7 +484,7 @@ export type GetUsersQuery = { __typename?: 'Query', users: Array<{ __typename?: export type GetGlobalDataQueryVariables = Exact<{ [key: string]: never; }>; -export type GetGlobalDataQuery = { __typename?: 'Query', me?: { __typename?: 'UserType', id: string, username: string, name: string, firstname?: string | null, lastname?: string | null, avatarUrl?: string | null, tasks: Array<{ __typename?: 'TaskType', id: string, done: boolean }> } | null, wards: Array<{ __typename?: 'LocationNodeType', id: string, title: string }>, teams: Array<{ __typename?: 'LocationNodeType', id: string, title: string }>, clinics: Array<{ __typename?: 'LocationNodeType', id: string, title: string }>, patients: Array<{ __typename?: 'PatientType', id: string, state: PatientState, assignedLocation?: { __typename?: 'LocationNodeType', id: string } | null }>, waitingPatients: Array<{ __typename?: 'PatientType', id: string, state: PatientState }> }; +export type GetGlobalDataQuery = { __typename?: 'Query', me?: { __typename?: 'UserType', id: string, username: string, name: string, firstname?: string | null, lastname?: string | null, avatarUrl?: string | null, organizations?: string | null, rootLocations: Array<{ __typename?: 'LocationNodeType', id: string, title: string, kind: LocationType }>, tasks: Array<{ __typename?: 'TaskType', id: string, done: boolean }> } | null, wards: Array<{ __typename?: 'LocationNodeType', id: string, title: string }>, teams: Array<{ __typename?: 'LocationNodeType', id: string, title: string }>, clinics: Array<{ __typename?: 'LocationNodeType', id: string, title: string }>, patients: Array<{ __typename?: 'PatientType', id: string, state: PatientState, assignedLocation?: { __typename?: 'LocationNodeType', id: string } | null }>, waitingPatients: Array<{ __typename?: 'PatientType', id: string, state: PatientState }> }; export type CreatePatientMutationVariables = Exact<{ data: CreatePatientInput; @@ -1218,6 +1219,12 @@ export const GetGlobalDataDocument = ` firstname lastname avatarUrl + organizations + rootLocations { + id + title + kind + } tasks { id done diff --git a/web/api/graphql/GlobalData.graphql b/web/api/graphql/GlobalData.graphql index 3eb3695..20b7a73 100644 --- a/web/api/graphql/GlobalData.graphql +++ b/web/api/graphql/GlobalData.graphql @@ -6,6 +6,12 @@ query GetGlobalData { firstname lastname avatarUrl + organizations + rootLocations { + id + title + kind + } tasks { id done diff --git a/web/components/layout/Page.tsx b/web/components/layout/Page.tsx index d102ae4..04b1eb8 100644 --- a/web/components/layout/Page.tsx +++ b/web/components/layout/Page.tsx @@ -204,7 +204,9 @@ type HeaderProps = HTMLAttributes & { export const Header = ({ onMenuClick, isMenuOpen, ...props }: HeaderProps) => { const router = useRouter() - const { user } = useTasksContext() + const { user, rootLocations } = useTasksContext() + + const organizations = user?.organizations ? user.organizations.split(',').map(org => org.trim()).filter(org => org.length > 0) : [] return (
{
+ {organizations.length > 0 && ( +
+ + + {organizations.join(', ')} + +
+ )} + {rootLocations && rootLocations.length > 0 && ( +
+ + + {rootLocations.map(loc => loc.title).join(', ')} + +
+ )}
diff --git a/web/components/patients/PatientDetailView.tsx b/web/components/patients/PatientDetailView.tsx index 8f177f3..0bad2ad 100644 --- a/web/components/patients/PatientDetailView.tsx +++ b/web/components/patients/PatientDetailView.tsx @@ -226,6 +226,8 @@ export const PatientDetailView = ({ const [selectedTeams, setSelectedTeams] = useState([]) const [isMarkDeadDialogOpen, setIsMarkDeadDialogOpen] = useState(false) const [isDischargeDialogOpen, setIsDischargeDialogOpen] = useState(false) + const [isLocationChangeConfirmOpen, setIsLocationChangeConfirmOpen] = useState(false) + const [pendingLocationUpdate, setPendingLocationUpdate] = useState<(() => void) | null>(null) // Validation state for required fields const [firstnameError, setFirstnameError] = useState(null) @@ -432,16 +434,30 @@ export const PatientDetailView = ({ setSelectedClinic(clinic) updateLocalState({ clinicId: clinic.id } as Partial) if (isEditMode) { - persistChanges({ clinicId: clinic.id } as Partial) + const updateFn = () => { + persistChanges({ clinicId: clinic.id } as Partial) + setIsLocationChangeConfirmOpen(false) + setPendingLocationUpdate(null) + } + setPendingLocationUpdate(() => updateFn) + setIsLocationChangeConfirmOpen(true) + } else { + validateClinic(clinic) } - validateClinic(clinic) } else { setSelectedClinic(null) updateLocalState({ clinicId: undefined } as Partial) if (isEditMode) { - persistChanges({ clinicId: undefined } as Partial) + const updateFn = () => { + persistChanges({ clinicId: undefined } as Partial) + setIsLocationChangeConfirmOpen(false) + setPendingLocationUpdate(null) + } + setPendingLocationUpdate(() => updateFn) + setIsLocationChangeConfirmOpen(true) + } else { + validateClinic(null) } - validateClinic(null) } setIsClinicDialogOpen(false) } @@ -451,11 +467,31 @@ export const PatientDetailView = ({ if (position) { setSelectedPosition(position) updateLocalState({ positionId: position.id } as Partial) - persistChanges({ positionId: position.id } as Partial) + if (isEditMode) { + const updateFn = () => { + persistChanges({ positionId: position.id } as Partial) + setIsLocationChangeConfirmOpen(false) + setPendingLocationUpdate(null) + } + setPendingLocationUpdate(() => updateFn) + setIsLocationChangeConfirmOpen(true) + } else { + persistChanges({ positionId: position.id } as Partial) + } } else { setSelectedPosition(null) updateLocalState({ positionId: undefined } as Partial) - persistChanges({ positionId: undefined } as Partial) + if (isEditMode) { + const updateFn = () => { + persistChanges({ positionId: undefined } as Partial) + setIsLocationChangeConfirmOpen(false) + setPendingLocationUpdate(null) + } + setPendingLocationUpdate(() => updateFn) + setIsLocationChangeConfirmOpen(true) + } else { + persistChanges({ positionId: undefined } as Partial) + } } setIsPositionDialogOpen(false) } @@ -464,7 +500,17 @@ export const PatientDetailView = ({ setSelectedTeams(locations) const teamIds = locations.map(loc => loc.id) updateLocalState({ teamIds } as Partial) - persistChanges({ teamIds } as Partial) + if (isEditMode) { + const updateFn = () => { + persistChanges({ teamIds } as Partial) + setIsLocationChangeConfirmOpen(false) + setPendingLocationUpdate(null) + } + setPendingLocationUpdate(() => updateFn) + setIsLocationChangeConfirmOpen(true) + } else { + persistChanges({ teamIds } as Partial) + } setIsTeamsDialogOpen(false) } @@ -1062,6 +1108,22 @@ export const PatientDetailView = ({ confirmType="neutral" /> + { + setIsLocationChangeConfirmOpen(false) + setPendingLocationUpdate(null) + }} + onConfirm={() => { + if (pendingLocationUpdate) { + pendingLocationUpdate() + } + }} + titleElement={translation('updateLocation')} + description={translation('updateLocationConfirmation')} + confirmType="neutral" + /> + setIsClinicDialogOpen(false)} diff --git a/web/hooks/useTasksContext.tsx b/web/hooks/useTasksContext.tsx index 567a445..5e0f32d 100644 --- a/web/hooks/useTasksContext.tsx +++ b/web/hooks/useTasksContext.tsx @@ -8,6 +8,7 @@ type User = { id: string, name: string, avatarUrl?: string | null, + organizations?: string | null, } type LocationNode = { @@ -32,6 +33,7 @@ export type TasksContextState = { selectedLocationId?: string, sidebar: SidebarContextType, user?: User, + rootLocations?: LocationNode[], } export type TasksContextType = TasksContextState & { @@ -76,7 +78,8 @@ export const TasksContextProvider = ({ children }: PropsWithChildren) => { user: data?.me ? { id: data.me.id, name: data.me.name, - avatarUrl: data.me.avatarUrl + avatarUrl: data.me.avatarUrl, + organizations: data.me.organizations ?? null } : undefined, myTasksCount: data?.me?.tasks?.filter(t => !t.done).length ?? 0, totalPatientsCount, @@ -87,6 +90,7 @@ export const TasksContextProvider = ({ children }: PropsWithChildren) => { teams: data?.teams, wards: data?.wards, clinics: data?.clinics, + rootLocations: data?.me?.rootLocations?.map(loc => ({ id: loc.id, title: loc.title })) ?? [], })) }, [data]) From 14c6c623de2dd38500483592169911ea48c3a207 Mon Sep 17 00:00:00 2001 From: Felix Evers Date: Mon, 22 Dec 2025 13:52:49 +0100 Subject: [PATCH 02/16] fix linting --- backend/api/resolvers/location.py | 14 +++++----- backend/api/resolvers/patient.py | 2 +- backend/api/resolvers/task.py | 37 +++++++++++++-------------- backend/api/services/authorization.py | 9 +++---- web/i18n/translations.ts | 6 +++++ web/locales/de-DE.arb | 2 ++ web/locales/en-US.arb | 2 ++ 7 files changed, 40 insertions(+), 32 deletions(-) diff --git a/backend/api/resolvers/location.py b/backend/api/resolvers/location.py index 5df6ad8..b321a86 100644 --- a/backend/api/resolvers/location.py +++ b/backend/api/resolvers/location.py @@ -16,10 +16,10 @@ async def location_roots(self, info: Info) -> list[LocationNodeType]: accessible_location_ids = await auth_service.get_user_accessible_location_ids( info.context.user, info.context ) - + if not accessible_location_ids: return [] - + result = await info.context.db.execute( select(models.LocationNode).where( models.LocationNode.parent_id.is_(None), @@ -38,7 +38,7 @@ async def location_node( select(models.LocationNode).where(models.LocationNode.id == id), ) location = result.scalars().first() - + if location: auth_service = AuthorizationService(info.context.db) accessible_location_ids = await auth_service.get_user_accessible_location_ids( @@ -49,7 +49,7 @@ async def location_node( "Forbidden: You do not have access to this location", extensions={"code": "FORBIDDEN"}, ) - + return location @strawberry.field @@ -63,12 +63,12 @@ async def location_nodes( order_by_name: bool = False, ) -> list[LocationNodeType]: db = info.context.db - + auth_service = AuthorizationService(db) accessible_location_ids = await auth_service.get_user_accessible_location_ids( info.context.user, info.context ) - + if not accessible_location_ids: return [] @@ -78,7 +78,7 @@ async def location_nodes( "Forbidden: You do not have access to this location", extensions={"code": "FORBIDDEN"}, ) - + cte = ( select(models.LocationNode) .where(models.LocationNode.id == parent_id) diff --git a/backend/api/resolvers/patient.py b/backend/api/resolvers/patient.py index 0f72ec3..70a42d2 100644 --- a/backend/api/resolvers/patient.py +++ b/backend/api/resolvers/patient.py @@ -162,7 +162,7 @@ async def create_patient( accessible_location_ids = await auth_service.get_user_accessible_location_ids( info.context.user, info.context ) - + if not accessible_location_ids: raise GraphQLError( "Forbidden: You do not have access to create patients", diff --git a/backend/api/resolvers/task.py b/backend/api/resolvers/task.py index b02345b..01a0ba5 100644 --- a/backend/api/resolvers/task.py +++ b/backend/api/resolvers/task.py @@ -6,7 +6,6 @@ from api.inputs import CreateTaskInput, UpdateTaskInput from api.resolvers.base import BaseMutationResolver, BaseSubscriptionResolver from api.services.authorization import AuthorizationService -from api.services.base import BaseRepository from api.services.checksum import validate_checksum from api.services.datetime import normalize_datetime_to_utc from api.services.property import PropertyService @@ -44,45 +43,45 @@ async def tasks( assignee_id: strawberry.ID | None = None, ) -> list[TaskType]: auth_service = AuthorizationService(info.context.db) - + if patient_id: if not await auth_service.can_access_patient_id(info.context.user, patient_id, info.context): raise GraphQLError( "Forbidden: You do not have access to this patient's tasks", extensions={"code": "FORBIDDEN"}, ) - + query = select(models.Task).options( selectinload(models.Task.patient).selectinload(models.Patient.assigned_locations) ).where(models.Task.patient_id == patient_id) - + if assignee_id: query = query.where(models.Task.assignee_id == assignee_id) - + result = await info.context.db.execute(query) return result.scalars().all() - + accessible_location_ids = await auth_service.get_user_accessible_location_ids( info.context.user, info.context ) - + if not accessible_location_ids: return [] - + patient_locations = aliased(models.patient_locations) patient_teams = aliased(models.patient_teams) - + cte = ( select(models.LocationNode.id) .where(models.LocationNode.id.in_(accessible_location_ids)) .cte(name="accessible_locations", recursive=True) ) - + children = select(models.LocationNode.id).join( cte, models.LocationNode.parent_id == cte.c.id ) cte = cte.union_all(children) - + query = ( select(models.Task) .options( @@ -112,10 +111,10 @@ async def tasks( ) .distinct() ) - + if assignee_id: query = query.where(models.Task.assignee_id == assignee_id) - + result = await info.context.db.execute(query) return result.scalars().all() @@ -129,24 +128,24 @@ async def recent_tasks( accessible_location_ids = await auth_service.get_user_accessible_location_ids( info.context.user, info.context ) - + if not accessible_location_ids: return [] - + patient_locations = aliased(models.patient_locations) patient_teams = aliased(models.patient_teams) - + cte = ( select(models.LocationNode.id) .where(models.LocationNode.id.in_(accessible_location_ids)) .cte(name="accessible_locations", recursive=True) ) - + children = select(models.LocationNode.id).join( cte, models.LocationNode.parent_id == cte.c.id ) cte = cte.union_all(children) - + query = ( select(models.Task) .options( @@ -178,7 +177,7 @@ async def recent_tasks( .limit(limit) .distinct() ) - + result = await info.context.db.execute(query) return result.scalars().all() diff --git a/backend/api/services/authorization.py b/backend/api/services/authorization.py index 062167e..8844967 100644 --- a/backend/api/services/authorization.py +++ b/backend/api/services/authorization.py @@ -1,4 +1,3 @@ -import asyncio from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import aliased, selectinload @@ -15,10 +14,10 @@ async def get_user_accessible_location_ids( ) -> set[str]: if context and hasattr(context, '_accessible_location_ids') and context._accessible_location_ids is not None: return context._accessible_location_ids - + if not context or not hasattr(context, '_accessible_location_ids_lock'): return await self._compute_accessible_location_ids(user, context) - + async with context._accessible_location_ids_lock: if context._accessible_location_ids is not None: return context._accessible_location_ids @@ -61,10 +60,10 @@ async def _compute_accessible_location_ids( result = await self.db.execute(select(cte.c.id)) rows = result.fetchall() accessible_ids = {row[0] for row in rows} - + if context: context._accessible_location_ids = accessible_ids - + return accessible_ids async def can_access_patient( diff --git a/web/i18n/translations.ts b/web/i18n/translations.ts index 9c7e666..a35b349 100644 --- a/web/i18n/translations.ts +++ b/web/i18n/translations.ts @@ -182,6 +182,8 @@ export type TasksTranslationEntries = { 'type': string, 'unassigned': string, 'updated': string, + 'updateLocation': string, + 'updateLocationConfirmation': string, 'visibility': string, 'waitingForPatient': string, 'waitingroom': string, @@ -474,6 +476,8 @@ export const tasksTranslation: Translation Date: Mon, 22 Dec 2025 14:21:07 +0100 Subject: [PATCH 03/16] add tests --- backend/tests/conftest.py | 24 ++++++++++++++++ .../integration/test_patient_resolver.py | 28 +++++++++---------- .../tests/integration/test_task_resolver.py | 28 +++++++++---------- 3 files changed, 52 insertions(+), 28 deletions(-) diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 4877679..98fb584 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -99,3 +99,27 @@ async def sample_task( await db_session.commit() await db_session.refresh(task) return task + + +@pytest.fixture +async def sample_user_with_location_access( + db_session: AsyncSession, sample_user: User, sample_location: LocationNode +) -> User: + from database.models.user import user_root_locations + from sqlalchemy import select, insert + + result = await db_session.execute( + select(user_root_locations).where( + user_root_locations.c.user_id == sample_user.id, + user_root_locations.c.location_id == sample_location.id, + ) + ) + existing = result.first() + if not existing: + stmt = insert(user_root_locations).values( + user_id=sample_user.id, location_id=sample_location.id + ) + await db_session.execute(stmt) + await db_session.commit() + await db_session.refresh(sample_user) + return sample_user diff --git a/backend/tests/integration/test_patient_resolver.py b/backend/tests/integration/test_patient_resolver.py index be75266..15139e2 100644 --- a/backend/tests/integration/test_patient_resolver.py +++ b/backend/tests/integration/test_patient_resolver.py @@ -5,13 +5,13 @@ class MockInfo: - def __init__(self, db): - self.context = Context(db=db) + def __init__(self, db, user=None): + self.context = Context(db=db, user=user) @pytest.mark.asyncio -async def test_patient_query_get_patient(db_session, sample_patient): - info = MockInfo(db_session) +async def test_patient_query_get_patient(db_session, sample_patient, sample_user_with_location_access): + info = MockInfo(db_session, sample_user_with_location_access) query = PatientQuery() result = await query.patient(info, sample_patient.id) assert result is not None @@ -20,8 +20,8 @@ async def test_patient_query_get_patient(db_session, sample_patient): @pytest.mark.asyncio -async def test_patient_query_patients(db_session, sample_patient): - info = MockInfo(db_session) +async def test_patient_query_patients(db_session, sample_patient, sample_user_with_location_access): + info = MockInfo(db_session, sample_user_with_location_access) query = PatientQuery() results = await query.patients(info) assert len(results) >= 1 @@ -29,11 +29,11 @@ async def test_patient_query_patients(db_session, sample_patient): @pytest.mark.asyncio -async def test_patient_mutation_create_patient(db_session, sample_location): +async def test_patient_mutation_create_patient(db_session, sample_location, sample_user_with_location_access): from api.inputs import CreatePatientInput from datetime import date - info = MockInfo(db_session) + info = MockInfo(db_session, sample_user_with_location_access) mutation = PatientMutation() input_data = CreatePatientInput( firstname="Jane", @@ -50,10 +50,10 @@ async def test_patient_mutation_create_patient(db_session, sample_location): @pytest.mark.asyncio -async def test_patient_mutation_update_patient(db_session, sample_patient): +async def test_patient_mutation_update_patient(db_session, sample_patient, sample_user_with_location_access): from api.inputs import UpdatePatientInput - info = MockInfo(db_session) + info = MockInfo(db_session, sample_user_with_location_access) mutation = PatientMutation() input_data = UpdatePatientInput(firstname="Updated Name") result = await mutation.update_patient(info, sample_patient.id, input_data) @@ -62,16 +62,16 @@ async def test_patient_mutation_update_patient(db_session, sample_patient): @pytest.mark.asyncio -async def test_patient_mutation_admit_patient(db_session, sample_patient): - info = MockInfo(db_session) +async def test_patient_mutation_admit_patient(db_session, sample_patient, sample_user_with_location_access): + info = MockInfo(db_session, sample_user_with_location_access) mutation = PatientMutation() result = await mutation.admit_patient(info, sample_patient.id) assert result.state == PatientState.ADMITTED.value @pytest.mark.asyncio -async def test_patient_mutation_discharge_patient(db_session, sample_patient): - info = MockInfo(db_session) +async def test_patient_mutation_discharge_patient(db_session, sample_patient, sample_user_with_location_access): + info = MockInfo(db_session, sample_user_with_location_access) mutation = PatientMutation() result = await mutation.discharge_patient(info, sample_patient.id) assert result.state == PatientState.DISCHARGED.value diff --git a/backend/tests/integration/test_task_resolver.py b/backend/tests/integration/test_task_resolver.py index 3cf68c7..14a4d8d 100644 --- a/backend/tests/integration/test_task_resolver.py +++ b/backend/tests/integration/test_task_resolver.py @@ -5,13 +5,13 @@ class MockInfo: - def __init__(self, db): - self.context = Context(db=db) + def __init__(self, db, user=None): + self.context = Context(db=db, user=user) @pytest.mark.asyncio -async def test_task_query_get_task(db_session, sample_task): - info = MockInfo(db_session) +async def test_task_query_get_task(db_session, sample_task, sample_user_with_location_access): + info = MockInfo(db_session, sample_user_with_location_access) query = TaskQuery() result = await query.task(info, sample_task.id) assert result is not None @@ -20,8 +20,8 @@ async def test_task_query_get_task(db_session, sample_task): @pytest.mark.asyncio -async def test_task_query_tasks_by_patient(db_session, sample_patient): - info = MockInfo(db_session) +async def test_task_query_tasks_by_patient(db_session, sample_patient, sample_user_with_location_access): + info = MockInfo(db_session, sample_user_with_location_access) task1 = Task(title="Task 1", patient_id=sample_patient.id) task2 = Task(title="Task 2", patient_id=sample_patient.id) db_session.add(task1) @@ -37,10 +37,10 @@ async def test_task_query_tasks_by_patient(db_session, sample_patient): @pytest.mark.asyncio -async def test_task_mutation_create_task(db_session, sample_patient): +async def test_task_mutation_create_task(db_session, sample_patient, sample_user_with_location_access): from api.inputs import CreateTaskInput - info = MockInfo(db_session) + info = MockInfo(db_session, sample_user_with_location_access) mutation = TaskMutation() input_data = CreateTaskInput( title="New Task", @@ -54,10 +54,10 @@ async def test_task_mutation_create_task(db_session, sample_patient): @pytest.mark.asyncio -async def test_task_mutation_update_task(db_session, sample_task): +async def test_task_mutation_update_task(db_session, sample_task, sample_user_with_location_access): from api.inputs import UpdateTaskInput - info = MockInfo(db_session) + info = MockInfo(db_session, sample_user_with_location_access) mutation = TaskMutation() input_data = UpdateTaskInput(title="Updated Title") result = await mutation.update_task(info, sample_task.id, input_data) @@ -66,8 +66,8 @@ async def test_task_mutation_update_task(db_session, sample_task): @pytest.mark.asyncio -async def test_task_mutation_complete_task(db_session, sample_task): - info = MockInfo(db_session) +async def test_task_mutation_complete_task(db_session, sample_task, sample_user_with_location_access): + info = MockInfo(db_session, sample_user_with_location_access) mutation = TaskMutation() result = await mutation.complete_task(info, sample_task.id) assert result.done is True @@ -75,8 +75,8 @@ async def test_task_mutation_complete_task(db_session, sample_task): @pytest.mark.asyncio -async def test_task_mutation_delete_task(db_session, sample_task): - info = MockInfo(db_session) +async def test_task_mutation_delete_task(db_session, sample_task, sample_user_with_location_access): + info = MockInfo(db_session, sample_user_with_location_access) mutation = TaskMutation() task_id = sample_task.id result = await mutation.delete_task(info, task_id) From 1837974e57cfa1418c49bc850a840034815a06f8 Mon Sep 17 00:00:00 2001 From: Felix Evers Date: Tue, 23 Dec 2025 15:16:08 +0100 Subject: [PATCH 04/16] remove legacy organization management --- backend/api/context.py | 28 +++++----- backend/api/types/user.py | 1 - .../remove_organizations_from_users.py | 29 ++++++++++ backend/database/models/user.py | 1 - backend/scaffold.py | 54 ------------------- web/api/gql/generated.ts | 5 +- web/api/graphql/GlobalData.graphql | 1 - web/components/layout/Page.tsx | 10 ---- web/hooks/useTasksContext.tsx | 2 - 9 files changed, 43 insertions(+), 88 deletions(-) create mode 100644 backend/database/migrations/versions/remove_organizations_from_users.py diff --git a/backend/api/context.py b/backend/api/context.py index c0084cb..9a52396 100644 --- a/backend/api/context.py +++ b/backend/api/context.py @@ -101,7 +101,6 @@ async def get_context( lastname=lastname, title="User", avatar_url=picture, - organizations=organizations, ) session.add(new_user) await session.commit() @@ -120,7 +119,6 @@ async def get_context( or db_user.lastname != lastname or db_user.email != email or db_user.avatar_url != picture - or db_user.organizations != organizations ): db_user.username = username db_user.firstname = firstname @@ -128,7 +126,6 @@ async def get_context( db_user.email = email if picture: db_user.avatar_url = picture - db_user.organizations = organizations session.add(db_user) await session.commit() await session.refresh(db_user) @@ -142,6 +139,8 @@ async def get_context( async def _update_user_root_locations( session: AsyncSession, user: User, organizations: str | None ) -> None: + from database.models.location import location_organizations + organization_ids: list[str] = [] if organizations: organization_ids = [ @@ -152,23 +151,20 @@ async def _update_user_root_locations( if organization_ids: result = await session.execute( - select(LocationNode).where(LocationNode.id.in_(organization_ids)) + select(LocationNode) + .join( + location_organizations, + LocationNode.id == location_organizations.c.location_id, + ) + .where( + LocationNode.kind == "CLINIC", + location_organizations.c.organization_id.in_(organization_ids), + ) + .distinct() ) found_locations = result.scalars().all() root_location_ids = [loc.id for loc in found_locations] - found_ids = {loc.id for loc in found_locations} - for org_id in organization_ids: - if org_id not in found_ids: - new_location = LocationNode( - id=org_id, - title=f"Organization {org_id[:8]}", - kind="CLINIC", - parent_id=None, - ) - session.add(new_location) - root_location_ids.append(org_id) - if not root_location_ids: personal_org_title = f"{user.username}'s Organization" result = await session.execute( diff --git a/backend/api/types/user.py b/backend/api/types/user.py index 14d86ee..2284702 100644 --- a/backend/api/types/user.py +++ b/backend/api/types/user.py @@ -19,7 +19,6 @@ class UserType: lastname: str | None title: str | None avatar_url: str | None - organizations: str | None @strawberry.field def name(self) -> str: diff --git a/backend/database/migrations/versions/remove_organizations_from_users.py b/backend/database/migrations/versions/remove_organizations_from_users.py new file mode 100644 index 0000000..c063f80 --- /dev/null +++ b/backend/database/migrations/versions/remove_organizations_from_users.py @@ -0,0 +1,29 @@ +"""Remove organizations field from users table. + +Revision ID: remove_organizations_from_users +Revises: add_user_root_locations +Create Date: 2025-01-16 12:00:00.000000 +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "remove_organizations_from_users" +down_revision: Union[str, Sequence[str], None] = "add_user_root_locations" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + op.drop_column("users", "organizations") + + +def downgrade() -> None: + """Downgrade schema.""" + op.add_column("users", sa.Column("organizations", sa.String(), nullable=True)) + diff --git a/backend/database/models/user.py b/backend/database/models/user.py index c4231f8..67b9cac 100644 --- a/backend/database/models/user.py +++ b/backend/database/models/user.py @@ -37,7 +37,6 @@ class User(Base): nullable=True, default="https://cdn.helpwave.de/boringavatar.svg", ) - organizations: Mapped[str | None] = mapped_column(String, nullable=True) tasks: Mapped[list[Task]] = relationship("Task", back_populates="assignee") root_locations: Mapped[list[LocationNode]] = relationship( diff --git a/backend/scaffold.py b/backend/scaffold.py index 7431f57..f781c67 100644 --- a/backend/scaffold.py +++ b/backend/scaffold.py @@ -6,7 +6,6 @@ from api.inputs import LocationType from config import LOGGER, SCAFFOLD_DIRECTORY from database.models.location import LocationNode, location_organizations -from database.models.user import User from database.session import async_session from sqlalchemy import select @@ -72,7 +71,6 @@ async def load_scaffold_data() -> None: logger.info( f"Successfully loaded scaffold data from {json_file}" ) - await _assign_clinics_to_users(session) except json.JSONDecodeError as e: logger.error(f"Failed to parse JSON file {json_file}: {e}") await session.rollback() @@ -143,55 +141,3 @@ async def _create_location_tree( await _create_location_tree(session, child_data, location_id) return location_id - - -async def _assign_clinics_to_users(session: Any) -> None: - result = await session.execute(select(User)) - users = result.scalars().all() - - for user in users: - if not user.organizations: - continue - - org_ids = [ - org_id.strip() - for org_id in user.organizations.split(",") - if org_id.strip() - ] - - for org_id in org_ids: - clinic_result = await session.execute( - select(LocationNode) - .join( - location_organizations, - LocationNode.id == location_organizations.c.location_id, - ) - .where( - LocationNode.kind == "CLINIC", - location_organizations.c.organization_id == org_id, - ) - .limit(1) - ) - clinic = clinic_result.scalar_one_or_none() - - if clinic: - from database.models.user import user_root_locations - - existing_result = await session.execute( - select(user_root_locations).where( - user_root_locations.c.user_id == user.id, - user_root_locations.c.location_id == clinic.id, - ) - ) - existing = existing_result.first() - if not existing: - await session.execute( - user_root_locations.insert().values( - user_id=user.id, location_id=clinic.id - ) - ) - logger.info( - f"Assigned clinic '{clinic.title}' to user '{user.username}' based on organization '{org_id}'" - ) - - await session.commit() diff --git a/web/api/gql/generated.ts b/web/api/gql/generated.ts index 3bd3aca..4d2b898 100644 --- a/web/api/gql/generated.ts +++ b/web/api/gql/generated.ts @@ -67,6 +67,7 @@ export type LocationNodeType = { children: Array; id: Scalars['ID']['output']; kind: LocationType; + organizationIds: Array; parent?: Maybe; parentId?: Maybe; patients: Array; @@ -425,7 +426,6 @@ export type UserType = { id: Scalars['ID']['output']; lastname?: Maybe; name: Scalars['String']['output']; - organizations?: Maybe; rootLocations: Array; tasks: Array; title?: Maybe; @@ -484,7 +484,7 @@ export type GetUsersQuery = { __typename?: 'Query', users: Array<{ __typename?: export type GetGlobalDataQueryVariables = Exact<{ [key: string]: never; }>; -export type GetGlobalDataQuery = { __typename?: 'Query', me?: { __typename?: 'UserType', id: string, username: string, name: string, firstname?: string | null, lastname?: string | null, avatarUrl?: string | null, organizations?: string | null, rootLocations: Array<{ __typename?: 'LocationNodeType', id: string, title: string, kind: LocationType }>, tasks: Array<{ __typename?: 'TaskType', id: string, done: boolean }> } | null, wards: Array<{ __typename?: 'LocationNodeType', id: string, title: string }>, teams: Array<{ __typename?: 'LocationNodeType', id: string, title: string }>, clinics: Array<{ __typename?: 'LocationNodeType', id: string, title: string }>, patients: Array<{ __typename?: 'PatientType', id: string, state: PatientState, assignedLocation?: { __typename?: 'LocationNodeType', id: string } | null }>, waitingPatients: Array<{ __typename?: 'PatientType', id: string, state: PatientState }> }; +export type GetGlobalDataQuery = { __typename?: 'Query', me?: { __typename?: 'UserType', id: string, username: string, name: string, firstname?: string | null, lastname?: string | null, avatarUrl?: string | null, rootLocations: Array<{ __typename?: 'LocationNodeType', id: string, title: string, kind: LocationType }>, tasks: Array<{ __typename?: 'TaskType', id: string, done: boolean }> } | null, wards: Array<{ __typename?: 'LocationNodeType', id: string, title: string }>, teams: Array<{ __typename?: 'LocationNodeType', id: string, title: string }>, clinics: Array<{ __typename?: 'LocationNodeType', id: string, title: string }>, patients: Array<{ __typename?: 'PatientType', id: string, state: PatientState, assignedLocation?: { __typename?: 'LocationNodeType', id: string } | null }>, waitingPatients: Array<{ __typename?: 'PatientType', id: string, state: PatientState }> }; export type CreatePatientMutationVariables = Exact<{ data: CreatePatientInput; @@ -1219,7 +1219,6 @@ export const GetGlobalDataDocument = ` firstname lastname avatarUrl - organizations rootLocations { id title diff --git a/web/api/graphql/GlobalData.graphql b/web/api/graphql/GlobalData.graphql index 20b7a73..ada82c6 100644 --- a/web/api/graphql/GlobalData.graphql +++ b/web/api/graphql/GlobalData.graphql @@ -6,7 +6,6 @@ query GetGlobalData { firstname lastname avatarUrl - organizations rootLocations { id title diff --git a/web/components/layout/Page.tsx b/web/components/layout/Page.tsx index 04b1eb8..f0f2bf7 100644 --- a/web/components/layout/Page.tsx +++ b/web/components/layout/Page.tsx @@ -206,8 +206,6 @@ export const Header = ({ onMenuClick, isMenuOpen, ...props }: HeaderProps) => { const router = useRouter() const { user, rootLocations } = useTasksContext() - const organizations = user?.organizations ? user.organizations.split(',').map(org => org.trim()).filter(org => org.length > 0) : [] - return (
{
- {organizations.length > 0 && ( -
- - - {organizations.join(', ')} - -
- )} {rootLocations && rootLocations.length > 0 && (
diff --git a/web/hooks/useTasksContext.tsx b/web/hooks/useTasksContext.tsx index 5e0f32d..c2e06eb 100644 --- a/web/hooks/useTasksContext.tsx +++ b/web/hooks/useTasksContext.tsx @@ -8,7 +8,6 @@ type User = { id: string, name: string, avatarUrl?: string | null, - organizations?: string | null, } type LocationNode = { @@ -79,7 +78,6 @@ export const TasksContextProvider = ({ children }: PropsWithChildren) => { id: data.me.id, name: data.me.name, avatarUrl: data.me.avatarUrl, - organizations: data.me.organizations ?? null } : undefined, myTasksCount: data?.me?.tasks?.filter(t => !t.done).length ?? 0, totalPatientsCount, From e85b92f8bf646797e82ebeb89aedee495862d194 Mon Sep 17 00:00:00 2001 From: Felix Evers Date: Tue, 23 Dec 2025 15:20:57 +0100 Subject: [PATCH 05/16] fix migration chain --- .../migrations/versions/remove_organizations_from_users.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/backend/database/migrations/versions/remove_organizations_from_users.py b/backend/database/migrations/versions/remove_organizations_from_users.py index c063f80..1d4f7bf 100644 --- a/backend/database/migrations/versions/remove_organizations_from_users.py +++ b/backend/database/migrations/versions/remove_organizations_from_users.py @@ -1,7 +1,7 @@ """Remove organizations field from users table. Revision ID: remove_organizations_from_users -Revises: add_user_root_locations +Revises: add_location_organizations_table Create Date: 2025-01-16 12:00:00.000000 """ @@ -13,7 +13,7 @@ # revision identifiers, used by Alembic. revision: str = "remove_organizations_from_users" -down_revision: Union[str, Sequence[str], None] = "add_user_root_locations" +down_revision: Union[str, Sequence[str], None] = "add_location_organizations_table" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None From 213514ba159c842f896ecfaf12c4d3795298afa5 Mon Sep 17 00:00:00 2001 From: Felix Evers Date: Tue, 23 Dec 2025 22:06:55 +0100 Subject: [PATCH 06/16] Make root location picker multiselect with LocationSelectionDialog filter by organizations and improve error handling --- backend/api/context.py | 53 ++++++++-- backend/api/inputs.py | 6 +- backend/api/resolvers/location.py | 16 ++- backend/api/resolvers/patient.py | 100 +++++++++++------- backend/api/resolvers/task.py | 45 +++++--- backend/api/types/user.py | 81 ++++++++++++-- .../versions/add_location_type_enum.py | 51 +++++++++ backend/database/models/location.py | 19 +++- backend/scaffold.py | 35 +++--- keycloak/tasks.json | 3 +- web/api/gql/generated.ts | 79 +++++++++++++- web/api/graphql/GetPatients.graphql | 4 +- web/api/graphql/GetTasks.graphql | 42 ++++++++ web/components/layout/Page.tsx | 52 +++++++-- .../locations/LocationSelectionDialog.tsx | 15 +-- web/components/patients/LocationChips.tsx | 12 ++- web/components/patients/PatientDetailView.tsx | 19 +++- web/components/patients/PatientList.tsx | 3 + web/hooks/useTasksContext.tsx | 82 ++++++++++---- web/pages/location/[id].tsx | 16 +-- web/pages/patients/index.tsx | 5 +- web/pages/settings/index.tsx | 24 ++++- web/pages/tasks/index.tsx | 15 ++- 23 files changed, 626 insertions(+), 151 deletions(-) create mode 100644 backend/database/migrations/versions/add_location_type_enum.py create mode 100644 web/api/graphql/GetTasks.graphql diff --git a/backend/api/context.py b/backend/api/context.py index 9a52396..e3b8cdd 100644 --- a/backend/api/context.py +++ b/backend/api/context.py @@ -3,10 +3,11 @@ import strawberry from auth import get_user_payload -from database.models.location import LocationNode +from database.models.location import LocationNode, location_organizations from database.models.user import User, user_root_locations from database.session import get_db_session from fastapi import Depends +from graphql import GraphQLError from sqlalchemy import delete, select from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession @@ -101,6 +102,7 @@ async def get_context( lastname=lastname, title="User", avatar_url=picture, + organizations=organizations, ) session.add(new_user) await session.commit() @@ -112,6 +114,12 @@ async def get_context( select(User).where(User.id == user_id), ) db_user = result.scalars().first() + except Exception as e: + await session.rollback() + raise GraphQLError( + "Failed to create user. Please contact an administrator if you believe this is an error.", + extensions={"code": "INTERNAL_SERVER_ERROR"}, + ) from e if db_user and ( db_user.username != username @@ -119,6 +127,7 @@ async def get_context( or db_user.lastname != lastname or db_user.email != email or db_user.avatar_url != picture + or db_user.organizations != organizations ): db_user.username = username db_user.firstname = firstname @@ -126,12 +135,19 @@ async def get_context( db_user.email = email if picture: db_user.avatar_url = picture + db_user.organizations = organizations session.add(db_user) await session.commit() await session.refresh(db_user) if db_user: - await _update_user_root_locations(session, db_user, organizations) + try: + await _update_user_root_locations(session, db_user, organizations) + except Exception as e: + raise GraphQLError( + "Failed to update user root locations. Please contact an administrator if you believe this is an error.", + extensions={"code": "INTERNAL_SERVER_ERROR"}, + ) from e return Context(db=session, user=db_user) @@ -156,15 +172,38 @@ async def _update_user_root_locations( location_organizations, LocationNode.id == location_organizations.c.location_id, ) - .where( - LocationNode.kind == "CLINIC", - location_organizations.c.organization_id.in_(organization_ids), - ) - .distinct() + .where(location_organizations.c.organization_id.in_(organization_ids)) ) found_locations = result.scalars().all() root_location_ids = [loc.id for loc in found_locations] + found_org_ids = set() + for loc in found_locations: + org_result = await session.execute( + select(location_organizations.c.organization_id).where( + location_organizations.c.location_id == loc.id, + location_organizations.c.organization_id.in_(organization_ids) + ) + ) + found_org_ids.update(row[0] for row in org_result.all()) + + for org_id in organization_ids: + if org_id not in found_org_ids: + new_location = LocationNode( + title=f"Organization {org_id[:8]}", + kind="CLINIC", + parent_id=None, + ) + session.add(new_location) + await session.flush() + await session.refresh(new_location) + + await session.execute( + location_organizations.insert().values( + location_id=new_location.id, organization_id=org_id + ) + ) + root_location_ids.append(new_location.id) if not root_location_ids: personal_org_title = f"{user.username}'s Organization" result = await session.execute( diff --git a/backend/api/inputs.py b/backend/api/inputs.py index 03c7827..4936f82 100644 --- a/backend/api/inputs.py +++ b/backend/api/inputs.py @@ -69,12 +69,12 @@ class CreatePatientInput: sex: Sex assigned_location_id: strawberry.ID | None = None assigned_location_ids: list[strawberry.ID] | None = None - clinic_id: strawberry.ID # Required: location node from kind CLINIC + clinic_id: strawberry.ID position_id: strawberry.ID | None = ( - None # Optional: location node from type hospital, practice, clinic, ward, bed or room + None ) team_ids: list[strawberry.ID] | None = ( - None # Array: location nodes from type clinic, team, practice, hospital + None ) properties: list[PropertyValueInput] | None = None state: PatientState | None = None diff --git a/backend/api/resolvers/location.py b/backend/api/resolvers/location.py index b321a86..3cc8ef0 100644 --- a/backend/api/resolvers/location.py +++ b/backend/api/resolvers/location.py @@ -21,10 +21,16 @@ async def location_roots(self, info: Info) -> list[LocationNodeType]: return [] result = await info.context.db.execute( - select(models.LocationNode).where( + select(models.LocationNode) + .join( + models.location_organizations, + models.LocationNode.id == models.location_organizations.c.location_id, + ) + .where( models.LocationNode.parent_id.is_(None), models.LocationNode.id.in_(accessible_location_ids), - ), + ) + .distinct() ) return result.scalars().all() @@ -46,7 +52,7 @@ async def location_node( ) if location.id not in accessible_location_ids: raise GraphQLError( - "Forbidden: You do not have access to this location", + "Insufficient permission. Please contact an administrator if you believe this is an error.", extensions={"code": "FORBIDDEN"}, ) @@ -75,7 +81,7 @@ async def location_nodes( if recursive and parent_id: if parent_id not in accessible_location_ids: raise GraphQLError( - "Forbidden: You do not have access to this location", + "Insufficient permission. Please contact an administrator if you believe this is an error.", extensions={"code": "FORBIDDEN"}, ) @@ -98,7 +104,7 @@ async def location_nodes( if parent_id: if parent_id not in accessible_location_ids: raise GraphQLError( - "Forbidden: You do not have access to this location", + "Insufficient permission. Please contact an administrator if you believe this is an error.", extensions={"code": "FORBIDDEN"}, ) query = query.where(models.LocationNode.parent_id == parent_id) diff --git a/backend/api/resolvers/patient.py b/backend/api/resolvers/patient.py index 70a42d2..8f71fd1 100644 --- a/backend/api/resolvers/patient.py +++ b/backend/api/resolvers/patient.py @@ -38,7 +38,7 @@ async def patient( auth_service = AuthorizationService(info.context.db) if not await auth_service.can_access_patient(info.context.user, patient, info.context): raise GraphQLError( - "Forbidden: You do not have access to this patient", + "Insufficient permission. Please contact an administrator if you believe this is an error.", extensions={"code": "FORBIDDEN"}, ) return patient @@ -48,6 +48,7 @@ async def patients( self, info: Info, location_node_id: strawberry.ID | None = None, + root_location_ids: list[strawberry.ID] | None = None, states: list[PatientState] | None = None, ) -> list[PatientType]: query = select(models.Patient).options( @@ -63,49 +64,72 @@ async def patients( query = query.where( models.Patient.state == PatientState.ADMITTED.value ) - if location_node_id: - cte = ( + auth_service = AuthorizationService(info.context.db) + accessible_location_ids = await auth_service.get_user_accessible_location_ids( + info.context.user, info.context + ) + query = auth_service.filter_patients_by_access( + info.context.user, query, accessible_location_ids + ) + + filter_cte = None + if root_location_ids: + invalid_ids = [lid for lid in root_location_ids if lid not in accessible_location_ids] + if invalid_ids: + raise GraphQLError( + "Insufficient permission. Please contact an administrator if you believe this is an error.", + extensions={"code": "FORBIDDEN"}, + ) + filter_cte = ( + select(models.LocationNode.id) + .where(models.LocationNode.id.in_(root_location_ids)) + .cte(name="root_location_descendants", recursive=True) + ) + root_children = select(models.LocationNode.id).join( + filter_cte, models.LocationNode.parent_id == filter_cte.c.id + ) + filter_cte = filter_cte.union_all(root_children) + elif location_node_id: + filter_cte = ( select(models.LocationNode.id) .where(models.LocationNode.id == location_node_id) .cte(name="location_descendants", recursive=True) ) - parent = select(models.LocationNode.id).join( - cte, - models.LocationNode.parent_id == cte.c.id, + filter_cte, + models.LocationNode.parent_id == filter_cte.c.id, ) - cte = cte.union_all(parent) - - patient_locations = aliased(models.patient_locations) - patient_teams = aliased(models.patient_teams) + filter_cte = filter_cte.union_all(parent) + if filter_cte: + patient_locations_filter = aliased(models.patient_locations) + patient_teams_filter = aliased(models.patient_teams) + query = ( query.outerjoin( - patient_locations, - models.Patient.id == patient_locations.c.patient_id, + patient_locations_filter, + models.Patient.id == patient_locations_filter.c.patient_id, ) .outerjoin( - patient_teams, - models.Patient.id == patient_teams.c.patient_id, + patient_teams_filter, + models.Patient.id == patient_teams_filter.c.patient_id, ) .where( - (models.Patient.assigned_location_id.in_(select(cte.c.id))) - | (patient_locations.c.location_id.in_(select(cte.c.id))) - | (models.Patient.clinic_id.in_(select(cte.c.id))) - | (models.Patient.position_id.in_(select(cte.c.id))) - | (patient_teams.c.location_id.in_(select(cte.c.id))), + (models.Patient.clinic_id.in_(select(filter_cte.c.id))) + | ( + models.Patient.position_id.isnot(None) + & models.Patient.position_id.in_(select(filter_cte.c.id)) + ) + | ( + models.Patient.assigned_location_id.isnot(None) + & models.Patient.assigned_location_id.in_(select(filter_cte.c.id)) + ) + | (patient_locations_filter.c.location_id.in_(select(filter_cte.c.id))) + | (patient_teams_filter.c.location_id.in_(select(filter_cte.c.id))) ) .distinct() ) - auth_service = AuthorizationService(info.context.db) - accessible_location_ids = await auth_service.get_user_accessible_location_ids( - info.context.user, info.context - ) - query = auth_service.filter_patients_by_access( - info.context.user, query, accessible_location_ids - ) - result = await info.context.db.execute(query) return result.scalars().all() @@ -165,13 +189,13 @@ async def create_patient( if not accessible_location_ids: raise GraphQLError( - "Forbidden: You do not have access to create patients", + "Insufficient permission. Please contact an administrator if you believe this is an error.", extensions={"code": "FORBIDDEN"}, ) if data.clinic_id not in accessible_location_ids: raise GraphQLError( - "Forbidden: You do not have access to this clinic", + "Insufficient permission. Please contact an administrator if you believe this is an error.", extensions={"code": "FORBIDDEN"}, ) @@ -180,7 +204,7 @@ async def create_patient( if data.position_id: if data.position_id not in accessible_location_ids: raise GraphQLError( - "Forbidden: You do not have access to this position", + "Insufficient permission. Please contact an administrator if you believe this is an error.", extensions={"code": "FORBIDDEN"}, ) await location_service.validate_and_get_position(data.position_id) @@ -190,7 +214,7 @@ async def create_patient( for team_id in data.team_ids: if team_id not in accessible_location_ids: raise GraphQLError( - "Forbidden: You do not have access to one or more teams", + "Insufficient permission. Please contact an administrator if you believe this is an error.", extensions={"code": "FORBIDDEN"}, ) teams = await location_service.validate_and_get_teams( @@ -215,7 +239,7 @@ async def create_patient( for loc_id in data.assigned_location_ids: if loc_id not in accessible_location_ids: raise GraphQLError( - "Forbidden: You do not have access to one or more assigned locations", + "Insufficient permission. Please contact an administrator if you believe this is an error.", extensions={"code": "FORBIDDEN"}, ) locations = await location_service.get_locations_by_ids( @@ -225,7 +249,7 @@ async def create_patient( elif data.assigned_location_id: if data.assigned_location_id not in accessible_location_ids: raise GraphQLError( - "Forbidden: You do not have access to this assigned location", + "Insufficient permission. Please contact an administrator if you believe this is an error.", extensions={"code": "FORBIDDEN"}, ) location = await location_service.get_location_by_id( @@ -295,7 +319,7 @@ async def update_patient( if data.clinic_id is not None: if data.clinic_id not in accessible_location_ids: raise GraphQLError( - "Forbidden: You do not have access to this clinic", + "Insufficient permission. Please contact an administrator if you believe this is an error.", extensions={"code": "FORBIDDEN"}, ) await location_service.validate_and_get_clinic(data.clinic_id) @@ -307,7 +331,7 @@ async def update_patient( else: if data.position_id not in accessible_location_ids: raise GraphQLError( - "Forbidden: You do not have access to this position", + "Insufficient permission. Please contact an administrator if you believe this is an error.", extensions={"code": "FORBIDDEN"}, ) await location_service.validate_and_get_position( @@ -322,7 +346,7 @@ async def update_patient( for team_id in data.team_ids: if team_id not in accessible_location_ids: raise GraphQLError( - "Forbidden: You do not have access to one or more teams", + "Insufficient permission. Please contact an administrator if you believe this is an error.", extensions={"code": "FORBIDDEN"}, ) patient.teams = await location_service.validate_and_get_teams( @@ -333,7 +357,7 @@ async def update_patient( for loc_id in data.assigned_location_ids: if loc_id not in accessible_location_ids: raise GraphQLError( - "Forbidden: You do not have access to one or more assigned locations", + "Insufficient permission. Please contact an administrator if you believe this is an error.", extensions={"code": "FORBIDDEN"}, ) locations = await location_service.get_locations_by_ids( @@ -343,7 +367,7 @@ async def update_patient( elif data.assigned_location_id is not None: if data.assigned_location_id not in accessible_location_ids: raise GraphQLError( - "Forbidden: You do not have access to this assigned location", + "Insufficient permission. Please contact an administrator if you believe this is an error.", extensions={"code": "FORBIDDEN"}, ) location = await location_service.get_location_by_id( diff --git a/backend/api/resolvers/task.py b/backend/api/resolvers/task.py index 01a0ba5..235533d 100644 --- a/backend/api/resolvers/task.py +++ b/backend/api/resolvers/task.py @@ -30,7 +30,7 @@ async def task(self, info: Info, id: strawberry.ID) -> TaskType | None: auth_service = AuthorizationService(info.context.db) if not await auth_service.can_access_patient(info.context.user, task.patient, info.context): raise GraphQLError( - "Forbidden: You do not have access to this task", + "Insufficient permission. Please contact an administrator if you believe this is an error.", extensions={"code": "FORBIDDEN"}, ) return task @@ -41,13 +41,14 @@ async def tasks( info: Info, patient_id: strawberry.ID | None = None, assignee_id: strawberry.ID | None = None, + root_location_ids: list[strawberry.ID] | None = None, ) -> list[TaskType]: auth_service = AuthorizationService(info.context.db) if patient_id: if not await auth_service.can_access_patient_id(info.context.user, patient_id, info.context): raise GraphQLError( - "Forbidden: You do not have access to this patient's tasks", + "Insufficient permission. Please contact an administrator if you believe this is an error.", extensions={"code": "FORBIDDEN"}, ) @@ -81,7 +82,27 @@ async def tasks( cte, models.LocationNode.parent_id == cte.c.id ) cte = cte.union_all(children) - + + if root_location_ids: + invalid_ids = [lid for lid in root_location_ids if lid not in accessible_location_ids] + if invalid_ids: + raise GraphQLError( + "Insufficient permission. Please contact an administrator if you believe this is an error.", + extensions={"code": "FORBIDDEN"}, + ) + root_cte = ( + select(models.LocationNode.id) + .where(models.LocationNode.id.in_(root_location_ids)) + .cte(name="root_location_descendants", recursive=True) + ) + root_children = select(models.LocationNode.id).join( + root_cte, models.LocationNode.parent_id == root_cte.c.id + ) + root_cte = root_cte.union_all(root_children) + else: + root_cte = cte + +>>>>>>> 7eb41e8 (Make root location picker multiselect with LocationSelectionDialog filter by organizations and improve error handling) query = ( select(models.Task) .options( @@ -97,17 +118,17 @@ async def tasks( models.Patient.id == patient_teams.c.patient_id, ) .where( - (models.Patient.clinic_id.in_(select(cte.c.id))) + (models.Patient.clinic_id.in_(select(root_cte.c.id))) | ( models.Patient.position_id.isnot(None) - & models.Patient.position_id.in_(select(cte.c.id)) + & models.Patient.position_id.in_(select(root_cte.c.id)) ) | ( models.Patient.assigned_location_id.isnot(None) - & models.Patient.assigned_location_id.in_(select(cte.c.id)) + & models.Patient.assigned_location_id.in_(select(root_cte.c.id)) ) - | (patient_locations.c.location_id.in_(select(cte.c.id))) - | (patient_teams.c.location_id.in_(select(cte.c.id))) + | (patient_locations.c.location_id.in_(select(root_cte.c.id))) + | (patient_teams.c.location_id.in_(select(root_cte.c.id))) ) .distinct() ) @@ -194,7 +215,7 @@ async def create_task(self, info: Info, data: CreateTaskInput) -> TaskType: auth_service = AuthorizationService(info.context.db) if not await auth_service.can_access_patient_id(info.context.user, data.patient_id, info.context): raise GraphQLError( - "Forbidden: You do not have access to create tasks for this patient", + "Insufficient permission. Please contact an administrator if you believe this is an error.", extensions={"code": "FORBIDDEN"}, ) @@ -243,7 +264,7 @@ async def update_task( auth_service = AuthorizationService(db) if not await auth_service.can_access_patient(info.context.user, task.patient, info.context): raise GraphQLError( - "Forbidden: You do not have access to this task", + "Insufficient permission. Please contact an administrator if you believe this is an error.", extensions={"code": "FORBIDDEN"}, ) @@ -299,7 +320,7 @@ async def _update_task_field( auth_service = AuthorizationService(db) if not await auth_service.can_access_patient(info.context.user, task.patient, info.context): raise GraphQLError( - "Forbidden: You do not have access to this task", + "Insufficient permission. Please contact an administrator if you believe this is an error.", extensions={"code": "FORBIDDEN"}, ) @@ -367,7 +388,7 @@ async def delete_task(self, info: Info, id: strawberry.ID) -> bool: auth_service = AuthorizationService(db) if not await auth_service.can_access_patient(info.context.user, task.patient, info.context): raise GraphQLError( - "Forbidden: You do not have access to this task", + "Insufficient permission. Please contact an administrator if you believe this is an error.", extensions={"code": "FORBIDDEN"}, ) diff --git a/backend/api/types/user.py b/backend/api/types/user.py index 2284702..d5f1676 100644 --- a/backend/api/types/user.py +++ b/backend/api/types/user.py @@ -3,7 +3,7 @@ import strawberry from database import models from sqlalchemy import select -from sqlalchemy.orm import selectinload +from sqlalchemy.orm import aliased, selectinload if TYPE_CHECKING: from api.types.location import LocationNodeType @@ -31,10 +31,63 @@ async def tasks( self, info, ) -> list[Annotated["TaskType", strawberry.lazy("api.types.task")]]: - - result = await info.context.db.execute( - select(models.Task).where(models.Task.assignee_id == self.id), + from api.services.authorization import AuthorizationService + + auth_service = AuthorizationService(info.context.db) + accessible_location_ids = await auth_service.get_user_accessible_location_ids( + info.context.user, info.context + ) + + if not accessible_location_ids: + return [] + + from sqlalchemy.orm import aliased + patient_locations = aliased(models.patient_locations) + patient_teams = aliased(models.patient_teams) + + from sqlalchemy import select + cte = ( + select(models.LocationNode.id) + .where(models.LocationNode.id.in_(accessible_location_ids)) + .cte(name="accessible_locations", recursive=True) ) + + children = select(models.LocationNode.id).join( + cte, models.LocationNode.parent_id == cte.c.id + ) + cte = cte.union_all(children) + + query = ( + select(models.Task) + .join(models.Patient, models.Task.patient_id == models.Patient.id) + .outerjoin( + patient_locations, + models.Patient.id == patient_locations.c.patient_id, + ) + .outerjoin( + patient_teams, + models.Patient.id == patient_teams.c.patient_id, + ) + .where( + models.Task.assignee_id == self.id, + ( + (models.Patient.clinic_id.in_(select(cte.c.id))) + | ( + models.Patient.position_id.isnot(None) + & models.Patient.position_id.in_(select(cte.c.id)) + ) + | ( + models.Patient.assigned_location_id.isnot(None) + & models.Patient.assigned_location_id.in_(select(cte.c.id)) + ) + | (patient_locations.c.location_id.in_(select(cte.c.id))) + | (patient_teams.c.location_id.in_(select(cte.c.id))) + ) + ) + .distinct() + ) + + result = await info.context.db.execute(query) return result.scalars().all() @strawberry.field @@ -43,11 +96,17 @@ async def root_locations( info, ) -> list[Annotated["LocationNodeType", strawberry.lazy("api.types.location")]]: result = await info.context.db.execute( - select(models.User) - .where(models.User.id == self.id) - .options(selectinload(models.User.root_locations)) + select(models.LocationNode) + .join( + models.user_root_locations, + models.LocationNode.id == models.user_root_locations.c.location_id, + ) + .join( + models.location_organizations, + models.LocationNode.id == models.location_organizations.c.location_id, + ) + .where(models.user_root_locations.c.user_id == self.id) + .where(models.LocationNode.parent_id.is_(None)) + .distinct() ) - user = result.scalars().first() - if not user: - return [] - return user.root_locations or [] + return result.scalars().all() diff --git a/backend/database/migrations/versions/add_location_type_enum.py b/backend/database/migrations/versions/add_location_type_enum.py new file mode 100644 index 0000000..cda3918 --- /dev/null +++ b/backend/database/migrations/versions/add_location_type_enum.py @@ -0,0 +1,51 @@ +"""Add location type enum. + +Revision ID: add_location_type_enum +Revises: add_location_organizations_table +Create Date: 2025-01-17 00:00:00.000000 +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "add_location_type_enum" +down_revision: Union[str, Sequence[str], None] = "add_location_organizations_table" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + op.alter_column( + 'location_nodes', + 'kind', + type_=sa.Enum( + 'HOSPITAL', 'PRACTICE', 'CLINIC', 'TEAM', 'WARD', 'ROOM', 'BED', 'OTHER', + name='locationtypeenum', + native_enum=False, + length=50 + ), + existing_type=sa.String(), + existing_nullable=False + ) + + +def downgrade() -> None: + """Downgrade schema.""" + op.alter_column( + 'location_nodes', + 'kind', + type_=sa.String(), + existing_type=sa.Enum( + 'HOSPITAL', 'PRACTICE', 'CLINIC', 'TEAM', 'WARD', 'ROOM', 'BED', 'OTHER', + name='locationtypeenum', + native_enum=False, + length=50 + ), + existing_nullable=False + ) + diff --git a/backend/database/models/location.py b/backend/database/models/location.py index 7efaa51..1042953 100644 --- a/backend/database/models/location.py +++ b/backend/database/models/location.py @@ -1,16 +1,29 @@ from __future__ import annotations import uuid +from enum import Enum from typing import TYPE_CHECKING from database.models.base import Base -from sqlalchemy import Column, ForeignKey, String, Table +from sqlalchemy import Column, ForeignKey, String, Table, Enum as SQLEnum from sqlalchemy.orm import Mapped, mapped_column, relationship if TYPE_CHECKING: from .patient import Patient from .user import User + +class LocationTypeEnum(str, Enum): + HOSPITAL = "HOSPITAL" + PRACTICE = "PRACTICE" + CLINIC = "CLINIC" + TEAM = "TEAM" + WARD = "WARD" + ROOM = "ROOM" + BED = "BED" + OTHER = "OTHER" + + location_organizations = Table( "location_organizations", Base.metadata, @@ -28,7 +41,9 @@ class LocationNode(Base): default=lambda: str(uuid.uuid4()), ) title: Mapped[str] = mapped_column(String) - kind: Mapped[str] = mapped_column(String) + kind: Mapped[LocationTypeEnum] = mapped_column( + SQLEnum(LocationTypeEnum, native_enum=False, length=50) + ) parent_id: Mapped[str | None] = mapped_column( ForeignKey("location_nodes.id"), nullable=True, diff --git a/backend/scaffold.py b/backend/scaffold.py index f781c67..0016bf8 100644 --- a/backend/scaffold.py +++ b/backend/scaffold.py @@ -120,22 +120,29 @@ async def _create_location_tree( organization_ids = data.get("organization_ids", []) if organization_ids: - for org_id in organization_ids: - stmt = select(location_organizations).where( - location_organizations.c.location_id == location_id, - location_organizations.c.organization_id == org_id, + allowed_types_for_orgs = {"HOSPITAL", "CLINIC", "PRACTICE", "TEAM"} + if location_type.value not in allowed_types_for_orgs: + logger.warning( + f"Organization IDs can only be assigned to HOSPITAL, CLINIC, PRACTICE, or TEAM. " + f"Skipping organization assignment for location '{name}' (type: {location_type.value})" ) - result = await session.execute(stmt) - existing_org = result.first() - if not existing_org: - await session.execute( - location_organizations.insert().values( - location_id=location_id, organization_id=org_id - ) - ) - logger.debug( - f"Assigned organization '{org_id}' to location '{name}'" + else: + for org_id in organization_ids: + stmt = select(location_organizations).where( + location_organizations.c.location_id == location_id, + location_organizations.c.organization_id == org_id, ) + result = await session.execute(stmt) + existing_org = result.first() + if not existing_org: + await session.execute( + location_organizations.insert().values( + location_id=location_id, organization_id=org_id + ) + ) + logger.debug( + f"Assigned organization '{org_id}' to location '{name}'" + ) for child_data in children: await _create_location_tree(session, child_data, location_id) diff --git a/keycloak/tasks.json b/keycloak/tasks.json index c4cdb6b..a65251d 100644 --- a/keycloak/tasks.json +++ b/keycloak/tasks.json @@ -1312,12 +1312,13 @@ "id" : "32ae8303-353f-43ad-bc9d-130923bdea7d", "name" : "organization", "protocol" : "openid-connect", - "protocolMapper" : "oidc-organization-membership-mapper", + "protocolMapper" : "oidc-usermodel-attribute-mapper", "consentRequired" : false, "config" : { "introspection.token.claim" : "true", "multivalued" : "true", "userinfo.token.claim" : "true", + "user.attribute" : "organization", "id.token.claim" : "true", "access.token.claim" : "true", "claim.name" : "organization", diff --git a/web/api/gql/generated.ts b/web/api/gql/generated.ts index 4d2b898..2e606e7 100644 --- a/web/api/gql/generated.ts +++ b/web/api/gql/generated.ts @@ -308,6 +308,7 @@ export type QueryPatientArgs = { export type QueryPatientsArgs = { locationNodeId?: InputMaybe; + rootLocationIds?: InputMaybe>; states?: InputMaybe>; }; @@ -330,6 +331,7 @@ export type QueryTaskArgs = { export type QueryTasksArgs = { assigneeId?: InputMaybe; patientId?: InputMaybe; + rootLocationIds?: InputMaybe>; }; @@ -463,6 +465,7 @@ export type GetPatientQuery = { __typename?: 'Query', patient?: { __typename?: ' export type GetPatientsQueryVariables = Exact<{ locationId?: InputMaybe; + rootLocationIds?: InputMaybe | Scalars['ID']['input']>; states?: InputMaybe | PatientState>; }>; @@ -476,6 +479,14 @@ export type GetTaskQueryVariables = Exact<{ export type GetTaskQuery = { __typename?: 'Query', task?: { __typename?: 'TaskType', id: string, title: string, description?: string | null, done: boolean, dueDate?: any | null, checksum: string, patient: { __typename?: 'PatientType', id: string, name: string }, assignee?: { __typename?: 'UserType', id: string, name: string } | null, properties: Array<{ __typename?: 'PropertyValueType', textValue?: string | null, numberValue?: number | null, booleanValue?: boolean | null, dateValue?: any | null, dateTimeValue?: any | null, selectValue?: string | null, multiSelectValues?: Array | null, definition: { __typename?: 'PropertyDefinitionType', id: string, name: string, description?: string | null, fieldType: FieldType, isActive: boolean, allowedEntities: Array, options: Array } }> } | null }; +export type GetTasksQueryVariables = Exact<{ + rootLocationIds?: InputMaybe | Scalars['ID']['input']>; + assigneeId?: InputMaybe; +}>; + + +export type GetTasksQuery = { __typename?: 'Query', tasks: Array<{ __typename?: 'TaskType', id: string, title: string, description?: string | null, done: boolean, dueDate?: any | null, creationDate: any, updateDate?: any | null, patient: { __typename?: 'PatientType', id: string, name: string, assignedLocation?: { __typename?: 'LocationNodeType', id: string, title: string, parent?: { __typename?: 'LocationNodeType', id: string, title: string } | null } | null, assignedLocations: Array<{ __typename?: 'LocationNodeType', id: string, title: string, kind: LocationType, parent?: { __typename?: 'LocationNodeType', id: string, title: string, parent?: { __typename?: 'LocationNodeType', id: string, title: string } | null } | null }> }, assignee?: { __typename?: 'UserType', id: string, name: string, avatarUrl?: string | null } | null }> }; + export type GetUsersQueryVariables = Exact<{ [key: string]: never; }>; @@ -992,8 +1003,12 @@ export const useGetPatientQuery = < )}; export const GetPatientsDocument = ` - query GetPatients($locationId: ID, $states: [PatientState!]) { - patients(locationNodeId: $locationId, states: $states) { + query GetPatients($locationId: ID, $rootLocationIds: [ID!], $states: [PatientState!]) { + patients( + locationNodeId: $locationId + rootLocationIds: $rootLocationIds + states: $states + ) { id name firstname @@ -1184,6 +1199,66 @@ export const useGetTaskQuery = < } )}; +export const GetTasksDocument = ` + query GetTasks($rootLocationIds: [ID!], $assigneeId: ID) { + tasks(rootLocationIds: $rootLocationIds, assigneeId: $assigneeId) { + id + title + description + done + dueDate + creationDate + updateDate + patient { + id + name + assignedLocation { + id + title + parent { + id + title + } + } + assignedLocations { + id + title + kind + parent { + id + title + parent { + id + title + } + } + } + } + assignee { + id + name + avatarUrl + } + } +} + `; + +export const useGetTasksQuery = < + TData = GetTasksQuery, + TError = unknown + >( + variables?: GetTasksQueryVariables, + options?: Omit, 'queryKey'> & { queryKey?: UseQueryOptions['queryKey'] } + ) => { + + return useQuery( + { + queryKey: variables === undefined ? ['GetTasks'] : ['GetTasks', variables], + queryFn: fetcher(GetTasksDocument, variables), + ...options + } + )}; + export const GetUsersDocument = ` query GetUsers { users { diff --git a/web/api/graphql/GetPatients.graphql b/web/api/graphql/GetPatients.graphql index e2be458..17bc1a0 100644 --- a/web/api/graphql/GetPatients.graphql +++ b/web/api/graphql/GetPatients.graphql @@ -1,5 +1,5 @@ -query GetPatients($locationId: ID, $states: [PatientState!]) { - patients(locationNodeId: $locationId, states: $states) { +query GetPatients($locationId: ID, $rootLocationIds: [ID!], $states: [PatientState!]) { + patients(locationNodeId: $locationId, rootLocationIds: $rootLocationIds, states: $states) { id name firstname diff --git a/web/api/graphql/GetTasks.graphql b/web/api/graphql/GetTasks.graphql new file mode 100644 index 0000000..ec6f7b8 --- /dev/null +++ b/web/api/graphql/GetTasks.graphql @@ -0,0 +1,42 @@ +query GetTasks($rootLocationIds: [ID!], $assigneeId: ID) { + tasks(rootLocationIds: $rootLocationIds, assigneeId: $assigneeId) { + id + title + description + done + dueDate + creationDate + updateDate + patient { + id + name + assignedLocation { + id + title + parent { + id + title + } + } + assignedLocations { + id + title + kind + parent { + id + title + parent { + id + title + } + } + } + } + assignee { + id + name + avatarUrl + } + } +} + diff --git a/web/components/layout/Page.tsx b/web/components/layout/Page.tsx index f0f2bf7..e5a45d7 100644 --- a/web/components/layout/Page.tsx +++ b/web/components/layout/Page.tsx @@ -12,6 +12,8 @@ import { Expandable, LoadingContainer, MarkdownInterpreter, + Select, + SelectOption, useLocalStorage } from '@helpwave/hightide' import { getConfig } from '@/utils/config' @@ -35,6 +37,8 @@ import { useRouter } from 'next/router' import { useTasksContext } from '@/hooks/useTasksContext' import { hashString } from '@/utils/hash' import { useSwipeGesture } from '@/hooks/useSwipeGesture' +import { LocationChips } from '@/components/patients/LocationChips' +import { LocationSelectionDialog } from '@/components/locations/LocationSelectionDialog' export const StagingDisclaimerDialog = () => { const config = getConfig() @@ -204,8 +208,21 @@ type HeaderProps = HTMLAttributes & { export const Header = ({ onMenuClick, isMenuOpen, ...props }: HeaderProps) => { const router = useRouter() - const { user, rootLocations } = useTasksContext() - + const { user, rootLocations, selectedRootLocationIds, update } = useTasksContext() + const translation = useTasksTranslation() + const [isLocationPickerOpen, setIsLocationPickerOpen] = useState(false) + + const selectedRootLocations = rootLocations?.filter(loc => selectedRootLocationIds?.includes(loc.id)) || [] + const firstSelectedRootLocation = selectedRootLocations[0] + + const handleRootLocationSelect = (locations: Array<{ id: string; title: string; kind?: string }>) => { + if (locations.length === 0) return + update(prevState => ({ + ...prevState, + selectedRootLocationIds: locations.map(loc => loc.id), + })) + setIsLocationPickerOpen(false) + } return (
{ {isMenuOpen ? : }
-
+
{rootLocations && rootLocations.length > 0 && (
- - - {rootLocations.map(loc => loc.title).join(', ')} - + + setIsLocationPickerOpen(false)} + onSelect={handleRootLocationSelect} + initialSelectedIds={selectedRootLocationIds || []} + multiSelect={true} + useCase="default" + /> +
+ )} + {selectedRootLocations.length > 0 && ( +
+
)}
diff --git a/web/components/locations/LocationSelectionDialog.tsx b/web/components/locations/LocationSelectionDialog.tsx index 6293a36..734e754 100644 --- a/web/components/locations/LocationSelectionDialog.tsx +++ b/web/components/locations/LocationSelectionDialog.tsx @@ -48,11 +48,13 @@ interface LocationTreeItemProps { const getKindStyles = (kind: string) => { const k = kind.toUpperCase() - if (k.includes('CLINIC')) return 'bg-blue-100 text-blue-700 dark:bg-blue-900/30 dark:text-blue-300' - if (k.includes('WARD')) return 'bg-orange-100 text-orange-700 dark:bg-orange-900/30 dark:text-orange-300' - if (k.includes('TEAM')) return 'bg-purple-100 text-purple-700 dark:bg-purple-900/30 dark:text-purple-300' - if (k.includes('ROOM')) return 'bg-gray-100 text-gray-700 dark:bg-gray-800 dark:text-gray-300' - if (k.includes('BED')) return 'bg-emerald-100 text-emerald-700 dark:bg-emerald-900/30 dark:text-emerald-300' + if (k === 'HOSPITAL') return 'bg-red-100 text-red-700 dark:bg-red-900/30 dark:text-red-300' + if (k === 'PRACTICE') return 'bg-indigo-100 text-indigo-700 dark:bg-indigo-900/30 dark:text-indigo-300' + if (k === 'CLINIC') return 'bg-blue-100 text-blue-700 dark:bg-blue-900/30 dark:text-blue-300' + if (k === 'TEAM') return 'bg-purple-100 text-purple-700 dark:bg-purple-900/30 dark:text-purple-300' + if (k === 'WARD') return 'bg-orange-100 text-orange-700 dark:bg-orange-900/30 dark:text-orange-300' + if (k === 'ROOM') return 'bg-gray-100 text-gray-700 dark:bg-gray-800 dark:text-gray-300' + if (k === 'BED') return 'bg-emerald-100 text-emerald-700 dark:bg-emerald-900/30 dark:text-emerald-300' return 'bg-surface-subdued text-text-tertiary' } @@ -347,6 +349,7 @@ export const LocationSelectionDialog = ({ const handleConfirm = () => { if (!data?.locationNodes) return + if (selectedIds.size === 0) return const nodes = data.locationNodes as LocationNodeType[] const selectedNodes = nodes.filter(n => selectedIds.has(n.id)) onSelect(selectedNodes) @@ -451,7 +454,7 @@ export const LocationSelectionDialog = ({ diff --git a/web/components/patients/LocationChips.tsx b/web/components/patients/LocationChips.tsx index 59a4eb8..513eacb 100644 --- a/web/components/patients/LocationChips.tsx +++ b/web/components/patients/LocationChips.tsx @@ -21,11 +21,13 @@ interface LocationChipsProps { const getKindStyles = (kind: string | undefined) => { if (!kind) return 'bg-surface-subdued text-text-tertiary' const k = kind.toUpperCase() - if (k.includes('CLINIC')) return 'bg-blue-100 text-blue-700 dark:bg-blue-900/30 dark:text-blue-300' - if (k.includes('WARD')) return 'bg-orange-100 text-orange-700 dark:bg-orange-900/30 dark:text-orange-300' - if (k.includes('TEAM')) return 'bg-purple-100 text-purple-700 dark:bg-purple-900/30 dark:text-purple-300' - if (k.includes('ROOM')) return 'bg-gray-100 text-gray-700 dark:bg-gray-800 dark:text-gray-300' - if (k.includes('BED')) return 'bg-emerald-100 text-emerald-700 dark:bg-emerald-900/30 dark:text-emerald-300' + if (k === 'HOSPITAL') return 'bg-red-100 text-red-700 dark:bg-red-900/30 dark:text-red-300' + if (k === 'PRACTICE') return 'bg-indigo-100 text-indigo-700 dark:bg-indigo-900/30 dark:text-indigo-300' + if (k === 'CLINIC') return 'bg-blue-100 text-blue-700 dark:bg-blue-900/30 dark:text-blue-300' + if (k === 'TEAM') return 'bg-purple-100 text-purple-700 dark:bg-purple-900/30 dark:text-purple-300' + if (k === 'WARD') return 'bg-orange-100 text-orange-700 dark:bg-orange-900/30 dark:text-orange-300' + if (k === 'ROOM') return 'bg-gray-100 text-gray-700 dark:bg-gray-800 dark:text-gray-300' + if (k === 'BED') return 'bg-emerald-100 text-emerald-700 dark:bg-emerald-900/30 dark:text-emerald-300' return 'bg-surface-subdued text-text-tertiary' } diff --git a/web/components/patients/PatientDetailView.tsx b/web/components/patients/PatientDetailView.tsx index 0bad2ad..7ab5509 100644 --- a/web/components/patients/PatientDetailView.tsx +++ b/web/components/patients/PatientDetailView.tsx @@ -171,7 +171,8 @@ export const PatientDetailView = ({ initialCreateData = {} }: PatientDetailViewProps) => { const translation = useTasksTranslation() - const { selectedLocationId } = useTasksContext() + const { selectedLocationId, selectedRootLocationIds, rootLocations } = useTasksContext() + const firstSelectedRootLocationId = selectedRootLocationIds && selectedRootLocationIds.length > 0 ? selectedRootLocationIds[0] : undefined const queryClient = useQueryClient() const [taskId, setTaskId] = useState(null) const [isCreatingTask, setIsCreatingTask] = useState(false) @@ -256,6 +257,22 @@ export const PatientDetailView = ({ } }, [patientData]) + useEffect(() => { + if (!isEditMode && firstSelectedRootLocationId && locationsData?.locationNodes && !formData.clinicId) { + const selectedRootLocation = locationsData.locationNodes.find( + loc => loc.id === firstSelectedRootLocationId && loc.kind === 'CLINIC' + ) + if (selectedRootLocation) { + const clinicLocation = selectedRootLocation as LocationNodeType + setSelectedClinic(clinicLocation) + setFormData(prev => ({ + ...prev, + clinicId: clinicLocation.id, + })) + } + } + }, [isEditMode, firstSelectedRootLocationId, locationsData, formData.clinicId]) + const { mutate: createPatient, isLoading: isCreating } = useCreatePatientMutation({ onSuccess: () => { queryClient.invalidateQueries({ queryKey: ['GetGlobalData'] }) diff --git a/web/components/patients/PatientList.tsx b/web/components/patients/PatientList.tsx index 067c3f0..02e6984 100644 --- a/web/components/patients/PatientList.tsx +++ b/web/components/patients/PatientList.tsx @@ -8,6 +8,7 @@ import { SmartDate } from '@/utils/date' import { LocationChips } from '@/components/patients/LocationChips' import { PatientStateChip } from '@/components/patients/PatientStateChip' import { useTasksTranslation } from '@/i18n/useTasksTranslation' +import { useTasksContext } from '@/hooks/useTasksContext' import type { ColumnDef } from '@tanstack/table-core' type PatientViewModel = { @@ -38,6 +39,7 @@ type PatientListProps = { export const PatientList = forwardRef(({ locationId, initialPatientId, onInitialPatientOpened, acceptedStates }, ref) => { const translation = useTasksTranslation() + const { selectedRootLocationIds } = useTasksContext() const [isPanelOpen, setIsPanelOpen] = useState(false) const [selectedPatient, setSelectedPatient] = useState(undefined) const [searchQuery, setSearchQuery] = useState('') @@ -46,6 +48,7 @@ export const PatientList = forwardRef(({ locat const { data: queryData, refetch } = useGetPatientsQuery( { locationId: locationId, + rootLocationIds: selectedRootLocationIds && selectedRootLocationIds.length > 0 ? selectedRootLocationIds : undefined, states: acceptedStates }, { diff --git a/web/hooks/useTasksContext.tsx b/web/hooks/useTasksContext.tsx index c2e06eb..996394f 100644 --- a/web/hooks/useTasksContext.tsx +++ b/web/hooks/useTasksContext.tsx @@ -3,16 +3,19 @@ import { createContext, type PropsWithChildren, useContext, useEffect, useState import { usePathname } from 'next/navigation' import { useGetGlobalDataQuery } from '@/api/gql/generated' import { useAuth } from './useAuth' +import { useLocalStorage } from '@helpwave/hightide' type User = { id: string, name: string, avatarUrl?: string | null, + organizations?: string | null, } type LocationNode = { id: string, title: string, + kind?: string, } type SidebarContextType = { @@ -30,6 +33,7 @@ export type TasksContextState = { wards?: LocationNode[], clinics?: LocationNode[], selectedLocationId?: string, + selectedRootLocationIds?: string[], sidebar: SidebarContextType, user?: User, rootLocations?: LocationNode[], @@ -54,12 +58,17 @@ export const useTasksContext = (): TasksContextType => { export const TasksContextProvider = ({ children }: PropsWithChildren) => { const pathName = usePathname() const { identity, isLoading: isAuthLoading } = useAuth() + const { + value: storedSelectedRootLocationIds, + setValue: setStoredSelectedRootLocationIds + } = useLocalStorage('selected-root-location-ids', []) const [state, setState] = useState({ sidebar: { isShowingTeams: false, isShowingWards: false, isShowingClinics: false, - } + }, + selectedRootLocationIds: storedSelectedRootLocationIds.length > 0 ? storedSelectedRootLocationIds : undefined, }) const { data } = useGetGlobalDataQuery(undefined, { @@ -72,31 +81,62 @@ export const TasksContextProvider = ({ children }: PropsWithChildren) => { useEffect(() => { const totalPatientsCount = data?.patients?.length ?? 0 const waitingPatientsCount = data?.waitingPatients?.length ?? 0 - setState(prevState => ({ - ...prevState, - user: data?.me ? { - id: data.me.id, - name: data.me.name, - avatarUrl: data.me.avatarUrl, - } : undefined, - myTasksCount: data?.me?.tasks?.filter(t => !t.done).length ?? 0, - totalPatientsCount, - waitingPatientsCount, - locationPatientsCount: prevState.selectedLocationId - ? data?.patients?.filter(p => p.assignedLocation?.id === prevState.selectedLocationId).length ?? 0 - : totalPatientsCount, - teams: data?.teams, - wards: data?.wards, - clinics: data?.clinics, - rootLocations: data?.me?.rootLocations?.map(loc => ({ id: loc.id, title: loc.title })) ?? [], - })) - }, [data]) + const rootLocations = data?.me?.rootLocations?.map(loc => ({ id: loc.id, title: loc.title, kind: loc.kind })) ?? [] + + setState(prevState => { + let selectedRootLocationIds = prevState.selectedRootLocationIds || [] + + if (rootLocations.length > 0) { + const validIds = selectedRootLocationIds.filter(id => rootLocations.find(loc => loc.id === id)) + if (validIds.length === 0) { + selectedRootLocationIds = [rootLocations[0].id] + setStoredSelectedRootLocationIds(selectedRootLocationIds) + } else { + selectedRootLocationIds = validIds + if (selectedRootLocationIds.length !== prevState.selectedRootLocationIds?.length) { + setStoredSelectedRootLocationIds(selectedRootLocationIds) + } + } + } + + return { + ...prevState, + user: data?.me ? { + id: data.me.id, + name: data.me.name, + avatarUrl: data.me.avatarUrl, + organizations: data.me.organizations ?? null + } : undefined, + myTasksCount: data?.me?.tasks?.filter(t => !t.done).length ?? 0, + totalPatientsCount, + waitingPatientsCount, + locationPatientsCount: prevState.selectedLocationId + ? data?.patients?.filter(p => p.assignedLocation?.id === prevState.selectedLocationId).length ?? 0 + : totalPatientsCount, + teams: data?.teams, + wards: data?.wards, + clinics: data?.clinics, + rootLocations, + selectedRootLocationIds, + } + }) + }, [data, setStoredSelectedRootLocationIds]) + + const updateState: Dispatch> = (updater) => { + setState(prevState => { + const newState = typeof updater === 'function' ? updater(prevState) : updater + if (newState.selectedRootLocationIds !== prevState.selectedRootLocationIds) { + setStoredSelectedRootLocationIds(newState.selectedRootLocationIds || []) + } + return newState + }) + } return ( diff --git a/web/pages/location/[id].tsx b/web/pages/location/[id].tsx index 3ab6c05..f0fa46b 100644 --- a/web/pages/location/[id].tsx +++ b/web/pages/location/[id].tsx @@ -9,22 +9,26 @@ import { TaskList, type TaskViewModel } from '@/components/tasks/TaskList' import { useGetLocationNodeQuery, useGetPatientsQuery, type LocationType } from '@/api/gql/generated' import { useMemo } from 'react' import { useRouter } from 'next/router' +import { useTasksContext } from '@/hooks/useTasksContext' import { LocationChips } from '@/components/patients/LocationChips' import { LOCATION_PATH_SEPARATOR } from '@/utils/location' const getKindStyles = (kind: string) => { const k = kind.toUpperCase() - if (k.includes('CLINIC')) return 'bg-blue-100 text-blue-700 dark:bg-blue-900/30 dark:text-blue-300' - if (k.includes('WARD')) return 'bg-orange-100 text-orange-700 dark:bg-orange-900/30 dark:text-orange-300' - if (k.includes('TEAM')) return 'bg-purple-100 text-purple-700 dark:bg-purple-900/30 dark:text-purple-300' - if (k.includes('ROOM')) return 'bg-gray-100 text-gray-700 dark:bg-gray-800 dark:text-gray-300' - if (k.includes('BED')) return 'bg-emerald-100 text-emerald-700 dark:bg-emerald-900/30 dark:text-emerald-300' + if (k === 'HOSPITAL') return 'bg-red-100 text-red-700 dark:bg-red-900/30 dark:text-red-300' + if (k === 'PRACTICE') return 'bg-indigo-100 text-indigo-700 dark:bg-indigo-900/30 dark:text-indigo-300' + if (k === 'CLINIC') return 'bg-blue-100 text-blue-700 dark:bg-blue-900/30 dark:text-blue-300' + if (k === 'TEAM') return 'bg-purple-100 text-purple-700 dark:bg-purple-900/30 dark:text-purple-300' + if (k === 'WARD') return 'bg-orange-100 text-orange-700 dark:bg-orange-900/30 dark:text-orange-300' + if (k === 'ROOM') return 'bg-gray-100 text-gray-700 dark:bg-gray-800 dark:text-gray-300' + if (k === 'BED') return 'bg-emerald-100 text-emerald-700 dark:bg-emerald-900/30 dark:text-emerald-300' return 'bg-surface-subdued text-text-tertiary' } const LocationPage: NextPage = () => { const translation = useTasksTranslation() const router = useRouter() + const { selectedRootLocationIds } = useTasksContext() const id = Array.isArray(router.query['id']) ? router.query['id'][0] : router.query['id'] const { data: locationData, isLoading: isLoadingLocation, isError: isLocationError } = useGetLocationNodeQuery( @@ -37,7 +41,7 @@ const LocationPage: NextPage = () => { ) const { data: patientsData, refetch: refetchPatients, isLoading: isLoadingPatients } = useGetPatientsQuery( - { locationId: id }, + { locationId: id, rootLocationIds: selectedRootLocationIds && selectedRootLocationIds.length > 0 ? selectedRootLocationIds : undefined }, { enabled: !!id, refetchInterval: 5000, diff --git a/web/pages/patients/index.tsx b/web/pages/patients/index.tsx index de11bd5..83a6065 100644 --- a/web/pages/patients/index.tsx +++ b/web/pages/patients/index.tsx @@ -10,14 +10,15 @@ import { useRouter } from 'next/router' const PatientsPage: NextPage = () => { const translation = useTasksTranslation() const router = useRouter() - const { selectedLocationId } = useTasksContext() + const { selectedRootLocationIds } = useTasksContext() const patientId = router.query['patientId'] as string | undefined + const firstSelectedRootLocationId = selectedRootLocationIds && selectedRootLocationIds.length > 0 ? selectedRootLocationIds[0] : undefined return ( router.replace('/patients', undefined, { shallow: true })} /> diff --git a/web/pages/settings/index.tsx b/web/pages/settings/index.tsx index de322ca..68143bc 100644 --- a/web/pages/settings/index.tsx +++ b/web/pages/settings/index.tsx @@ -16,7 +16,7 @@ import { import type { HightideTranslationLocales, ThemeType } from '@helpwave/hightide' import { useTasksContext } from '@/hooks/useTasksContext' import { useAuth } from '@/hooks/useAuth' -import { LogOut, MonitorCog, MoonIcon, SunIcon, Trash2, ClipboardList, Shield, TableProperties } from 'lucide-react' +import { LogOut, MonitorCog, MoonIcon, SunIcon, Trash2, ClipboardList, Shield, TableProperties, Building2 } from 'lucide-react' import { useRouter } from 'next/router' import clsx from 'clsx' import { removeUser } from '@/api/auth/authService' @@ -112,6 +112,28 @@ const SettingsPage: NextPage = () => {
+ {/* Organizations Section */} + {user?.organizations && ( +
+

{translation('organizations') || 'Organizations'}

+
+ +
+ {user.organizations.split(',').map((org, index) => { + const trimmedOrg = org.trim() + if (!trimmedOrg) return null + return ( + + {trimmedOrg} + {index < user.organizations!.split(',').length - 1 && ,} + + ) + })} +
+
+
+ )} + {/* System / Management */}

{translation('system')}

diff --git a/web/pages/tasks/index.tsx b/web/pages/tasks/index.tsx index f24e335..511eeae 100644 --- a/web/pages/tasks/index.tsx +++ b/web/pages/tasks/index.tsx @@ -5,14 +5,19 @@ import { useTasksTranslation } from '@/i18n/useTasksTranslation' import { ContentPanel } from '@/components/layout/ContentPanel' import { TaskList, type TaskViewModel } from '@/components/tasks/TaskList' import { useMemo } from 'react' -import { useGetMyTasksQuery } from '@/api/gql/generated' +import { useGetTasksQuery } from '@/api/gql/generated' import { useRouter } from 'next/router' +import { useTasksContext } from '@/hooks/useTasksContext' const TasksPage: NextPage = () => { const translation = useTasksTranslation() const router = useRouter() - const { data: queryData, refetch } = useGetMyTasksQuery( - undefined, + const { selectedRootLocationId, user } = useTasksContext() + const { data: queryData, refetch } = useGetTasksQuery( + { + rootLocationId: selectedRootLocationId, + assigneeId: user?.id, + }, { refetchInterval: 5000, refetchOnWindowFocus: true, @@ -22,9 +27,9 @@ const TasksPage: NextPage = () => { const taskId = router.query['taskId'] as string | undefined const tasks: TaskViewModel[] = useMemo(() => { - if (!queryData?.me?.tasks) return [] + if (!queryData?.tasks) return [] - return queryData.me.tasks.map((task) => ({ + return queryData.tasks.map((task) => ({ id: task.id, name: task.title, description: task.description || undefined, From da8bb45a8c7412b5be63ddca2deb76ba77bcae5d Mon Sep 17 00:00:00 2001 From: Felix Evers Date: Tue, 23 Dec 2025 22:12:13 +0100 Subject: [PATCH 07/16] Fix syntax error remove conflict marker --- backend/api/resolvers/task.py | 1 - 1 file changed, 1 deletion(-) diff --git a/backend/api/resolvers/task.py b/backend/api/resolvers/task.py index 235533d..800fd1e 100644 --- a/backend/api/resolvers/task.py +++ b/backend/api/resolvers/task.py @@ -102,7 +102,6 @@ async def tasks( else: root_cte = cte ->>>>>>> 7eb41e8 (Make root location picker multiselect with LocationSelectionDialog filter by organizations and improve error handling) query = ( select(models.Task) .options( From eafe203aa8c1bb56ed1a51c6f105988b0d68753e Mon Sep 17 00:00:00 2001 From: Felix Evers Date: Tue, 23 Dec 2025 22:12:58 +0100 Subject: [PATCH 08/16] Add merge migration for location type enum and remove organizations --- ...ba_merge_location_type_enum_and_remove_.py | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 backend/database/migrations/versions/0de3078888ba_merge_location_type_enum_and_remove_.py diff --git a/backend/database/migrations/versions/0de3078888ba_merge_location_type_enum_and_remove_.py b/backend/database/migrations/versions/0de3078888ba_merge_location_type_enum_and_remove_.py new file mode 100644 index 0000000..1b31f78 --- /dev/null +++ b/backend/database/migrations/versions/0de3078888ba_merge_location_type_enum_and_remove_.py @@ -0,0 +1,28 @@ +"""Merge location type enum and remove organizations + +Revision ID: 0de3078888ba +Revises: add_location_type_enum, remove_organizations_from_users +Create Date: 2025-12-23 22:12:52.604315 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '0de3078888ba' +down_revision: Union[str, Sequence[str], None] = ('add_location_type_enum', 'remove_organizations_from_users') +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + pass + + +def downgrade() -> None: + """Downgrade schema.""" + pass From d0714ba8bfc0649133ac305cee5d179aea0dc693 Mon Sep 17 00:00:00 2001 From: Felix Evers Date: Tue, 23 Dec 2025 22:17:20 +0100 Subject: [PATCH 09/16] Fix location roots query to show all user root locations and remove organizations field from User model --- backend/api/context.py | 3 --- backend/api/resolvers/location.py | 8 ++------ 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/backend/api/context.py b/backend/api/context.py index e3b8cdd..8461c7a 100644 --- a/backend/api/context.py +++ b/backend/api/context.py @@ -102,7 +102,6 @@ async def get_context( lastname=lastname, title="User", avatar_url=picture, - organizations=organizations, ) session.add(new_user) await session.commit() @@ -127,7 +126,6 @@ async def get_context( or db_user.lastname != lastname or db_user.email != email or db_user.avatar_url != picture - or db_user.organizations != organizations ): db_user.username = username db_user.firstname = firstname @@ -135,7 +133,6 @@ async def get_context( db_user.email = email if picture: db_user.avatar_url = picture - db_user.organizations = organizations session.add(db_user) await session.commit() await session.refresh(db_user) diff --git a/backend/api/resolvers/location.py b/backend/api/resolvers/location.py index 3cc8ef0..4e90dde 100644 --- a/backend/api/resolvers/location.py +++ b/backend/api/resolvers/location.py @@ -16,16 +16,12 @@ async def location_roots(self, info: Info) -> list[LocationNodeType]: accessible_location_ids = await auth_service.get_user_accessible_location_ids( info.context.user, info.context ) - + if not accessible_location_ids: return [] - + result = await info.context.db.execute( select(models.LocationNode) - .join( - models.location_organizations, - models.LocationNode.id == models.location_organizations.c.location_id, - ) .where( models.LocationNode.parent_id.is_(None), models.LocationNode.id.in_(accessible_location_ids), From 2471bd89b85dfe87ae4572f4030e6ab7a474e10a Mon Sep 17 00:00:00 2001 From: Felix Evers Date: Tue, 23 Dec 2025 22:17:32 +0100 Subject: [PATCH 10/16] Fix user root_locations to show all root locations not just those with organizations --- backend/api/types/user.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/backend/api/types/user.py b/backend/api/types/user.py index d5f1676..aaeb188 100644 --- a/backend/api/types/user.py +++ b/backend/api/types/user.py @@ -101,10 +101,6 @@ async def root_locations( models.user_root_locations, models.LocationNode.id == models.user_root_locations.c.location_id, ) - .join( - models.location_organizations, - models.LocationNode.id == models.location_organizations.c.location_id, - ) .where(models.user_root_locations.c.user_id == self.id) .where(models.LocationNode.parent_id.is_(None)) .distinct() From f2974b189341fc2430e96728f4a5ba1595661d83 Mon Sep 17 00:00:00 2001 From: Felix Evers Date: Wed, 24 Dec 2025 01:23:44 +0100 Subject: [PATCH 11/16] fix location picker initialization and button label display --- backend/api/context.py | 106 +++++++++--- backend/api/resolvers/__init__.py | 5 +- backend/api/resolvers/location.py | 163 +++++++++++++++++- backend/api/resolvers/patient.py | 19 +- backend/api/types/user.py | 37 +++- keycloak/tasks.json | 8 +- web/api/gql/generated.ts | 4 +- web/api/graphql/GlobalData.graphql | 1 + web/api/graphql/Subscriptions.graphql | 12 ++ web/components/layout/Page.tsx | 43 +++-- .../locations/LocationSelectionDialog.tsx | 129 +++++++++++++- web/components/patients/PatientDetailView.tsx | 34 ++-- web/hooks/useTasksContext.tsx | 84 ++++++++- web/i18n/translations.ts | 12 ++ web/locales/de-DE.arb | 4 + web/locales/en-US.arb | 4 + 16 files changed, 597 insertions(+), 68 deletions(-) diff --git a/backend/api/context.py b/backend/api/context.py index 8461c7a..b0d4737 100644 --- a/backend/api/context.py +++ b/backend/api/context.py @@ -1,4 +1,5 @@ import asyncio +import logging from typing import Any import strawberry @@ -9,6 +10,7 @@ from fastapi import Depends from graphql import GraphQLError from sqlalchemy import delete, select +from sqlalchemy.dialects.postgresql import insert from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession from starlette.requests import HTTPConnection @@ -48,10 +50,11 @@ def __getattr__(self, name): class Context(BaseContext): - def __init__(self, db: AsyncSession, user: "User | None" = None): + def __init__(self, db: AsyncSession, user: "User | None" = None, organizations: str | None = None): super().__init__() self._db = db self.user = user + self.organizations = organizations self._accessible_location_ids: set[str] | None = None self._accessible_location_ids_lock = asyncio.Lock() self._db_lock = asyncio.Lock() @@ -67,6 +70,7 @@ async def get_context( ) -> Context: user_payload = get_user_payload(connection) db_user = None + organizations = None if user_payload: user_id = user_payload.get("sub") @@ -78,13 +82,33 @@ async def get_context( email = user_payload.get("email") picture = user_payload.get("picture") + # Debug: Log available keys in token to help diagnose missing organization claim + logger = logging.getLogger(__name__) organizations_raw = user_payload.get("organization") + + if organizations_raw is None: + # Check if organization scope is in the token + scope = user_payload.get("scope", "") + has_org_scope = "organization" in scope.split() if scope else False + # Use warning level so it's visible in logs + logger.warning( + f"Organization claim not found in token for user {user_payload.get('sub', 'unknown')}. " + f"Has organization scope: {has_org_scope}. " + f"Token scope: {scope}. " + f"Available claims: {sorted(user_payload.keys())}" + ) + # Also print to console for immediate visibility + print(f"WARNING: Organization claim missing. Scope: {scope}, Available claims: {sorted(user_payload.keys())}") + organizations = None if organizations_raw: if isinstance(organizations_raw, list): - organizations = ",".join(str(org) for org in organizations_raw) + # Filter out empty strings and None values + org_list = [str(org) for org in organizations_raw if org] + if org_list: + organizations = ",".join(org_list) else: - organizations = str(organizations_raw) + organizations = str(organizations_raw) if organizations_raw else None if user_id: result = await session.execute( @@ -139,48 +163,79 @@ async def get_context( if db_user: try: - await _update_user_root_locations(session, db_user, organizations) + # Debug output + if organizations is None: + print(f"WARNING: organizations is None for user {db_user.id} ({db_user.username})") + print(f"Token payload keys: {sorted(user_payload.keys())}") + print(f"Token scope: {user_payload.get('scope', 'N/A')}") + else: + print(f"Organizations for user {db_user.id}: {organizations}") + + await _update_user_root_locations( + session, + db_user, + organizations, + ) except Exception as e: raise GraphQLError( "Failed to update user root locations. Please contact an administrator if you believe this is an error.", extensions={"code": "INTERNAL_SERVER_ERROR"}, ) from e - return Context(db=session, user=db_user) + return Context(db=session, user=db_user, organizations=organizations) async def _update_user_root_locations( - session: AsyncSession, user: User, organizations: str | None + session: AsyncSession, + user: User, + organizations: str | None, ) -> None: - from database.models.location import location_organizations - organization_ids: list[str] = [] if organizations: organization_ids = [ - org_id.strip() for org_id in organizations.split(",") if org_id.strip() + org_id.strip() + for org_id in organizations.split(",") + if org_id.strip() ] + logger = logging.getLogger(__name__) + logger.info(f"Updating root locations for user {user.id} with organizations: {organization_ids}") + root_location_ids: list[str] = [] if organization_ids: + logger.info(f"Looking up locations for organization IDs: {organization_ids}") + # First check if any location_organizations entries exist + org_check = await session.execute( + select(location_organizations.c.organization_id, location_organizations.c.location_id) + .where(location_organizations.c.organization_id.in_(organization_ids)) + ) + org_entries = org_check.all() + logger.info(f"Found {len(org_entries)} location_organizations entries: {[(row[0], row[1]) for row in org_entries]}") + result = await session.execute( select(LocationNode) .join( location_organizations, LocationNode.id == location_organizations.c.location_id, ) - .where(location_organizations.c.organization_id.in_(organization_ids)) + .where( + location_organizations.c.organization_id.in_(organization_ids), + ), ) found_locations = result.scalars().all() root_location_ids = [loc.id for loc in found_locations] + logger.info(f"Found {len(found_locations)} existing locations for organizations: {[loc.id for loc in found_locations]}") found_org_ids = set() for loc in found_locations: org_result = await session.execute( select(location_organizations.c.organization_id).where( location_organizations.c.location_id == loc.id, - location_organizations.c.organization_id.in_(organization_ids) - ) + location_organizations.c.organization_id.in_( + organization_ids, + ), + ), ) found_org_ids.update(row[0] for row in org_result.all()) @@ -194,20 +249,25 @@ async def _update_user_root_locations( session.add(new_location) await session.flush() await session.refresh(new_location) - + await session.execute( location_organizations.insert().values( - location_id=new_location.id, organization_id=org_id - ) + location_id=new_location.id, + organization_id=org_id, + ), ) root_location_ids.append(new_location.id) + logger.info(f"Created new location {new_location.id} for organization {org_id}") + + logger.info(f"Total root location IDs: {root_location_ids}") + if not root_location_ids: personal_org_title = f"{user.username}'s Organization" result = await session.execute( select(LocationNode).where( LocationNode.title == personal_org_title, LocationNode.parent_id.is_(None), - ) + ), ) personal_location = result.scalars().first() @@ -221,20 +281,26 @@ async def _update_user_root_locations( await session.flush() root_location_ids = [personal_location.id] + logger.info(f"Using personal location: {personal_location.id}") await session.execute( - delete(user_root_locations).where(user_root_locations.c.user_id == user.id) + delete(user_root_locations).where( + user_root_locations.c.user_id == user.id, + ), ) if root_location_ids: - from sqlalchemy.dialects.postgresql import insert stmt = insert(user_root_locations).values( [ {"user_id": user.id, "location_id": loc_id} for loc_id in root_location_ids - ] + ], + ) + stmt = stmt.on_conflict_do_nothing( + index_elements=["user_id", "location_id"], ) - stmt = stmt.on_conflict_do_nothing(index_elements=["user_id", "location_id"]) await session.execute(stmt) + logger.info(f"Inserted {len(root_location_ids)} root locations for user {user.id}: {root_location_ids}") await session.commit() + logger.info(f"Root locations update completed for user {user.id}") diff --git a/backend/api/resolvers/__init__.py b/backend/api/resolvers/__init__.py index a8312ab..228d89e 100644 --- a/backend/api/resolvers/__init__.py +++ b/backend/api/resolvers/__init__.py @@ -1,6 +1,6 @@ import strawberry -from .location import LocationQuery +from .location import LocationMutation, LocationQuery, LocationSubscription from .patient import PatientMutation, PatientQuery, PatientSubscription from .property import PropertyDefinitionMutation, PropertyDefinitionQuery from .task import TaskMutation, TaskQuery, TaskSubscription @@ -23,10 +23,11 @@ class Mutation( PatientMutation, TaskMutation, PropertyDefinitionMutation, + LocationMutation, ): pass @strawberry.type -class Subscription(PatientSubscription, TaskSubscription): +class Subscription(PatientSubscription, TaskSubscription, LocationSubscription): pass diff --git a/backend/api/resolvers/location.py b/backend/api/resolvers/location.py index 4e90dde..794a4a1 100644 --- a/backend/api/resolvers/location.py +++ b/backend/api/resolvers/location.py @@ -1,6 +1,10 @@ +from collections.abc import AsyncGenerator + import strawberry +from api.audit import audit_log from api.context import Info -from api.inputs import LocationType +from api.inputs import CreateLocationNodeInput, LocationType, UpdateLocationNodeInput +from api.resolvers.base import BaseMutationResolver, BaseSubscriptionResolver from api.services.authorization import AuthorizationService from api.types.location import LocationNodeType from database import models @@ -116,3 +120,160 @@ async def location_nodes( result = await db.execute(query) return result.scalars().all() + + +@strawberry.type +class LocationMutation(BaseMutationResolver[models.LocationNode]): + @strawberry.mutation + @audit_log("create_location_node") + async def create_location_node( + self, + info: Info, + data: CreateLocationNodeInput, + ) -> LocationNodeType: + db = info.context.db + auth_service = AuthorizationService(db) + accessible_location_ids = await auth_service.get_user_accessible_location_ids( + info.context.user, info.context + ) + + if not accessible_location_ids: + raise GraphQLError( + "Insufficient permission. Please contact an administrator if you believe this is an error.", + extensions={"code": "FORBIDDEN"}, + ) + + if data.parent_id and data.parent_id not in accessible_location_ids: + raise GraphQLError( + "Insufficient permission. Please contact an administrator if you believe this is an error.", + extensions={"code": "FORBIDDEN"}, + ) + + location = models.LocationNode( + title=data.title, + kind=data.kind.value, + parent_id=data.parent_id, + ) + + location = await BaseMutationResolver.create_and_notify( + info, location, models.LocationNode, "location_node" + ) + return location + + @strawberry.mutation + @audit_log("update_location_node") + async def update_location_node( + self, + info: Info, + id: strawberry.ID, + data: UpdateLocationNodeInput, + ) -> LocationNodeType: + db = info.context.db + auth_service = AuthorizationService(db) + accessible_location_ids = await auth_service.get_user_accessible_location_ids( + info.context.user, info.context + ) + + if not accessible_location_ids: + raise GraphQLError( + "Insufficient permission. Please contact an administrator if you believe this is an error.", + extensions={"code": "FORBIDDEN"}, + ) + + result = await db.execute( + select(models.LocationNode).where(models.LocationNode.id == id) + ) + location = result.scalars().first() + + if not location: + raise GraphQLError( + "Location not found.", + extensions={"code": "NOT_FOUND"}, + ) + + if location.id not in accessible_location_ids: + raise GraphQLError( + "Insufficient permission. Please contact an administrator if you believe this is an error.", + extensions={"code": "FORBIDDEN"}, + ) + + if data.parent_id is not None and data.parent_id not in accessible_location_ids: + raise GraphQLError( + "Insufficient permission. Please contact an administrator if you believe this is an error.", + extensions={"code": "FORBIDDEN"}, + ) + + if data.title is not None: + location.title = data.title + if data.kind is not None: + location.kind = data.kind.value + if data.parent_id is not None: + location.parent_id = data.parent_id + + location = await BaseMutationResolver.update_and_notify( + info, location, models.LocationNode, "location_node" + ) + return location + + @strawberry.mutation + @audit_log("delete_location_node") + async def delete_location_node(self, info: Info, id: strawberry.ID) -> bool: + db = info.context.db + auth_service = AuthorizationService(db) + accessible_location_ids = await auth_service.get_user_accessible_location_ids( + info.context.user, info.context + ) + + if not accessible_location_ids: + raise GraphQLError( + "Insufficient permission. Please contact an administrator if you believe this is an error.", + extensions={"code": "FORBIDDEN"}, + ) + + result = await db.execute( + select(models.LocationNode).where(models.LocationNode.id == id) + ) + location = result.scalars().first() + + if not location: + raise GraphQLError( + "Location not found.", + extensions={"code": "NOT_FOUND"}, + ) + + if location.id not in accessible_location_ids: + raise GraphQLError( + "Insufficient permission. Please contact an administrator if you believe this is an error.", + extensions={"code": "FORBIDDEN"}, + ) + + await BaseMutationResolver.delete_entity( + info, location, models.LocationNode, "location_node" + ) + return True + + +@strawberry.type +class LocationSubscription(BaseSubscriptionResolver): + @strawberry.subscription + async def location_node_created( + self, info: Info + ) -> AsyncGenerator[strawberry.ID, None]: + async for location_id in BaseSubscriptionResolver.entity_created(info, "location_node"): + yield location_id + + @strawberry.subscription + async def location_node_updated( + self, + info: Info, + location_id: strawberry.ID | None = None, + ) -> AsyncGenerator[strawberry.ID, None]: + async for updated_id in BaseSubscriptionResolver.entity_updated(info, "location_node", location_id): + yield updated_id + + @strawberry.subscription + async def location_node_deleted( + self, info: Info + ) -> AsyncGenerator[strawberry.ID, None]: + async for location_id in BaseSubscriptionResolver.entity_deleted(info, "location_node"): + yield location_id diff --git a/backend/api/resolvers/patient.py b/backend/api/resolvers/patient.py index 8f71fd1..72938de 100644 --- a/backend/api/resolvers/patient.py +++ b/backend/api/resolvers/patient.py @@ -68,18 +68,23 @@ async def patients( accessible_location_ids = await auth_service.get_user_accessible_location_ids( info.context.user, info.context ) + + # If user has no accessible locations, return empty list + if not accessible_location_ids: + return [] + query = auth_service.filter_patients_by_access( info.context.user, query, accessible_location_ids ) filter_cte = None if root_location_ids: - invalid_ids = [lid for lid in root_location_ids if lid not in accessible_location_ids] - if invalid_ids: - raise GraphQLError( - "Insufficient permission. Please contact an administrator if you believe this is an error.", - extensions={"code": "FORBIDDEN"}, - ) + # Filter to only include root_location_ids that the user has access to + valid_root_location_ids = [lid for lid in root_location_ids if lid in accessible_location_ids] + if not valid_root_location_ids: + # If none of the requested root_location_ids are accessible, return empty list + return [] + root_location_ids = valid_root_location_ids filter_cte = ( select(models.LocationNode.id) .where(models.LocationNode.id.in_(root_location_ids)) @@ -101,7 +106,7 @@ async def patients( ) filter_cte = filter_cte.union_all(parent) - if filter_cte: + if filter_cte is not None: patient_locations_filter = aliased(models.patient_locations) patient_teams_filter = aliased(models.patient_teams) diff --git a/backend/api/types/user.py b/backend/api/types/user.py index aaeb188..44d73dc 100644 --- a/backend/api/types/user.py +++ b/backend/api/types/user.py @@ -26,6 +26,11 @@ def name(self) -> str: return f"{self.firstname} {self.lastname}" return self.username + @strawberry.field + def organizations(self, info) -> str | None: + """Get organizations from the context""" + return info.context.organizations + @strawberry.field async def tasks( self, @@ -95,6 +100,18 @@ async def root_locations( self, info, ) -> list[Annotated["LocationNodeType", strawberry.lazy("api.types.location")]]: + import logging + logger = logging.getLogger(__name__) + + # First check what's in user_root_locations table + user_root_check = await info.context.db.execute( + select(models.user_root_locations.c.location_id).where( + models.user_root_locations.c.user_id == self.id + ) + ) + user_root_location_ids = [row[0] for row in user_root_check.all()] + logger.info(f"User {self.id} has {len(user_root_location_ids)} entries in user_root_locations: {user_root_location_ids}") + result = await info.context.db.execute( select(models.LocationNode) .join( @@ -102,7 +119,23 @@ async def root_locations( models.LocationNode.id == models.user_root_locations.c.location_id, ) .where(models.user_root_locations.c.user_id == self.id) - .where(models.LocationNode.parent_id.is_(None)) .distinct() ) - return result.scalars().all() + locations = result.scalars().all() + logger.info(f"User {self.id} root_locations query returned {len(locations)} locations: {[loc.id for loc in locations]}") + + # If we have user_root_locations entries but no locations returned, check if locations exist + if user_root_location_ids and not locations: + location_check = await info.context.db.execute( + select(models.LocationNode).where( + models.LocationNode.id.in_(user_root_location_ids) + ) + ) + existing_locations = location_check.scalars().all() + logger.warning( + f"User {self.id} has {len(user_root_location_ids)} root location IDs but query returned empty. " + f"Checking if locations exist: {[loc.id for loc in existing_locations]} " + f"with parent_ids: {[loc.parent_id for loc in existing_locations]}" + ) + + return locations diff --git a/keycloak/tasks.json b/keycloak/tasks.json index a65251d..217244b 100644 --- a/keycloak/tasks.json +++ b/keycloak/tasks.json @@ -690,8 +690,8 @@ "authenticationFlowBindingOverrides" : { }, "fullScopeAllowed" : true, "nodeReRegistrationTimeout" : -1, - "defaultClientScopes" : [ "web-origins", "acr", "roles", "profile", "basic", "email" ], - "optionalClientScopes" : [ "address", "phone", "organization", "offline_access", "microprofile-jwt" ] + "defaultClientScopes" : [ "web-origins", "acr", "roles", "profile", "basic", "email", "organization" ], + "optionalClientScopes" : [ "address", "phone", "offline_access", "microprofile-jwt" ] }, { "id" : "1c76a254-0a04-4dc1-b287-4e642ae3e9be", "clientId" : "tasks-web", @@ -1357,8 +1357,8 @@ } } ] } ], - "defaultDefaultClientScopes" : [ "role_list", "saml_organization", "profile", "email", "roles", "web-origins", "acr", "basic" ], - "defaultOptionalClientScopes" : [ "offline_access", "address", "phone", "microprofile-jwt", "organization" ], + "defaultDefaultClientScopes" : [ "role_list", "saml_organization", "profile", "email", "roles", "web-origins", "acr", "basic", "organization" ], + "defaultOptionalClientScopes" : [ "offline_access", "address", "phone", "microprofile-jwt" ], "browserSecurityHeaders" : { "contentSecurityPolicyReportOnly" : "", "xContentTypeOptions" : "nosniff", diff --git a/web/api/gql/generated.ts b/web/api/gql/generated.ts index 2e606e7..9dc0445 100644 --- a/web/api/gql/generated.ts +++ b/web/api/gql/generated.ts @@ -428,6 +428,7 @@ export type UserType = { id: Scalars['ID']['output']; lastname?: Maybe; name: Scalars['String']['output']; + organizations?: Maybe; rootLocations: Array; tasks: Array; title?: Maybe; @@ -495,7 +496,7 @@ export type GetUsersQuery = { __typename?: 'Query', users: Array<{ __typename?: export type GetGlobalDataQueryVariables = Exact<{ [key: string]: never; }>; -export type GetGlobalDataQuery = { __typename?: 'Query', me?: { __typename?: 'UserType', id: string, username: string, name: string, firstname?: string | null, lastname?: string | null, avatarUrl?: string | null, rootLocations: Array<{ __typename?: 'LocationNodeType', id: string, title: string, kind: LocationType }>, tasks: Array<{ __typename?: 'TaskType', id: string, done: boolean }> } | null, wards: Array<{ __typename?: 'LocationNodeType', id: string, title: string }>, teams: Array<{ __typename?: 'LocationNodeType', id: string, title: string }>, clinics: Array<{ __typename?: 'LocationNodeType', id: string, title: string }>, patients: Array<{ __typename?: 'PatientType', id: string, state: PatientState, assignedLocation?: { __typename?: 'LocationNodeType', id: string } | null }>, waitingPatients: Array<{ __typename?: 'PatientType', id: string, state: PatientState }> }; +export type GetGlobalDataQuery = { __typename?: 'Query', me?: { __typename?: 'UserType', id: string, username: string, name: string, firstname?: string | null, lastname?: string | null, avatarUrl?: string | null, organizations?: string | null, rootLocations: Array<{ __typename?: 'LocationNodeType', id: string, title: string, kind: LocationType }>, tasks: Array<{ __typename?: 'TaskType', id: string, done: boolean }> } | null, wards: Array<{ __typename?: 'LocationNodeType', id: string, title: string }>, teams: Array<{ __typename?: 'LocationNodeType', id: string, title: string }>, clinics: Array<{ __typename?: 'LocationNodeType', id: string, title: string }>, patients: Array<{ __typename?: 'PatientType', id: string, state: PatientState, assignedLocation?: { __typename?: 'LocationNodeType', id: string } | null }>, waitingPatients: Array<{ __typename?: 'PatientType', id: string, state: PatientState }> }; export type CreatePatientMutationVariables = Exact<{ data: CreatePatientInput; @@ -1294,6 +1295,7 @@ export const GetGlobalDataDocument = ` firstname lastname avatarUrl + organizations rootLocations { id title diff --git a/web/api/graphql/GlobalData.graphql b/web/api/graphql/GlobalData.graphql index ada82c6..20b7a73 100644 --- a/web/api/graphql/GlobalData.graphql +++ b/web/api/graphql/GlobalData.graphql @@ -6,6 +6,7 @@ query GetGlobalData { firstname lastname avatarUrl + organizations rootLocations { id title diff --git a/web/api/graphql/Subscriptions.graphql b/web/api/graphql/Subscriptions.graphql index bebc73b..4ecc7fe 100644 --- a/web/api/graphql/Subscriptions.graphql +++ b/web/api/graphql/Subscriptions.graphql @@ -22,3 +22,15 @@ subscription TaskDeleted { taskDeleted } +subscription LocationNodeUpdated($locationId: ID) { + locationNodeUpdated(locationId: $locationId) +} + +subscription LocationNodeCreated { + locationNodeCreated +} + +subscription LocationNodeDeleted { + locationNodeDeleted +} + diff --git a/web/components/layout/Page.tsx b/web/components/layout/Page.tsx index e5a45d7..343db63 100644 --- a/web/components/layout/Page.tsx +++ b/web/components/layout/Page.tsx @@ -12,8 +12,6 @@ import { Expandable, LoadingContainer, MarkdownInterpreter, - Select, - SelectOption, useLocalStorage } from '@helpwave/hightide' import { getConfig } from '@/utils/config' @@ -211,12 +209,40 @@ export const Header = ({ onMenuClick, isMenuOpen, ...props }: HeaderProps) => { const { user, rootLocations, selectedRootLocationIds, update } = useTasksContext() const translation = useTasksTranslation() const [isLocationPickerOpen, setIsLocationPickerOpen] = useState(false) + const [selectedLocationsCache, setSelectedLocationsCache] = useState>([]) - const selectedRootLocations = rootLocations?.filter(loc => selectedRootLocationIds?.includes(loc.id)) || [] + // Update cache when rootLocations change and contain the selected IDs + useEffect(() => { + if (rootLocations && selectedRootLocationIds && selectedRootLocationIds.length > 0) { + const foundInRoot = rootLocations.filter(loc => selectedRootLocationIds.includes(loc.id)) + if (foundInRoot.length > 0) { + // Update cache with locations from rootLocations if they match + const cacheIds = new Set(selectedLocationsCache.map(loc => loc.id)) + const selectedIds = new Set(selectedRootLocationIds) + if (cacheIds.size === selectedIds.size && Array.from(cacheIds).every(id => selectedIds.has(id))) { + // Cache matches, update it with rootLocations data + setSelectedLocationsCache(foundInRoot.map(loc => ({ id: loc.id, title: loc.title, kind: loc.kind }))) + } else if (selectedLocationsCache.length === 0) { + // No cache, use rootLocations + setSelectedLocationsCache(foundInRoot.map(loc => ({ id: loc.id, title: loc.title, kind: loc.kind }))) + } + } + } else if (!selectedRootLocationIds || selectedRootLocationIds.length === 0) { + // Clear cache when selection is cleared + setSelectedLocationsCache([]) + } + }, [rootLocations, selectedRootLocationIds]) + + // Use cached locations if available, otherwise fall back to rootLocations + const selectedRootLocations = selectedLocationsCache.length > 0 + ? selectedLocationsCache + : (rootLocations?.filter(loc => selectedRootLocationIds?.includes(loc.id)) || []) const firstSelectedRootLocation = selectedRootLocations[0] const handleRootLocationSelect = (locations: Array<{ id: string; title: string; kind?: string }>) => { if (locations.length === 0) return + // Cache the selected locations so we can display them even if they're not in rootLocations + setSelectedLocationsCache(locations) update(prevState => ({ ...prevState, selectedRootLocationIds: locations.map(loc => loc.id), @@ -254,7 +280,9 @@ export const Header = ({ onMenuClick, isMenuOpen, ...props }: HeaderProps) => { {selectedRootLocations.length > 0 ? selectedRootLocations.length === 1 ? firstSelectedRootLocation?.title - : `${selectedRootLocations.length} ${translation('locations') || 'locations'}` + : selectedRootLocations.length === 2 + ? `${selectedRootLocations[0].title}, ${selectedRootLocations[1].title}` + : `${selectedRootLocations[0].title} +${selectedRootLocations.length - 1}` : translation('selectLocation') || 'Select Location'} { onSelect={handleRootLocationSelect} initialSelectedIds={selectedRootLocationIds || []} multiSelect={true} - useCase="default" + useCase="root" />
)} - {selectedRootLocations.length > 0 && ( -
- -
- )}
diff --git a/web/components/locations/LocationSelectionDialog.tsx b/web/components/locations/LocationSelectionDialog.tsx index 734e754..b73524e 100644 --- a/web/components/locations/LocationSelectionDialog.tsx +++ b/web/components/locations/LocationSelectionDialog.tsx @@ -26,6 +26,7 @@ export type LocationPickerUseCase = | 'clinic' | 'position' | 'teams' + | 'root' interface LocationSelectionDialogProps { isOpen: boolean, @@ -178,22 +179,132 @@ export const LocationSelectionDialog = ({ const hasInitialized = useRef(false) + // Helper function to get all descendant IDs of a node (recursively) + const getAllDescendantIds = useMemo(() => { + if (!data?.locationNodes) return () => new Set() + const nodes = data.locationNodes as LocationNodeType[] + + return (nodeId: string): Set => { + const descendants = new Set() + const queue = [nodeId] + + while (queue.length > 0) { + const currentId = queue.shift()! + const children = nodes.filter(n => n.parentId === currentId) + children.forEach(child => { + descendants.add(child.id) + queue.push(child.id) + }) + } + + return descendants + } + }, [data?.locationNodes]) + + // Helper function to get all ancestor IDs of a node (recursively) + const getAllAncestorIds = useMemo(() => { + if (!data?.locationNodes) return () => new Set() + const nodes = data.locationNodes as LocationNodeType[] + + return (nodeId: string): Set => { + const ancestors = new Set() + let current: LocationNodeType | undefined = nodes.find(n => n.id === nodeId) + + while (current?.parentId) { + ancestors.add(current.parentId) + current = nodes.find(n => n.id === current.parentId) + } + + return ancestors + } + }, [data?.locationNodes]) + + // Simplify selection: prefer children over parents (most specific selection wins) + const simplifySelection = useMemo(() => { + if (!data?.locationNodes || useCase !== 'root') { + return (ids: string[]): string[] => ids + } + const nodes = data.locationNodes as LocationNodeType[] + + return (ids: string[]): string[] => { + if (ids.length === 0) return ids + + const idSet = new Set(ids) + const simplified = new Set() + + // First pass: prefer children over parents + // If both a parent and child are selected, keep only the child + for (const id of ids) { + let current: LocationNodeType | undefined = nodes.find(n => n.id === id) + let hasAncestorSelected = false + + // Check if any ancestor is also selected + while (current?.parentId) { + if (idSet.has(current.parentId)) { + hasAncestorSelected = true + break + } + current = nodes.find(n => n.id === current?.parentId) + } + + // Only add if no ancestor is selected (child wins over parent) + if (!hasAncestorSelected) { + simplified.add(id) + } + } + + // Second pass: remove descendants of selected nodes (parent includes children) + const finalSet = new Set() + for (const id of simplified) { + // Check if this node has any descendants in the simplified set + const descendants = getAllDescendantIds(id) + const hasDescendantInSimplified = Array.from(descendants).some(descId => simplified.has(descId)) + + // Only add if no descendant is in simplified set (child wins over parent) + if (!hasDescendantInSimplified) { + finalSet.add(id) + } + } + + return Array.from(finalSet) + } + }, [data?.locationNodes, useCase, getAllDescendantIds]) + useEffect(() => { if (isOpen) { - setSelectedIds(new Set(initialSelectedIds)) + // Simplify initial selection when dialog opens + const simplifiedIds = simplifySelection(initialSelectedIds) + setSelectedIds(new Set(simplifiedIds)) setExpandedIds(new Set()) hasInitialized.current = true } else { hasInitialized.current = false } - }, [isOpen, initialSelectedIds]) + }, [isOpen, initialSelectedIds, simplifySelection]) const matchesFilter = useMemo(() => { if (useCase === 'default') { return () => true } - if (useCase === 'clinic') { + if (useCase === 'root') { + // Only hospitals, practices, clinics, and teams are selectable for root locations + const allowedKinds = new Set([ + LocationType.Hospital, + LocationType.Practice, + LocationType.Clinic, + LocationType.Team, + 'HOSPITAL', + 'PRACTICE', + 'CLINIC', + 'TEAM', + ]) + return (node: LocationNodeType) => { + const kindStr = node.kind.toString().toUpperCase() + return allowedKinds.has(node.kind as LocationType) || + allowedKinds.has(kindStr) + } + } else if (useCase === 'clinic') { return (node: LocationNodeType) => { const kindStr = node.kind.toString().toUpperCase() return kindStr === 'CLINIC' || node.kind === LocationType.Clinic @@ -320,6 +431,18 @@ export const LocationSelectionDialog = ({ if (useCase === 'clinic' || !multiSelect) { newSet.clear() } + + // Simplification logic: only for root useCase + if (useCase === 'root') { + // Remove all descendants of this node (parent includes children) + const descendants = getAllDescendantIds(node.id) + descendants.forEach(descId => newSet.delete(descId)) + + // Remove all ancestors of this node (child replaces parent) + const ancestors = getAllAncestorIds(node.id) + ancestors.forEach(ancId => newSet.delete(ancId)) + } + newSet.add(node.id) } else { newSet.delete(node.id) diff --git a/web/components/patients/PatientDetailView.tsx b/web/components/patients/PatientDetailView.tsx index 7ab5509..87ff093 100644 --- a/web/components/patients/PatientDetailView.tsx +++ b/web/components/patients/PatientDetailView.tsx @@ -241,7 +241,8 @@ export const PatientDetailView = ({ if (patientData?.patient) { const patient = patientData.patient const { firstname, lastname, sex, birthdate, assignedLocations, clinic, position, teams } = patient - setFormData({ + setFormData(prev => ({ + ...prev, firstname, lastname, sex, @@ -250,7 +251,7 @@ export const PatientDetailView = ({ clinicId: clinic?.id || undefined, positionId: position?.id || undefined, teamIds: teams?.map(t => t.id) || undefined, - } as CreatePatientInput & { clinicId?: string, positionId?: string, teamIds?: string[] }) + } as CreatePatientInput & { clinicId?: string, positionId?: string, teamIds?: string[] })) setSelectedClinic(clinic ? (clinic as LocationNodeType) : null) setSelectedPosition(position ? (position as LocationNodeType) : null) setSelectedTeams((teams || []) as LocationNodeType[]) @@ -258,20 +259,33 @@ export const PatientDetailView = ({ }, [patientData]) useEffect(() => { - if (!isEditMode && firstSelectedRootLocationId && locationsData?.locationNodes && !formData.clinicId) { - const selectedRootLocation = locationsData.locationNodes.find( - loc => loc.id === firstSelectedRootLocationId && loc.kind === 'CLINIC' - ) - if (selectedRootLocation) { - const clinicLocation = selectedRootLocation as LocationNodeType + if (!isEditMode && locationsData?.locationNodes && !formData.clinicId) { + // Try to find a CLINIC in the selected root locations first + let clinicLocation: LocationNodeType | undefined + if (firstSelectedRootLocationId) { + const selectedRootLocation = locationsData.locationNodes.find( + loc => loc.id === firstSelectedRootLocationId && loc.kind === 'CLINIC' + ) + if (selectedRootLocation) { + clinicLocation = selectedRootLocation as LocationNodeType + } + } + // If no CLINIC found in selected, try first root location that is a CLINIC + if (!clinicLocation && rootLocations && rootLocations.length > 0) { + const firstClinic = rootLocations.find(loc => loc.kind === 'CLINIC') + if (firstClinic) { + clinicLocation = firstClinic as LocationNodeType + } + } + if (clinicLocation) { setSelectedClinic(clinicLocation) setFormData(prev => ({ ...prev, - clinicId: clinicLocation.id, + clinicId: clinicLocation!.id, })) } } - }, [isEditMode, firstSelectedRootLocationId, locationsData, formData.clinicId]) + }, [isEditMode, firstSelectedRootLocationId, locationsData, formData.clinicId, rootLocations]) const { mutate: createPatient, isLoading: isCreating } = useCreatePatientMutation({ onSuccess: () => { diff --git a/web/hooks/useTasksContext.tsx b/web/hooks/useTasksContext.tsx index 996394f..84e6841 100644 --- a/web/hooks/useTasksContext.tsx +++ b/web/hooks/useTasksContext.tsx @@ -1,7 +1,8 @@ import type { Dispatch, SetStateAction } from 'react' -import { createContext, type PropsWithChildren, useContext, useEffect, useState } from 'react' +import { createContext, type PropsWithChildren, useContext, useEffect, useRef, useState } from 'react' import { usePathname } from 'next/navigation' -import { useGetGlobalDataQuery } from '@/api/gql/generated' +import { useGetGlobalDataQuery, useLocationNodeUpdatedSubscription } from '@/api/gql/generated' +import { useQueryClient } from '@tanstack/react-query' import { useAuth } from './useAuth' import { useLocalStorage } from '@helpwave/hightide' @@ -58,6 +59,7 @@ export const useTasksContext = (): TasksContextType => { export const TasksContextProvider = ({ children }: PropsWithChildren) => { const pathName = usePathname() const { identity, isLoading: isAuthLoading } = useAuth() + const queryClient = useQueryClient() const { value: storedSelectedRootLocationIds, setValue: setStoredSelectedRootLocationIds @@ -78,25 +80,76 @@ export const TasksContextProvider = ({ children }: PropsWithChildren) => { refetchOnMount: true, }) + // Subscribe to location updates and invalidate all queries when locations change + // Note: This will be available after running codegen + // useLocationNodeUpdatedSubscription( + // { locationId: undefined }, + // { + // enabled: !isAuthLoading && !!identity, + // onData: () => { + // // Invalidate all queries when a location is updated + // queryClient.invalidateQueries() + // }, + // } + // ) + + // Track previous root location IDs to detect changes + const prevRootLocationIdsRef = useRef('') + + // Invalidate all queries when root locations change (this handles location node updates) + useEffect(() => { + if (data?.me?.rootLocations) { + const currentRootLocationIds = data.me.rootLocations.map(loc => loc.id).sort().join(',') + + if (prevRootLocationIdsRef.current && prevRootLocationIdsRef.current !== currentRootLocationIds) { + // Root locations changed, invalidate all queries to reload global state + queryClient.invalidateQueries() + } + prevRootLocationIdsRef.current = currentRootLocationIds + } + }, [data?.me?.rootLocations, queryClient]) + useEffect(() => { const totalPatientsCount = data?.patients?.length ?? 0 const waitingPatientsCount = data?.waitingPatients?.length ?? 0 const rootLocations = data?.me?.rootLocations?.map(loc => ({ id: loc.id, title: loc.title, kind: loc.kind })) ?? [] + // Debug logging - use console.log so it's always visible + console.log('[DEBUG] useTasksContext - data?.me:', data?.me) + if (data?.me?.organizations) { + console.log('[DEBUG] Organizations (raw):', data.me.organizations) + console.log('[DEBUG] Organizations (parsed):', data.me.organizations.split(',').map(org => org.trim())) + } + console.log('[DEBUG] Root Locations count:', rootLocations.length) + if (rootLocations.length > 0) { + console.log('[DEBUG] Root Locations:', rootLocations) + } else { + console.log('[DEBUG] No root locations found. data?.me?.rootLocations:', data?.me?.rootLocations) + } + setState(prevState => { let selectedRootLocationIds = prevState.selectedRootLocationIds || [] if (rootLocations.length > 0) { const validIds = selectedRootLocationIds.filter(id => rootLocations.find(loc => loc.id === id)) - if (validIds.length === 0) { + // If no valid IDs and no localStorage state, auto-select only the first root location + if (validIds.length === 0 && storedSelectedRootLocationIds.length === 0) { + // Auto-select first root location if none selected and no localStorage state + selectedRootLocationIds = [rootLocations[0].id] + console.log('[DEBUG] Auto-selected first root location (no localStorage):', rootLocations[0].id) + } else if (validIds.length === 0 && storedSelectedRootLocationIds.length > 0) { + // If localStorage has values but they're not valid, clear localStorage and use first selectedRootLocationIds = [rootLocations[0].id] - setStoredSelectedRootLocationIds(selectedRootLocationIds) + setStoredSelectedRootLocationIds([]) + console.log('[DEBUG] Cleared invalid localStorage, auto-selected first root location:', rootLocations[0].id) } else { selectedRootLocationIds = validIds - if (selectedRootLocationIds.length !== prevState.selectedRootLocationIds?.length) { - setStoredSelectedRootLocationIds(selectedRootLocationIds) - } } + } else if (selectedRootLocationIds.length > 0) { + // If we have selected IDs but no root locations, clear the selection + // This happens when locations are removed or user's organizations change + console.log('[DEBUG] Clearing selectedRootLocationIds because rootLocations is empty') + selectedRootLocationIds = [] } return { @@ -120,7 +173,22 @@ export const TasksContextProvider = ({ children }: PropsWithChildren) => { selectedRootLocationIds, } }) - }, [data, setStoredSelectedRootLocationIds]) + }, [data]) + + // Use refs to track what we last wrote to localStorage to avoid loops + const lastWrittenLocationIdsRef = useRef(undefined) + + // Separate effect to sync state changes to localStorage + useEffect(() => { + if (state.selectedRootLocationIds !== undefined) { + const currentIds = state.selectedRootLocationIds + const lastWritten = lastWrittenLocationIdsRef.current + if (JSON.stringify(currentIds) !== JSON.stringify(lastWritten)) { + lastWrittenLocationIdsRef.current = currentIds + setStoredSelectedRootLocationIds(currentIds) + } + } + }, [state.selectedRootLocationIds, setStoredSelectedRootLocationIds]) const updateState: Dispatch> = (updater) => { setState(prevState => { diff --git a/web/i18n/translations.ts b/web/i18n/translations.ts index a35b349..f70fdf7 100644 --- a/web/i18n/translations.ts +++ b/web/i18n/translations.ts @@ -74,6 +74,7 @@ export type TasksTranslationEntries = { 'lastUpdate': string, 'loading': string, 'location': string, + 'locations': string, 'locationType': (values: { type: string }) => string, 'login': string, 'loginRequired': string, @@ -110,6 +111,7 @@ export type TasksTranslationEntries = { 'openSurvey': string, 'openTasks': string, 'option': string, + 'organizations': string, 'overview': string, 'pages.404.notFound': string, 'pages.404.notFoundDescription1': string, @@ -155,6 +157,8 @@ export type TasksTranslationEntries = { 'selectLocation': string, 'selectLocationDescription': string, 'selectOptions': string, + 'selectOrganization': string, + 'selectOrganizations': string, 'selectPatient': string, 'selectPosition': string, 'selectTeams': string, @@ -261,6 +265,7 @@ export const tasksTranslation: Translation { return TranslationGen.resolveSelect(type, { 'CLINIC': `Klinik`, @@ -359,6 +364,7 @@ export const tasksTranslation: Translation { return TranslationGen.resolveSelect(type, { 'CLINIC': `Clinic`, @@ -650,6 +659,7 @@ export const tasksTranslation: Translation Date: Wed, 24 Dec 2025 02:03:45 +0100 Subject: [PATCH 12/16] remove comments and debug statements fix linting errors --- backend/api/context.py | 23 -- web/api/gql/generated.ts | 86 +++++- web/api/graphql/GlobalData.graphql | 9 +- web/components/layout/Page.tsx | 244 ++++++++++-------- .../locations/LocationSelectionDialog.tsx | 92 ++----- web/hooks/useTasksContext.tsx | 171 +++++++----- web/i18n/translations.ts | 6 + web/locales/de-DE.arb | 2 + web/locales/en-US.arb | 2 + web/pages/tasks/index.tsx | 4 +- 10 files changed, 353 insertions(+), 286 deletions(-) diff --git a/backend/api/context.py b/backend/api/context.py index b0d4737..8cac67a 100644 --- a/backend/api/context.py +++ b/backend/api/context.py @@ -82,28 +82,22 @@ async def get_context( email = user_payload.get("email") picture = user_payload.get("picture") - # Debug: Log available keys in token to help diagnose missing organization claim logger = logging.getLogger(__name__) organizations_raw = user_payload.get("organization") if organizations_raw is None: - # Check if organization scope is in the token scope = user_payload.get("scope", "") has_org_scope = "organization" in scope.split() if scope else False - # Use warning level so it's visible in logs logger.warning( f"Organization claim not found in token for user {user_payload.get('sub', 'unknown')}. " f"Has organization scope: {has_org_scope}. " f"Token scope: {scope}. " f"Available claims: {sorted(user_payload.keys())}" ) - # Also print to console for immediate visibility - print(f"WARNING: Organization claim missing. Scope: {scope}, Available claims: {sorted(user_payload.keys())}") organizations = None if organizations_raw: if isinstance(organizations_raw, list): - # Filter out empty strings and None values org_list = [str(org) for org in organizations_raw if org] if org_list: organizations = ",".join(org_list) @@ -163,13 +157,6 @@ async def get_context( if db_user: try: - # Debug output - if organizations is None: - print(f"WARNING: organizations is None for user {db_user.id} ({db_user.username})") - print(f"Token payload keys: {sorted(user_payload.keys())}") - print(f"Token scope: {user_payload.get('scope', 'N/A')}") - else: - print(f"Organizations for user {db_user.id}: {organizations}") await _update_user_root_locations( session, @@ -199,19 +186,15 @@ async def _update_user_root_locations( ] logger = logging.getLogger(__name__) - logger.info(f"Updating root locations for user {user.id} with organizations: {organization_ids}") root_location_ids: list[str] = [] if organization_ids: - logger.info(f"Looking up locations for organization IDs: {organization_ids}") - # First check if any location_organizations entries exist org_check = await session.execute( select(location_organizations.c.organization_id, location_organizations.c.location_id) .where(location_organizations.c.organization_id.in_(organization_ids)) ) org_entries = org_check.all() - logger.info(f"Found {len(org_entries)} location_organizations entries: {[(row[0], row[1]) for row in org_entries]}") result = await session.execute( select(LocationNode) @@ -225,7 +208,6 @@ async def _update_user_root_locations( ) found_locations = result.scalars().all() root_location_ids = [loc.id for loc in found_locations] - logger.info(f"Found {len(found_locations)} existing locations for organizations: {[loc.id for loc in found_locations]}") found_org_ids = set() for loc in found_locations: @@ -257,9 +239,7 @@ async def _update_user_root_locations( ), ) root_location_ids.append(new_location.id) - logger.info(f"Created new location {new_location.id} for organization {org_id}") - logger.info(f"Total root location IDs: {root_location_ids}") if not root_location_ids: personal_org_title = f"{user.username}'s Organization" @@ -281,7 +261,6 @@ async def _update_user_root_locations( await session.flush() root_location_ids = [personal_location.id] - logger.info(f"Using personal location: {personal_location.id}") await session.execute( delete(user_root_locations).where( @@ -300,7 +279,5 @@ async def _update_user_root_locations( index_elements=["user_id", "location_id"], ) await session.execute(stmt) - logger.info(f"Inserted {len(root_location_ids)} root locations for user {user.id}: {root_location_ids}") await session.commit() - logger.info(f"Root locations update completed for user {user.id}") diff --git a/web/api/gql/generated.ts b/web/api/gql/generated.ts index 9dc0445..f6d0c5f 100644 --- a/web/api/gql/generated.ts +++ b/web/api/gql/generated.ts @@ -18,6 +18,12 @@ export type Scalars = { DateTime: { input: any; output: any; } }; +export type CreateLocationNodeInput = { + kind: LocationType; + parentId?: InputMaybe; + title: Scalars['String']['input']; +}; + export type CreatePatientInput = { assignedLocationId?: InputMaybe; assignedLocationIds?: InputMaybe>; @@ -90,9 +96,11 @@ export type Mutation = { admitPatient: PatientType; assignTask: TaskType; completeTask: TaskType; + createLocationNode: LocationNodeType; createPatient: PatientType; createPropertyDefinition: PropertyDefinitionType; createTask: TaskType; + deleteLocationNode: Scalars['Boolean']['output']; deletePatient: Scalars['Boolean']['output']; deletePropertyDefinition: Scalars['Boolean']['output']; deleteTask: Scalars['Boolean']['output']; @@ -100,6 +108,7 @@ export type Mutation = { markPatientDead: PatientType; reopenTask: TaskType; unassignTask: TaskType; + updateLocationNode: LocationNodeType; updatePatient: PatientType; updatePropertyDefinition: PropertyDefinitionType; updateTask: TaskType; @@ -123,6 +132,11 @@ export type MutationCompleteTaskArgs = { }; +export type MutationCreateLocationNodeArgs = { + data: CreateLocationNodeInput; +}; + + export type MutationCreatePatientArgs = { data: CreatePatientInput; }; @@ -138,6 +152,11 @@ export type MutationCreateTaskArgs = { }; +export type MutationDeleteLocationNodeArgs = { + id: Scalars['ID']['input']; +}; + + export type MutationDeletePatientArgs = { id: Scalars['ID']['input']; }; @@ -173,6 +192,12 @@ export type MutationUnassignTaskArgs = { }; +export type MutationUpdateLocationNodeArgs = { + data: UpdateLocationNodeInput; + id: Scalars['ID']['input']; +}; + + export type MutationUpdatePatientArgs = { data: UpdatePatientInput; id: Scalars['ID']['input']; @@ -347,6 +372,9 @@ export enum Sex { export type Subscription = { __typename?: 'Subscription'; + locationNodeCreated: Scalars['ID']['output']; + locationNodeDeleted: Scalars['ID']['output']; + locationNodeUpdated: Scalars['ID']['output']; patientCreated: Scalars['ID']['output']; patientStateChanged: Scalars['ID']['output']; patientUpdated: Scalars['ID']['output']; @@ -356,6 +384,11 @@ export type Subscription = { }; +export type SubscriptionLocationNodeUpdatedArgs = { + locationId?: InputMaybe; +}; + + export type SubscriptionPatientStateChangedArgs = { patientId?: InputMaybe; }; @@ -387,6 +420,12 @@ export type TaskType = { updateDate?: Maybe; }; +export type UpdateLocationNodeInput = { + kind?: InputMaybe; + parentId?: InputMaybe; + title?: InputMaybe; +}; + export type UpdatePatientInput = { assignedLocationId?: InputMaybe; assignedLocationIds?: InputMaybe>; @@ -493,10 +532,12 @@ export type GetUsersQueryVariables = Exact<{ [key: string]: never; }>; export type GetUsersQuery = { __typename?: 'Query', users: Array<{ __typename?: 'UserType', id: string, name: string, avatarUrl?: string | null }> }; -export type GetGlobalDataQueryVariables = Exact<{ [key: string]: never; }>; +export type GetGlobalDataQueryVariables = Exact<{ + rootLocationIds?: InputMaybe | Scalars['ID']['input']>; +}>; -export type GetGlobalDataQuery = { __typename?: 'Query', me?: { __typename?: 'UserType', id: string, username: string, name: string, firstname?: string | null, lastname?: string | null, avatarUrl?: string | null, organizations?: string | null, rootLocations: Array<{ __typename?: 'LocationNodeType', id: string, title: string, kind: LocationType }>, tasks: Array<{ __typename?: 'TaskType', id: string, done: boolean }> } | null, wards: Array<{ __typename?: 'LocationNodeType', id: string, title: string }>, teams: Array<{ __typename?: 'LocationNodeType', id: string, title: string }>, clinics: Array<{ __typename?: 'LocationNodeType', id: string, title: string }>, patients: Array<{ __typename?: 'PatientType', id: string, state: PatientState, assignedLocation?: { __typename?: 'LocationNodeType', id: string } | null }>, waitingPatients: Array<{ __typename?: 'PatientType', id: string, state: PatientState }> }; +export type GetGlobalDataQuery = { __typename?: 'Query', me?: { __typename?: 'UserType', id: string, username: string, name: string, firstname?: string | null, lastname?: string | null, avatarUrl?: string | null, organizations?: string | null, rootLocations: Array<{ __typename?: 'LocationNodeType', id: string, title: string, kind: LocationType }>, tasks: Array<{ __typename?: 'TaskType', id: string, done: boolean }> } | null, wards: Array<{ __typename?: 'LocationNodeType', id: string, title: string, parentId?: string | null }>, teams: Array<{ __typename?: 'LocationNodeType', id: string, title: string, parentId?: string | null }>, clinics: Array<{ __typename?: 'LocationNodeType', id: string, title: string, parentId?: string | null }>, patients: Array<{ __typename?: 'PatientType', id: string, state: PatientState, assignedLocation?: { __typename?: 'LocationNodeType', id: string } | null }>, waitingPatients: Array<{ __typename?: 'PatientType', id: string, state: PatientState }> }; export type CreatePatientMutationVariables = Exact<{ data: CreatePatientInput; @@ -612,6 +653,23 @@ export type TaskDeletedSubscriptionVariables = Exact<{ [key: string]: never; }>; export type TaskDeletedSubscription = { __typename?: 'Subscription', taskDeleted: string }; +export type LocationNodeUpdatedSubscriptionVariables = Exact<{ + locationId?: InputMaybe; +}>; + + +export type LocationNodeUpdatedSubscription = { __typename?: 'Subscription', locationNodeUpdated: string }; + +export type LocationNodeCreatedSubscriptionVariables = Exact<{ [key: string]: never; }>; + + +export type LocationNodeCreatedSubscription = { __typename?: 'Subscription', locationNodeCreated: string }; + +export type LocationNodeDeletedSubscriptionVariables = Exact<{ [key: string]: never; }>; + + +export type LocationNodeDeletedSubscription = { __typename?: 'Subscription', locationNodeDeleted: string }; + export type CreateTaskMutationVariables = Exact<{ data: CreateTaskInput; }>; @@ -1287,7 +1345,7 @@ export const useGetUsersQuery = < )}; export const GetGlobalDataDocument = ` - query GetGlobalData { + query GetGlobalData($rootLocationIds: [ID!]) { me { id username @@ -1309,23 +1367,26 @@ export const GetGlobalDataDocument = ` wards: locationNodes(kind: WARD) { id title + parentId } teams: locationNodes(kind: TEAM) { id title + parentId } clinics: locationNodes(kind: CLINIC) { id title + parentId } - patients { + patients(rootLocationIds: $rootLocationIds) { id state assignedLocation { id } } - waitingPatients: patients(states: [WAIT]) { + waitingPatients: patients(states: [WAIT], rootLocationIds: $rootLocationIds) { id state } @@ -1700,6 +1761,21 @@ export const TaskDeletedDocument = ` taskDeleted } `; +export const LocationNodeUpdatedDocument = ` + subscription LocationNodeUpdated($locationId: ID) { + locationNodeUpdated(locationId: $locationId) +} + `; +export const LocationNodeCreatedDocument = ` + subscription LocationNodeCreated { + locationNodeCreated +} + `; +export const LocationNodeDeletedDocument = ` + subscription LocationNodeDeleted { + locationNodeDeleted +} + `; export const CreateTaskDocument = ` mutation CreateTask($data: CreateTaskInput!) { createTask(data: $data) { diff --git a/web/api/graphql/GlobalData.graphql b/web/api/graphql/GlobalData.graphql index 20b7a73..319978f 100644 --- a/web/api/graphql/GlobalData.graphql +++ b/web/api/graphql/GlobalData.graphql @@ -1,4 +1,4 @@ -query GetGlobalData { +query GetGlobalData($rootLocationIds: [ID!]) { me { id username @@ -20,23 +20,26 @@ query GetGlobalData { wards: locationNodes(kind: WARD) { id title + parentId } teams: locationNodes(kind: TEAM) { id title + parentId } clinics: locationNodes(kind: CLINIC) { id title + parentId } - patients { + patients(rootLocationIds: $rootLocationIds) { id state assignedLocation { id } } - waitingPatients: patients(states: [WAIT]) { + waitingPatients: patients(states: [WAIT], rootLocationIds: $rootLocationIds) { id state } diff --git a/web/components/layout/Page.tsx b/web/components/layout/Page.tsx index 343db63..4132741 100644 --- a/web/components/layout/Page.tsx +++ b/web/components/layout/Page.tsx @@ -33,9 +33,9 @@ import { Notifications } from '@/components/Notifications' import { TasksLogo } from '@/components/TasksLogo' import { useRouter } from 'next/router' import { useTasksContext } from '@/hooks/useTasksContext' +import { useGetLocationsQuery } from '@/api/gql/generated' import { hashString } from '@/utils/hash' import { useSwipeGesture } from '@/hooks/useSwipeGesture' -import { LocationChips } from '@/components/patients/LocationChips' import { LocationSelectionDialog } from '@/components/locations/LocationSelectionDialog' export const StagingDisclaimerDialog = () => { @@ -209,44 +209,64 @@ export const Header = ({ onMenuClick, isMenuOpen, ...props }: HeaderProps) => { const { user, rootLocations, selectedRootLocationIds, update } = useTasksContext() const translation = useTasksTranslation() const [isLocationPickerOpen, setIsLocationPickerOpen] = useState(false) - const [selectedLocationsCache, setSelectedLocationsCache] = useState>([]) + const [selectedLocationsCache, setSelectedLocationsCache] = useState>([]) + + const { data: locationsData } = useGetLocationsQuery( + {}, + { + enabled: !!selectedRootLocationIds && selectedRootLocationIds.length > 0, + refetchInterval: 30000, + refetchOnWindowFocus: true, + } + ) - // Update cache when rootLocations change and contain the selected IDs useEffect(() => { - if (rootLocations && selectedRootLocationIds && selectedRootLocationIds.length > 0) { - const foundInRoot = rootLocations.filter(loc => selectedRootLocationIds.includes(loc.id)) - if (foundInRoot.length > 0) { - // Update cache with locations from rootLocations if they match - const cacheIds = new Set(selectedLocationsCache.map(loc => loc.id)) - const selectedIds = new Set(selectedRootLocationIds) - if (cacheIds.size === selectedIds.size && Array.from(cacheIds).every(id => selectedIds.has(id))) { - // Cache matches, update it with rootLocations data - setSelectedLocationsCache(foundInRoot.map(loc => ({ id: loc.id, title: loc.title, kind: loc.kind }))) - } else if (selectedLocationsCache.length === 0) { - // No cache, use rootLocations - setSelectedLocationsCache(foundInRoot.map(loc => ({ id: loc.id, title: loc.title, kind: loc.kind }))) + if (selectedRootLocationIds && selectedRootLocationIds.length > 0) { + const foundInRoot = rootLocations?.filter(loc => selectedRootLocationIds.includes(loc.id)) || [] + + if (foundInRoot.length === selectedRootLocationIds.length) { + setSelectedLocationsCache(foundInRoot.map(loc => ({ id: loc.id, title: loc.title, kind: loc.kind }))) + } else if (locationsData?.locationNodes) { + const allLocations = locationsData.locationNodes + const foundLocations: Array<{ id: string, title: string, kind?: string }> = [] + for (const id of selectedRootLocationIds) { + const inRoot = rootLocations?.find(loc => loc.id === id) + if (inRoot) { + foundLocations.push({ id: inRoot.id, title: inRoot.title, kind: inRoot.kind }) + } else { + const inAll = allLocations.find(loc => loc.id === id) + if (inAll) { + foundLocations.push({ id: inAll.id, title: inAll.title, kind: inAll.kind }) + } + } + } + + if (foundLocations.length > 0) { + setSelectedLocationsCache(foundLocations) } + } else if (foundInRoot.length > 0) { + setSelectedLocationsCache(foundInRoot.map(loc => ({ id: loc.id, title: loc.title, kind: loc.kind }))) } - } else if (!selectedRootLocationIds || selectedRootLocationIds.length === 0) { - // Clear cache when selection is cleared + } else { setSelectedLocationsCache([]) } - }, [rootLocations, selectedRootLocationIds]) + }, [rootLocations, selectedRootLocationIds, locationsData]) - // Use cached locations if available, otherwise fall back to rootLocations - const selectedRootLocations = selectedLocationsCache.length > 0 - ? selectedLocationsCache + const selectedRootLocations = selectedLocationsCache.length > 0 + ? selectedLocationsCache : (rootLocations?.filter(loc => selectedRootLocationIds?.includes(loc.id)) || []) const firstSelectedRootLocation = selectedRootLocations[0] - const handleRootLocationSelect = (locations: Array<{ id: string; title: string; kind?: string }>) => { + const handleRootLocationSelect = (locations: Array<{ id: string, title: string, kind?: string }>) => { if (locations.length === 0) return - // Cache the selected locations so we can display them even if they're not in rootLocations + const locationIds = locations.map(loc => loc.id) setSelectedLocationsCache(locations) - update(prevState => ({ - ...prevState, - selectedRootLocationIds: locations.map(loc => loc.id), - })) + update(prevState => { + return { + ...prevState, + selectedRootLocationIds: locationIds, + } + }) setIsLocationPickerOpen(false) } return ( @@ -281,8 +301,8 @@ export const Header = ({ onMenuClick, isMenuOpen, ...props }: HeaderProps) => { ? selectedRootLocations.length === 1 ? firstSelectedRootLocation?.title : selectedRootLocations.length === 2 - ? `${selectedRootLocations[0].title}, ${selectedRootLocations[1].title}` - : `${selectedRootLocations[0].title} +${selectedRootLocations.length - 1}` + ? `${selectedRootLocations[0]?.title ?? ''}, ${selectedRootLocations[1]?.title ?? ''}` + : `${selectedRootLocations[0]?.title ?? ''} +${selectedRootLocations.length - 1}` : translation('selectLocation') || 'Select Location'} { {context?.totalPatientsCount !== undefined && ({context.totalPatientsCount})} - - - {translation('teams')} -
- )} - headerClassName="!px-2.5 !py-1.5 hover:bg-black/30" - contentClassName="!px-0 !pb-0 gap-y-0" - className="!shadow-none" - isExpanded={context.sidebar.isShowingTeams} - onChange={isExpanded => context.update(prevState => ({ - ...prevState, - sidebar: { - ...prevState.sidebar, - isShowingTeams: isExpanded, - } - }))} - > - {!context?.teams ? ( - - ) : context.teams.map(team => ( - - {team.title} - - ))} - - - - - {translation('wards')} -
- )} - headerClassName="!px-2.5 !py-1.5 hover:bg-black/30" - contentClassName="!px-0 !pb-0 gap-y-0" - className="!shadow-none" - isExpanded={context.sidebar.isShowingWards} - onChange={isExpanded => context.update(prevState => ({ - ...prevState, - sidebar: { - ...prevState.sidebar, - isShowingWards: isExpanded, - } - }))} - > - {!context?.wards ? ( - - ) : context.wards.map(ward => ( - - {ward.title} - - ))} - - - - - {translation('clinics')} -
- )} - headerClassName="!px-2.5 !py-1.5 hover:bg-black/30" - contentClassName="!px-0 !pb-0 gap-y-0" - className="!shadow-none" - isExpanded={context.sidebar.isShowingClinics} - onChange={isExpanded => context.update(prevState => ({ - ...prevState, - sidebar: { - ...prevState.sidebar, - isShowingClinics: isExpanded, - } - }))} - > - {!context?.clinics ? ( - - ) : context.clinics.map(clinic => ( - - {clinic.title} - - ))} - + {context?.teams && context.teams.length > 0 && ( + + + {translation('teams')} + + )} + headerClassName="!px-2.5 !py-1.5 hover:bg-black/30" + contentClassName="!px-0 !pb-0 gap-y-0" + className="!shadow-none" + isExpanded={context.sidebar.isShowingTeams} + onChange={isExpanded => context.update(prevState => ({ + ...prevState, + sidebar: { + ...prevState.sidebar, + isShowingTeams: isExpanded, + } + }))} + > + {context.teams.map(team => ( + + {team.title} + + ))} + + )} + + {context?.wards && context.wards.length > 0 && ( + + + {translation('wards')} + + )} + headerClassName="!px-2.5 !py-1.5 hover:bg-black/30" + contentClassName="!px-0 !pb-0 gap-y-0" + className="!shadow-none" + isExpanded={context.sidebar.isShowingWards} + onChange={isExpanded => context.update(prevState => ({ + ...prevState, + sidebar: { + ...prevState.sidebar, + isShowingWards: isExpanded, + } + }))} + > + {context.wards.map(ward => ( + + {ward.title} + + ))} + + )} + + {context?.clinics && context.clinics.length > 0 && ( + + + {translation('clinics')} + + )} + headerClassName="!px-2.5 !py-1.5 hover:bg-black/30" + contentClassName="!px-0 !pb-0 gap-y-0" + className="!shadow-none" + isExpanded={context.sidebar.isShowingClinics} + onChange={isExpanded => context.update(prevState => ({ + ...prevState, + sidebar: { + ...prevState.sidebar, + isShowingClinics: isExpanded, + } + }))} + > + {context.clinics.map(clinic => ( + + {clinic.title} + + ))} + + )} diff --git a/web/components/locations/LocationSelectionDialog.tsx b/web/components/locations/LocationSelectionDialog.tsx index b73524e..ae2fbf7 100644 --- a/web/components/locations/LocationSelectionDialog.tsx +++ b/web/components/locations/LocationSelectionDialog.tsx @@ -179,7 +179,6 @@ export const LocationSelectionDialog = ({ const hasInitialized = useRef(false) - // Helper function to get all descendant IDs of a node (recursively) const getAllDescendantIds = useMemo(() => { if (!data?.locationNodes) return () => new Set() const nodes = data.locationNodes as LocationNodeType[] @@ -201,8 +200,7 @@ export const LocationSelectionDialog = ({ } }, [data?.locationNodes]) - // Helper function to get all ancestor IDs of a node (recursively) - const getAllAncestorIds = useMemo(() => { + const _getAllAncestorIds = useMemo(() => { if (!data?.locationNodes) return () => new Set() const nodes = data.locationNodes as LocationNodeType[] @@ -212,75 +210,29 @@ export const LocationSelectionDialog = ({ while (current?.parentId) { ancestors.add(current.parentId) - current = nodes.find(n => n.id === current.parentId) + const parentId = current.parentId + const parent = nodes.find(n => n.id === parentId) + if (!parent) break + current = parent } return ancestors } }, [data?.locationNodes]) - // Simplify selection: prefer children over parents (most specific selection wins) - const simplifySelection = useMemo(() => { - if (!data?.locationNodes || useCase !== 'root') { - return (ids: string[]): string[] => ids - } - const nodes = data.locationNodes as LocationNodeType[] - - return (ids: string[]): string[] => { - if (ids.length === 0) return ids - - const idSet = new Set(ids) - const simplified = new Set() - - // First pass: prefer children over parents - // If both a parent and child are selected, keep only the child - for (const id of ids) { - let current: LocationNodeType | undefined = nodes.find(n => n.id === id) - let hasAncestorSelected = false - - // Check if any ancestor is also selected - while (current?.parentId) { - if (idSet.has(current.parentId)) { - hasAncestorSelected = true - break - } - current = nodes.find(n => n.id === current?.parentId) - } - - // Only add if no ancestor is selected (child wins over parent) - if (!hasAncestorSelected) { - simplified.add(id) - } - } - - // Second pass: remove descendants of selected nodes (parent includes children) - const finalSet = new Set() - for (const id of simplified) { - // Check if this node has any descendants in the simplified set - const descendants = getAllDescendantIds(id) - const hasDescendantInSimplified = Array.from(descendants).some(descId => simplified.has(descId)) - - // Only add if no descendant is in simplified set (child wins over parent) - if (!hasDescendantInSimplified) { - finalSet.add(id) - } - } - - return Array.from(finalSet) - } - }, [data?.locationNodes, useCase, getAllDescendantIds]) + const _simplifySelection = useMemo(() => { + return (ids: string[]): string[] => ids + }, []) useEffect(() => { if (isOpen) { - // Simplify initial selection when dialog opens - const simplifiedIds = simplifySelection(initialSelectedIds) - setSelectedIds(new Set(simplifiedIds)) + setSelectedIds(new Set(initialSelectedIds)) setExpandedIds(new Set()) hasInitialized.current = true } else { hasInitialized.current = false } - }, [isOpen, initialSelectedIds, simplifySelection]) + }, [isOpen, initialSelectedIds]) const matchesFilter = useMemo(() => { if (useCase === 'default') { @@ -288,7 +240,6 @@ export const LocationSelectionDialog = ({ } if (useCase === 'root') { - // Only hospitals, practices, clinics, and teams are selectable for root locations const allowedKinds = new Set([ LocationType.Hospital, LocationType.Practice, @@ -427,22 +378,11 @@ export const LocationSelectionDialog = ({ const handleToggleSelect = (node: LocationNodeType, checked: boolean) => { const newSet = new Set(selectedIds) if (checked) { - // For clinic useCase, enforce exactly one selection (clear all first) if (useCase === 'clinic' || !multiSelect) { newSet.clear() } - - // Simplification logic: only for root useCase - if (useCase === 'root') { - // Remove all descendants of this node (parent includes children) - const descendants = getAllDescendantIds(node.id) - descendants.forEach(descId => newSet.delete(descId)) - - // Remove all ancestors of this node (child replaces parent) - const ancestors = getAllAncestorIds(node.id) - ancestors.forEach(ancId => newSet.delete(ancId)) - } - + + newSet.add(node.id) } else { newSet.delete(node.id) @@ -474,7 +414,9 @@ export const LocationSelectionDialog = ({ if (!data?.locationNodes) return if (selectedIds.size === 0) return const nodes = data.locationNodes as LocationNodeType[] - const selectedNodes = nodes.filter(n => selectedIds.has(n.id)) + + const finalSelectedIds = Array.from(selectedIds) + const selectedNodes = nodes.filter(n => finalSelectedIds.includes(n.id)) onSelect(selectedNodes) onClose() } @@ -499,6 +441,7 @@ export const LocationSelectionDialog = ({ {useCase === 'clinic' ? translation('pickClinic') : useCase === 'position' ? translation('pickPosition') : useCase === 'teams' ? translation('pickTeams') : + useCase === 'root' ? translation('selectRootLocation') : translation('selectLocation')} )} @@ -506,6 +449,7 @@ export const LocationSelectionDialog = ({ useCase === 'clinic' ? translation('pickClinicDescription') : useCase === 'position' ? translation('pickPositionDescription') : useCase === 'teams' ? translation('pickTeamsDescription') : + useCase === 'root' ? translation('selectRootLocationDescription') : translation('selectLocationDescription') } className="w-[600px] h-[80vh] flex flex-col max-w-full" @@ -531,7 +475,7 @@ export const LocationSelectionDialog = ({ - {multiSelect && ( + {multiSelect && useCase !== 'root' && (