Skip to content

Commit

Permalink
document problems with tortoise-orm
Browse files Browse the repository at this point in the history
Signed-off-by: Teo Koon Peng <[email protected]>
  • Loading branch information
koonpeng committed May 28, 2024
1 parent d3e1362 commit c2594b2
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 12 deletions.
3 changes: 1 addition & 2 deletions packages/api-server/api_server/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
def add_pagination(
query: QuerySet[MODEL],
pagination: Pagination,
field_mappings: Optional[Dict[str, str]] = None,
field_mappings: Dict[str, str] = {},
) -> QuerySet[MODEL]:
"""
Adds pagination and ordering to a query.
Expand All @@ -17,7 +17,6 @@ def add_pagination(
query. e.g. a url of `?order_by=order_field` and a field mapping of `{"order_field": "db_field"}`
will order the query result according to `db_field`.
"""
field_mappings = field_mappings or {}
query = query.limit(pagination.limit).offset(pagination.offset)
if pagination.order_by is not None:
order_fields = []
Expand Down
32 changes: 30 additions & 2 deletions packages/api-server/api_server/routes/tasks/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from fastapi import Body, Depends, HTTPException, Path, Query
from rx import operators as rxops
from tortoise.expressions import Case, F, Q, RawSQL, Subquery, When
from tortoise.functions import Max

from api_server import models as mdl
from api_server.dependencies import (
Expand All @@ -15,6 +17,7 @@
)
from api_server.fast_io import FastIORouter, SubscriptionRequest
from api_server.logging import LoggerAdapter, get_logger
from api_server.models.tortoise_models import TaskLabel as DbTaskLabel
from api_server.models.tortoise_models import TaskState as DbTaskState
from api_server.repositories import RmfRepository, TaskRepository
from api_server.response import RawJSONResponse
Expand Down Expand Up @@ -99,10 +102,10 @@ async def query_task_states(
None, description="comma separated list of requester names"
),
pickup: Optional[str] = Query(
None, description="comma separated list of pickup names"
None, description="comma separated list of pickup names", deprecated=True
),
destination: Optional[str] = Query(
None, description="comma separated list of destination names"
None, description="comma separated list of destination names", deprecated=True
),
assigned_to: Optional[str] = Query(
None, description="comma separated list of assigned robot names"
Expand All @@ -116,6 +119,10 @@ async def query_task_states(
status: Optional[str] = Query(None, description="comma separated list of statuses"),
pagination: mdl.Pagination = Depends(pagination_query),
):
"""
Note that sorting by `pickup` and `destination` is mutually exclusive and sorting
by either of them will filter only tasks which has those labels.
"""
filters = {}
if task_id is not None:
filters["id___in"] = task_id.split(",")
Expand Down Expand Up @@ -155,6 +162,27 @@ async def query_task_states(
# NOTE: In order to perform sorting based on the values in labels, a filter
# on the label_name has to be performed first. A side-effect of this would
# be that states that do not contain this field will not be returned.
#
# tortoise-orm lacks too many features to implement a proper sort logic, for
# reference, these are some solutions in sql
#
# Solution 1 (can't do multiple joins):
# SELECT t.*
# FROM tasks t
# LEFT JOIN tasklabels tl_foo ON t.task_id = tl_foo.task_id AND tl_foo.label_name = 'foo'
# LEFT JOIN tasklabels tl_bar ON t.task_id = tl_bar.task_id AND tl_bar.label_name = 'bar'
# ORDER BY
# tl_foo.label_value, -- Primary sort by 'foo' label
# tl_bar.label_value; -- Secondary sort by 'bar' label
#
# Solution 2 (can't do MAX on CASE WHEN):
# SELECT t.task_id,
# MAX(CASE WHEN tl.label_name = 'foo' THEN tl.label_value END) AS foo_value,
# MAX(CASE WHEN tl.label_name = 'bar' THEN tl.label_value END) AS bar_value
# FROM tasks t
# LEFT JOIN tasklabels tl ON t.task_id = tl.task_id
# GROUP BY t.task_id
# ORDER BY foo_value, bar_value;
if pagination.order_by is not None:
labels_fields = ["pickup", "destination"]
new_order = pagination.order_by
Expand Down
31 changes: 29 additions & 2 deletions packages/api-server/api_server/routes/tasks/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,17 @@ class TestTasksRoute(AppFixture):
@classmethod
def setUpClass(cls):
super().setUpClass()
task_ids = [uuid4()]
cls.task_states = [make_task_state(task_id=f"test_{x}") for x in task_ids]
booking_labels_2 = make_task_booking_label()
booking_labels_2.description["pickup"] = "AAA"
booking_labels_2.description["destination"] = "BBB"
task_ids = [uuid4(), uuid4()]
cls.task_states = [
make_task_state(task_id=f"test_{task_ids[0]}"),
make_task_state(
task_id=f"test_{task_ids[1]}",
booking_labels=[booking_labels_2.json()],
),
]
cls.task_logs = [make_task_log(task_id=f"test_{x}") for x in task_ids]

with cls.client.websocket_connect("/_internal") as ws:
Expand Down Expand Up @@ -60,6 +69,24 @@ def test_query_task_states_by_label(self):
self.task_states[0].booking.id, results[0]["booking"]["id"]
)

# FIXME(koonpeng): This does not work because of tortoise-orm limitations
# def test_query_task_states_sort_by_label(self):
# """Checks that sorting by `pickup` for `destination` does not filter out tasks"""
# test_cases = {
# "pickup": "Kitchen",
# "destination": "room_203",
# }
# for k, v in test_cases.items():
# resp = self.client.get(
# f"/tasks?task_id={self.task_states[0].booking.id}&order_by=pickup"
# )
# self.assertEqual(200, resp.status_code)
# results = resp.json()
# self.assertEqual(2, len(results))
# self.assertEqual(
# self.task_states[1].booking.id, results[0]["booking"]["id"]
# )

def test_sub_task_state(self):
task_id = self.task_states[0].booking.id
gen = self.subscribe_sio(f"/tasks/{task_id}/state")
Expand Down
16 changes: 10 additions & 6 deletions packages/api-server/api_server/test/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,10 @@ def make_task_booking_label() -> TaskBookingLabel:
)


def make_task_state(task_id: str = "test_task") -> TaskState:
def make_task_state(
task_id: str = "test_task",
booking_labels: list[str] | None = None,
) -> TaskState:
# from https://raw.githubusercontent.com/open-rmf/rmf_api_msgs/960b286d9849fc716a3043b8e1f5fb341bdf5778/rmf_api_msgs/samples/task_state/multi_dropoff_delivery.json
sample_task = json.loads(
"""
Expand Down Expand Up @@ -443,11 +446,12 @@ def make_task_state(task_id: str = "test_task") -> TaskState:
)
sample_task["booking"]["id"] = task_id

booking_labels = [
"dummy_label_1",
"dummy_label_2",
make_task_booking_label().json(),
]
if booking_labels is None:
booking_labels = [
"dummy_label_1",
"dummy_label_2",
make_task_booking_label().json(),
]
sample_task["booking"]["labels"] = booking_labels
return TaskState(**sample_task)

Expand Down

0 comments on commit c2594b2

Please sign in to comment.