diff --git a/packages/api-server/api_server/models/tortoise_models/__init__.py b/packages/api-server/api_server/models/tortoise_models/__init__.py index 0028a5e78..384dd978d 100644 --- a/packages/api-server/api_server/models/tortoise_models/__init__.py +++ b/packages/api-server/api_server/models/tortoise_models/__init__.py @@ -26,6 +26,7 @@ TaskEventLogPhasesEventsLog, TaskEventLogPhasesLog, TaskFavorite, + TaskLabel, TaskRequest, TaskState, ) diff --git a/packages/api-server/api_server/models/tortoise_models/tasks.py b/packages/api-server/api_server/models/tortoise_models/tasks.py index a3c28394f..3111fc95a 100644 --- a/packages/api-server/api_server/models/tortoise_models/tasks.py +++ b/packages/api-server/api_server/models/tortoise_models/tasks.py @@ -1,4 +1,5 @@ from tortoise.fields import ( + BigIntField, CharField, DatetimeField, ForeignKeyField, @@ -29,6 +30,14 @@ class TaskState(Model): unix_millis_warn_time = DatetimeField(null=True, index=True) pickup = CharField(255, null=True, index=True) destination = CharField(255, null=True, index=True) + labels = ReverseRelation["TaskLabel"] + + +class TaskLabel(Model): + state = ForeignKeyField("models.TaskState", null=True, related_name="labels") + label_name = CharField(255, null=False, index=True) + label_value_str = CharField(255, null=True, index=True) + label_value_num = BigIntField(null=True, index=True) class TaskEventLog(Model): diff --git a/packages/api-server/api_server/repositories/tasks.py b/packages/api-server/api_server/repositories/tasks.py index 3b3872e25..0bb993788 100644 --- a/packages/api-server/api_server/repositories/tasks.py +++ b/packages/api-server/api_server/repositories/tasks.py @@ -59,15 +59,6 @@ async def query_task_requests(self, task_ids: List[str]) -> List[DbTaskRequest]: raise HTTPException(422, str(e)) from e async def save_task_state(self, task_state: TaskState) -> None: - labels = task_state.booking.labels - booking_label = None - if labels is not None: - for l in labels: - validated_booking_label = TaskBookingLabel.from_json_string(l) - if validated_booking_label is not None: - booking_label = validated_booking_label - break - task_state_dict = { "data": task_state.json(), "category": task_state.category.__root__ if task_state.category else None, @@ -86,23 +77,10 @@ async def save_task_state(self, task_state: TaskState) -> None: "requester": task_state.booking.requester if task_state.booking.requester else None, - "pickup": booking_label.description.pickup - if booking_label is not None - and booking_label.description.pickup is not None - else None, - "destination": booking_label.description.destination - if booking_label is not None - and booking_label.description.destination is not None - else None, } - if task_state.unix_millis_warn_time is not None: - task_state_dict["unix_millis_warn_time"] = datetime.fromtimestamp( - task_state.unix_millis_warn_time / 1000 - ) - try: - await ttm.TaskState.update_or_create( + state, created = await ttm.TaskState.update_or_create( task_state_dict, id_=task_state.booking.id ) except Exception as e: # pylint: disable=W0703 @@ -119,6 +97,43 @@ async def save_task_state(self, task_state: TaskState) -> None: self.logger.error( f"Failed to save task state of id [{task_state.booking.id}] [{e}]" ) + return + + if not created: + return + + # Save the labels that we want + labels = task_state.booking.labels + booking_label = None + if labels is not None: + for l in labels: + validated_booking_label = TaskBookingLabel.from_json_string(l) + if validated_booking_label is not None: + booking_label = validated_booking_label + break + if booking_label is None: + return + + # Here we generate the labels required for server-side sorting and + # filtering. + if booking_label.description.pickup is not None: + await ttm.TaskLabel.create( + state=state, + label_name="pickup", + label_value_str=booking_label.description.pickup, + ) + if booking_label.description.destination is not None: + await ttm.TaskLabel.create( + state=state, + label_name="destination", + label_value_str=booking_label.description.destination, + ) + if booking_label.description.unix_millis_warn_time is not None: + await ttm.TaskLabel.create( + state=state, + label_name="unix_millis_warn_time", + label_value_num=booking_label.description.unix_millis_warn_time, + ) async def query_task_states( self, query: QuerySet[DbTaskState], pagination: Optional[Pagination] = None diff --git a/packages/api-server/api_server/routes/tasks/tasks.py b/packages/api-server/api_server/routes/tasks/tasks.py index b9549382d..2a8030c91 100644 --- a/packages/api-server/api_server/routes/tasks/tasks.py +++ b/packages/api-server/api_server/routes/tasks/tasks.py @@ -126,10 +126,6 @@ async def query_task_states( filters["unix_millis_request_time__lte"] = request_time_between[1] if requester is not None: filters["requester__in"] = requester.split(",") - if pickup is not None: - filters["pickup__in"] = pickup.split(",") - if destination is not None: - filters["destination__in"] = destination.split(",") if assigned_to is not None: filters["assigned_to__in"] = assigned_to.split(",") if start_time_between is not None: @@ -146,6 +142,31 @@ async def query_task_states( continue filters["status__in"].append(mdl.Status(status_string)) + # NOTE: in order to perform filtering based on the values in labels, a + # filter on the label_name will need to be applied as well as a filter on + # the label_value. + if pickup is not None: + filters["labels__label_name"] = "pickup" + filters["labels__label_value_str__in"] = pickup.split(",") + if destination is not None: + filters["labels__label_name"] = "destination" + filters["labels__label_value_str__in"] = destination.split(",") + + # NOTE: In order to perform sorting based on the values in labels, a filter + # on the label_name has to be performed first. A side-effect of this would + # be that states that do not contain this field will not be returned. + if pagination.order_by is not None: + labels_fields = ["pickup", "destination"] + new_order = pagination.order_by + for field in labels_fields: + if field in pagination.order_by: + filters["labels__label_name"] = field + new_order = pagination.order_by.replace( + field, "labels__label_value_str" + ) + break + pagination.order_by = new_order + return await task_repo.query_task_states(DbTaskState.filter(**filters), pagination) diff --git a/packages/api-server/migrations/migrate_db_912.py b/packages/api-server/migrations/migrate_db_912.py index 1722ce1ae..968196b7a 100644 --- a/packages/api-server/migrations/migrate_db_912.py +++ b/packages/api-server/migrations/migrate_db_912.py @@ -18,11 +18,17 @@ # Before migration: # - Pickup, destination, cart ID, task definition id information will be # unavailable on the Task Queue Table on the dashboard, as we no longer gather -# those fields from the TaskRequest +# those fields from the TaskRequest. +# - TaskState database model contains optional CharFields for pickup and +# destination, to facilitate server-side sorting and filtering. # After migration: # - Dashboard will behave the same as before #912, however it is no longer # dependent on TaskRequest to fill out those fields. It gathers those fields # from the json string in TaskState.booking.labels. +# - In the database, we create a new generic key-value pair model, that allow +# us to encode all this information and tie them to a task state, and be used +# for sorting and filtering, using reverse relations, as opposed to fully +# filled out columns for TaskState. # This script performs the following: # - Construct TaskBookingLabel from its TaskRequest if it is available. # - Update the respective TaskState.data json TaskState.booking.labels field @@ -206,12 +212,17 @@ async def migrate(): ) # print(state_model) + if pickup is not None: + await ttm.TaskLabel.create( + state.state, label_name="pickup", label_value_str=pickup + ) + if destination is not None: + await ttm.TaskLabel.create( + state.state, label_name="destination", label_value_str=destination + ) + state.update_from_dict( - { - "data": state_model.json(exclude_none=True, separators=(",", ":")), - "pickup": pickup, - "destination": destination, - } + {"data": state_model.json(exclude_none=True, separators=(",", ":"))} ) await state.save()