Skip to content

Commit

Permalink
Tests for alert routes
Browse files Browse the repository at this point in the history
Signed-off-by: Aaron Chong <[email protected]>
  • Loading branch information
aaronchongth committed May 24, 2024
1 parent a18d89b commit 3e4a4ef
Show file tree
Hide file tree
Showing 4 changed files with 255 additions and 20 deletions.
2 changes: 2 additions & 0 deletions packages/api-server/api_server/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,8 @@ def convert_fleet_alert(fleet_alert: RmfFleetAlert):
for p in fleet_alert.alert_parameters:
parameters.append(AlertParameter(name=p.name, value=p.value))

# check task phases to find out what waypoint it is?

return AlertRequest(
id=fleet_alert.id,
unix_millis_alert_time=round(datetime.now().timestamp() * 1000),
Expand Down
40 changes: 21 additions & 19 deletions packages/api-server/api_server/routes/alerts.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
router = FastIORouter(tags=["Alerts"])


@router.sub("", response_model=AlertRequest)
@router.sub("/requests", response_model=AlertRequest)
async def sub_alerts(_req: SubscriptionRequest):
return alert_events.alert_requests.pipe(rxops.filter(lambda x: x is not None))

Expand All @@ -40,12 +40,27 @@ async def create_new_alert(alert: AlertRequest):
return alert


@router.get("/request/{alert_id}", response_model=AlertRequest)
async def get_alert(alert_id: str):
"""
Gets an alert based on the alert ID.
"""
alert = await ttm.AlertRequest.get_or_none(id=alert_id)
if alert is None:
raise HTTPException(404, f"Alert with ID {alert_id} does not exists")

alert_model = AlertRequest(**alert.data)
return alert_model


@router.sub("/responses", response_model=AlertResponse)
async def sub_alert_responses(_req: SubscriptionRequest):
return alert_events.alert_responses.pipe(rxops.filter(lambda x: x is not None))


@router.post("/{alert_id}/respond", status_code=201, response_model=AlertResponse)
@router.post(
"/request/{alert_id}/respond", status_code=201, response_model=AlertResponse
)
async def respond_to_alert(alert_id: str, response: str):
"""
Responds to an existing alert. The response must be one of the available
Expand Down Expand Up @@ -76,20 +91,7 @@ async def respond_to_alert(alert_id: str, response: str):
return alert_response_model


@router.get("/{alert_id}", response_model=AlertRequest)
async def get_alert(alert_id: str):
"""
Gets an alert based on the alert ID.
"""
alert = await ttm.AlertRequest.get_or_none(id=alert_id)
if alert is None:
raise HTTPException(404, f"Alert with ID {alert_id} does not exists")

alert_model = AlertRequest(**alert.data)
return alert_model


@router.get("/{alert_id}/response", response_model=AlertResponse)
@router.get("/request/{alert_id}/response", response_model=AlertResponse)
async def get_alert_response(alert_id: str):
"""
Gets the response to the alert based on the alert ID.
Expand All @@ -104,7 +106,7 @@ async def get_alert_response(alert_id: str):
return response_model


@router.get("/task/{task_id}", response_model=List[AlertRequest])
@router.get("/requests/task/{task_id}", response_model=List[AlertRequest])
async def get_alerts_of_task(task_id: str, unresponded: bool = True):
"""
Returns all the alerts associated to a task ID. Provides the option to only
Expand All @@ -123,8 +125,8 @@ async def get_alerts_of_task(task_id: str, unresponded: bool = True):
return alert_models


@router.get("/unresponded", response_model=List[AlertRequest])
async def get_unresponded_alert_ids():
@router.get("/unresponded_requests", response_model=List[AlertRequest])
async def get_unresponded_alerts():
"""
Returns the list of alert IDs that have yet to be responded to, while a
response was required.
Expand Down
215 changes: 215 additions & 0 deletions packages/api-server/api_server/routes/test_alerts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
from unittest.mock import patch
from urllib.parse import urlencode
from uuid import uuid4

from api_server import models as mdl
from api_server.rmf_io import tasks_service
from api_server.test import AppFixture, make_alert_request


class TestAlertsRoute(AppFixture):
@classmethod
def setUpClass(cls):
super().setUpClass()

def test_create_new_alert(self):
id = str(uuid4())
alert = make_alert_request(id=id, responses=["resume", "cancel"])
resp = self.client.post("/alerts/request", data=alert.json(exclude_none=True))
self.assertEqual(201, resp.status_code, resp.content)
self.assertEqual(alert, resp.json(), resp.content)

# repeated creation with same ID will fail
alert = make_alert_request(id=id, responses=["resume", "cancel"])
resp = self.client.post("/alerts/request", data=alert.json(exclude_none=True))
self.assertEqual(409, resp.status_code, resp.content)

def respond_to_alert(self):
id = str(uuid4())
alert = make_alert_request(id=id, responses=["resume", "cancel"])
resp = self.client.post("/alerts/request", data=alert.json(exclude_none=True))
self.assertEqual(201, resp.status_code, resp.content)
self.assertEqual(alert, resp.json(), resp.content)

# respond to alert that does not exist
params = {"response": "resume"}
resp = self.client.post(
f"/alerts/request/wrong_alert/respond?{urlencode(params)}"
)
self.assertEqual(404, resp.status_code, resp.content)

# response that is unavailable
params = {"response": "wrong"}
resp = self.client.post(f"/alerts/{id}/respond?{urlencode(params)}")
self.assertEqual(422, resp.status_code, resp.content)

# respond correctly
response = "resume"
params = {"response": response}
resp = self.client.post(f"/alerts/request/{id}/respond?{urlencode(params)}")
self.assertEqual(201, resp.status_code, resp.content)
self.assertEqual(id, resp.json()["id"], resp.content)
self.assertEqual(response, resp.json()["response"], resp.content)

def test_get_alert(self):
id = str(uuid4())

# alert does not exist
resp = self.client.get(f"/alerts/request/{id}")
self.assertEqual(404, resp.status_code, resp.content)

# create alert
alert = make_alert_request(id=id, responses=["resume", "cancel"])
resp = self.client.post("/alerts/request", data=alert.json(exclude_none=True))
self.assertEqual(201, resp.status_code, resp.content)
self.assertEqual(alert, resp.json(), resp.content)

# alert exists now
resp = self.client.get(f"/alerts/request/{id}")
self.assertEqual(200, resp.status_code, resp.content)
self.assertEqual(alert, resp.json(), resp.content)

def test_get_alert_response(self):
id = str(uuid4())
alert = make_alert_request(id=id, responses=["resume", "cancel"])
resp = self.client.post("/alerts/request", data=alert.json(exclude_none=True))
self.assertEqual(201, resp.status_code, resp.content)
self.assertEqual(alert, resp.json(), resp.content)

# respond
response = "resume"
params = {"response": response}
resp = self.client.post(f"/alerts/request/{id}/respond?{urlencode(params)}")
self.assertEqual(201, resp.status_code, resp.content)
self.assertEqual(id, resp.json()["id"], resp.content)
self.assertEqual(response, resp.json()["response"], resp.content)

# response exists
resp = self.client.get(f"/alerts/request/{id}/response")
self.assertEqual(200, resp.status_code, resp.content)
self.assertEqual(id, resp.json()["id"], resp.content)
self.assertEqual(response, resp.json()["response"], resp.content)

def test_sub_alert(self):
gen = self.subscribe_sio("/alerts/requests")

id = str(uuid4())
alert = make_alert_request(id=id, responses=["resume", "cancel"])
resp = self.client.post("/alerts/request", data=alert.json(exclude_none=True))
self.assertEqual(201, resp.status_code, resp.content)
self.assertEqual(alert, resp.json(), resp.content)

# check subscribed alert
subbed_alert = next(gen)
self.assertEqual(subbed_alert, alert, subbed_alert)

def test_sub_alert_response(self):
gen = self.subscribe_sio("/alerts/responses")

id = str(uuid4())
alert = make_alert_request(id=id, responses=["resume", "cancel"])
resp = self.client.post("/alerts/request", data=alert.json(exclude_none=True))
self.assertEqual(201, resp.status_code, resp.content)
self.assertEqual(alert, resp.json(), resp.content)

# respond
response = "resume"
params = {"response": response}
resp = self.client.post(f"/alerts/request/{id}/respond?{urlencode(params)}")
self.assertEqual(201, resp.status_code, resp.content)
self.assertEqual(id, resp.json()["id"], resp.content)
self.assertEqual(response, resp.json()["response"], resp.content)

# check subscribed alert response
subbed_alert_response = next(gen)
self.assertEqual(subbed_alert_response, resp.json(), subbed_alert_response)

def test_get_alerts_of_task(self):
id = str(uuid4())
alert = make_alert_request(id=id, responses=["resume", "cancel"])
alert.task_id = "test_task_id"
resp = self.client.post("/alerts/request", data=alert.json(exclude_none=True))
self.assertEqual(201, resp.status_code, resp.content)
self.assertEqual(alert, resp.json(), resp.content)

# check for non-existent alert for a wrong task ID
resp = self.client.get("/alerts/requests/task/wrong_task_id")
self.assertEqual(200, resp.status_code, resp.content)
self.assertEqual(0, len(resp.json()), resp.content)

# check for correct task ID
resp = self.client.get(f"/alerts/requests/task/{alert.task_id}")
self.assertEqual(200, resp.status_code, resp.content)
self.assertEqual(1, len(resp.json()), resp.content)
self.assertEqual(resp.json()[0], alert, resp.content)

# respond to alert
response = "resume"
params = {"response": response}
resp = self.client.post(f"/alerts/request/{id}/respond?{urlencode(params)}")
self.assertEqual(201, resp.status_code, resp.content)
self.assertEqual(id, resp.json()["id"], resp.content)
self.assertEqual(response, resp.json()["response"], resp.content)

# check for alert of correct task ID again (will only return
# unresponded by default)
resp = self.client.get(f"/alerts/requests/task/{alert.task_id}")
self.assertEqual(200, resp.status_code, resp.content)
self.assertEqual(0, len(resp.json()), resp.content)

# check for alert of correct task ID again with unresponded False
params = {"unresponded": False}
resp = self.client.get(
f"/alerts/requests/task/{alert.task_id}?{urlencode(params)}"
)
self.assertEqual(200, resp.status_code, resp.content)
self.assertEqual(1, len(resp.json()), resp.content)
self.assertEqual(resp.json()[0], alert, resp.content)

def test_get_unresponded_alert_ids(self):
first_id = str(uuid4())
first_alert = make_alert_request(id=first_id, responses=["resume", "cancel"])
resp = self.client.post(
"/alerts/request", data=first_alert.json(exclude_none=True)
)
self.assertEqual(201, resp.status_code, resp.content)
self.assertEqual(first_alert, resp.json(), resp.content)

second_id = str(uuid4())
second_alert = make_alert_request(id=second_id, responses=["resume", "cancel"])
resp = self.client.post(
"/alerts/request", data=second_alert.json(exclude_none=True)
)
self.assertEqual(201, resp.status_code, resp.content)
self.assertEqual(second_alert, resp.json(), resp.content)

# both alerts unresponded
resp = self.client.get("/alerts/unresponded_requests")
self.assertEqual(200, resp.status_code, resp.content)
unresponded_num = len(resp.json())
self.assertTrue(unresponded_num > 0, resp.content)
returned_alerts = resp.json()
returned_alert_ids = [a["id"] for a in returned_alerts]
self.assertTrue(first_id in returned_alert_ids)
self.assertTrue(second_id in returned_alert_ids)

# respond to first
response = "resume"
params = {"response": response}
resp = self.client.post(
f"/alerts/request/{first_id}/respond?{urlencode(params)}"
)
self.assertEqual(201, resp.status_code, resp.content)
self.assertEqual(first_id, resp.json()["id"], resp.content)
self.assertEqual(response, resp.json()["response"], resp.content)

# first is no longer returned
resp = self.client.get("/alerts/unresponded_requests")
self.assertEqual(200, resp.status_code, resp.content)
new_unresponded_num = len(resp.json())
self.assertTrue(new_unresponded_num > 0, resp.content)
self.assertTrue(unresponded_num - new_unresponded_num == 1, resp.content)
returned_alerts = resp.json()
returned_alert_ids = [a["id"] for a in returned_alerts]
self.assertTrue(first_id not in returned_alert_ids)
self.assertTrue(second_id in returned_alert_ids)
18 changes: 17 additions & 1 deletion packages/api-server/api_server/test/test_data.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Optional
from typing import List, Optional
from uuid import uuid4

from rmf_building_map_msgs.msg import Door as RmfDoor
Expand All @@ -10,6 +10,7 @@

from api_server.models import (
AffineImage,
AlertRequest,
BuildingMap,
DispenserState,
Door,
Expand Down Expand Up @@ -740,3 +741,18 @@ def make_task_log(task_id: str) -> TaskEventLog:
)
sample.task_id = task_id
return sample


def make_alert_request(id: str, responses: List[str]) -> AlertRequest:
return AlertRequest(
id=id,
unix_millis_alert_time=0,
title="test_title",
subtitle="test_subtitle",
message="test_message",
display=True,
tier=AlertRequest.Tier.Info,
responses_available=responses,
alert_parameters=[],
task_id=None,
)

0 comments on commit 3e4a4ef

Please sign in to comment.