From c2594b22476d9fd7fabc9dcd9d328d9b492ab4ce Mon Sep 17 00:00:00 2001 From: Teo Koon Peng Date: Tue, 28 May 2024 15:56:01 +0800 Subject: [PATCH] document problems with tortoise-orm Signed-off-by: Teo Koon Peng --- packages/api-server/api_server/query.py | 3 +- .../api_server/routes/tasks/tasks.py | 32 +++++++++++++++++-- .../api_server/routes/tasks/test_tasks.py | 31 ++++++++++++++++-- .../api-server/api_server/test/test_data.py | 16 ++++++---- 4 files changed, 70 insertions(+), 12 deletions(-) diff --git a/packages/api-server/api_server/query.py b/packages/api-server/api_server/query.py index b04c3d565..8ae4347d8 100644 --- a/packages/api-server/api_server/query.py +++ b/packages/api-server/api_server/query.py @@ -8,7 +8,7 @@ def add_pagination( query: QuerySet[MODEL], pagination: Pagination, - field_mappings: Optional[Dict[str, str]] = None, + field_mappings: Dict[str, str] = {}, ) -> QuerySet[MODEL]: """ Adds pagination and ordering to a query. @@ -17,7 +17,6 @@ def add_pagination( query. e.g. a url of `?order_by=order_field` and a field mapping of `{"order_field": "db_field"}` will order the query result according to `db_field`. """ - field_mappings = field_mappings or {} query = query.limit(pagination.limit).offset(pagination.offset) if pagination.order_by is not None: order_fields = [] diff --git a/packages/api-server/api_server/routes/tasks/tasks.py b/packages/api-server/api_server/routes/tasks/tasks.py index 2a8030c91..1704ce312 100644 --- a/packages/api-server/api_server/routes/tasks/tasks.py +++ b/packages/api-server/api_server/routes/tasks/tasks.py @@ -3,6 +3,8 @@ from fastapi import Body, Depends, HTTPException, Path, Query from rx import operators as rxops +from tortoise.expressions import Case, F, Q, RawSQL, Subquery, When +from tortoise.functions import Max from api_server import models as mdl from api_server.dependencies import ( @@ -15,6 +17,7 @@ ) from api_server.fast_io import FastIORouter, SubscriptionRequest from api_server.logging import LoggerAdapter, get_logger +from api_server.models.tortoise_models import TaskLabel as DbTaskLabel from api_server.models.tortoise_models import TaskState as DbTaskState from api_server.repositories import RmfRepository, TaskRepository from api_server.response import RawJSONResponse @@ -99,10 +102,10 @@ async def query_task_states( None, description="comma separated list of requester names" ), pickup: Optional[str] = Query( - None, description="comma separated list of pickup names" + None, description="comma separated list of pickup names", deprecated=True ), destination: Optional[str] = Query( - None, description="comma separated list of destination names" + None, description="comma separated list of destination names", deprecated=True ), assigned_to: Optional[str] = Query( None, description="comma separated list of assigned robot names" @@ -116,6 +119,10 @@ async def query_task_states( status: Optional[str] = Query(None, description="comma separated list of statuses"), pagination: mdl.Pagination = Depends(pagination_query), ): + """ + Note that sorting by `pickup` and `destination` is mutually exclusive and sorting + by either of them will filter only tasks which has those labels. + """ filters = {} if task_id is not None: filters["id___in"] = task_id.split(",") @@ -155,6 +162,27 @@ async def query_task_states( # NOTE: In order to perform sorting based on the values in labels, a filter # on the label_name has to be performed first. A side-effect of this would # be that states that do not contain this field will not be returned. + # + # tortoise-orm lacks too many features to implement a proper sort logic, for + # reference, these are some solutions in sql + # + # Solution 1 (can't do multiple joins): + # SELECT t.* + # FROM tasks t + # LEFT JOIN tasklabels tl_foo ON t.task_id = tl_foo.task_id AND tl_foo.label_name = 'foo' + # LEFT JOIN tasklabels tl_bar ON t.task_id = tl_bar.task_id AND tl_bar.label_name = 'bar' + # ORDER BY + # tl_foo.label_value, -- Primary sort by 'foo' label + # tl_bar.label_value; -- Secondary sort by 'bar' label + # + # Solution 2 (can't do MAX on CASE WHEN): + # SELECT t.task_id, + # MAX(CASE WHEN tl.label_name = 'foo' THEN tl.label_value END) AS foo_value, + # MAX(CASE WHEN tl.label_name = 'bar' THEN tl.label_value END) AS bar_value + # FROM tasks t + # LEFT JOIN tasklabels tl ON t.task_id = tl.task_id + # GROUP BY t.task_id + # ORDER BY foo_value, bar_value; if pagination.order_by is not None: labels_fields = ["pickup", "destination"] new_order = pagination.order_by diff --git a/packages/api-server/api_server/routes/tasks/test_tasks.py b/packages/api-server/api_server/routes/tasks/test_tasks.py index d15d4da0f..fda0a5dcf 100644 --- a/packages/api-server/api_server/routes/tasks/test_tasks.py +++ b/packages/api-server/api_server/routes/tasks/test_tasks.py @@ -15,8 +15,17 @@ class TestTasksRoute(AppFixture): @classmethod def setUpClass(cls): super().setUpClass() - task_ids = [uuid4()] - cls.task_states = [make_task_state(task_id=f"test_{x}") for x in task_ids] + booking_labels_2 = make_task_booking_label() + booking_labels_2.description["pickup"] = "AAA" + booking_labels_2.description["destination"] = "BBB" + task_ids = [uuid4(), uuid4()] + cls.task_states = [ + make_task_state(task_id=f"test_{task_ids[0]}"), + make_task_state( + task_id=f"test_{task_ids[1]}", + booking_labels=[booking_labels_2.json()], + ), + ] cls.task_logs = [make_task_log(task_id=f"test_{x}") for x in task_ids] with cls.client.websocket_connect("/_internal") as ws: @@ -60,6 +69,24 @@ def test_query_task_states_by_label(self): self.task_states[0].booking.id, results[0]["booking"]["id"] ) + # FIXME(koonpeng): This does not work because of tortoise-orm limitations + # def test_query_task_states_sort_by_label(self): + # """Checks that sorting by `pickup` for `destination` does not filter out tasks""" + # test_cases = { + # "pickup": "Kitchen", + # "destination": "room_203", + # } + # for k, v in test_cases.items(): + # resp = self.client.get( + # f"/tasks?task_id={self.task_states[0].booking.id}&order_by=pickup" + # ) + # self.assertEqual(200, resp.status_code) + # results = resp.json() + # self.assertEqual(2, len(results)) + # self.assertEqual( + # self.task_states[1].booking.id, results[0]["booking"]["id"] + # ) + def test_sub_task_state(self): task_id = self.task_states[0].booking.id gen = self.subscribe_sio(f"/tasks/{task_id}/state") diff --git a/packages/api-server/api_server/test/test_data.py b/packages/api-server/api_server/test/test_data.py index 4e43386c4..3a1970c5c 100644 --- a/packages/api-server/api_server/test/test_data.py +++ b/packages/api-server/api_server/test/test_data.py @@ -143,7 +143,10 @@ def make_task_booking_label() -> TaskBookingLabel: ) -def make_task_state(task_id: str = "test_task") -> TaskState: +def make_task_state( + task_id: str = "test_task", + booking_labels: list[str] | None = None, +) -> TaskState: # from https://raw.githubusercontent.com/open-rmf/rmf_api_msgs/960b286d9849fc716a3043b8e1f5fb341bdf5778/rmf_api_msgs/samples/task_state/multi_dropoff_delivery.json sample_task = json.loads( """ @@ -443,11 +446,12 @@ def make_task_state(task_id: str = "test_task") -> TaskState: ) sample_task["booking"]["id"] = task_id - booking_labels = [ - "dummy_label_1", - "dummy_label_2", - make_task_booking_label().json(), - ] + if booking_labels is None: + booking_labels = [ + "dummy_label_1", + "dummy_label_2", + make_task_booking_label().json(), + ] sample_task["booking"]["labels"] = booking_labels return TaskState(**sample_task)