Skip to content

Commit b7c04a5

Browse files
committed
Refactored alert db interactions to repository, with lru cache for location alerts checking
Signed-off-by: Aaron Chong <[email protected]>
1 parent deca732 commit b7c04a5

File tree

9 files changed

+356
-109
lines changed

9 files changed

+356
-109
lines changed

packages/api-server/api_server/gateway.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,14 @@
5151
LiftState,
5252
)
5353
from .models.delivery_alerts import action_from_msg, category_from_msg, tier_from_msg
54-
from .repositories import CachedFilesRepository, cached_files_repo
54+
from .repositories import (
55+
CachedFilesRepository,
56+
LocationAlertFailResponse,
57+
LocationAlertSuccessResponse,
58+
cached_files_repo,
59+
is_final_location_alert_check,
60+
task_id_to_all_locations_success_cache,
61+
)
5562
from .rmf_io import alert_events, rmf_events
5663
from .ros import ros_node
5764

@@ -274,8 +281,6 @@ def convert_fleet_alert(fleet_alert: RmfFleetAlert):
274281
for p in fleet_alert.alert_parameters:
275282
parameters.append(AlertParameter(name=p.name, value=p.value))
276283

277-
# check task phases to find out what waypoint it is?
278-
279284
return AlertRequest(
280285
id=fleet_alert.id,
281286
unix_millis_alert_time=round(datetime.now().timestamp() * 1000),
@@ -292,6 +297,19 @@ def convert_fleet_alert(fleet_alert: RmfFleetAlert):
292297
def handle_fleet_alert(fleet_alert: AlertRequest):
293298
logging.info("Received fleet alert:")
294299
logging.info(fleet_alert)
300+
301+
# Handle request for checking all location completion success for
302+
# this task
303+
is_final_check = is_final_location_alert_check(fleet_alert)
304+
if is_final_check:
305+
successful_so_far = task_id_to_all_locations_success_cache.lookup(
306+
fleet_alert.task_id
307+
)
308+
if successful_so_far is None or not successful_so_far:
309+
self.respond_to_alert(fleet_alert.id, LocationAlertFailResponse)
310+
else:
311+
self.respond_to_alert(fleet_alert.id, LocationAlertSuccessResponse)
312+
295313
alert_events.alert_requests.on_next(fleet_alert)
296314

297315
fleet_alert_sub = ros_node().create_subscription(

packages/api-server/api_server/repositories/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
from .alerts import (
2+
AlertRepository,
3+
LocationAlertFailResponse,
4+
LocationAlertSuccessResponse,
5+
is_final_location_alert_check,
6+
task_id_to_all_locations_success_cache,
7+
)
18
from .cached_files import CachedFilesRepository, cached_files_repo
29
from .fleets import FleetRepository
310
from .rmf import RmfRepository
Lines changed: 270 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,270 @@
1+
import logging
2+
from collections import deque
3+
from datetime import datetime
4+
from typing import List, Optional
5+
6+
from api_server.models import AlertRequest, AlertResponse
7+
from api_server.models import tortoise_models as ttm
8+
9+
# from api_server.gateway import rmf_gateway
10+
11+
12+
# TODO: not hardcode all these expected values
13+
LocationAlertSuccessResponse = "success"
14+
LocationAlertFailResponse = "fail"
15+
LocationAlertTypeParameterName = "type"
16+
LocationAlertTypeParameterValue = "location_result"
17+
LocationAlertLocationParameterName = "location_name"
18+
LocationAlertFinalCheckTypeParameterValue = "check_all_task_location_alerts"
19+
20+
21+
def get_location_from_location_alert(alert: AlertRequest) -> Optional[str]:
22+
"""
23+
Returns the location name from a location alert when possible, otherwise
24+
returns None.
25+
Note: This is an experimental feature and may be subjected to
26+
modifications often.
27+
"""
28+
if (
29+
len(alert.alert_parameters) < 2
30+
or LocationAlertSuccessResponse not in alert.responses_available
31+
or LocationAlertFailResponse not in alert.responses_available
32+
):
33+
return None
34+
35+
# Check type
36+
alert_type = None
37+
for param in alert.alert_parameters:
38+
if param.name == LocationAlertTypeParameterName:
39+
alert_type = param.value
40+
break
41+
if alert_type != LocationAlertTypeParameterValue:
42+
return None
43+
44+
# Check location name
45+
# TODO: make sure that there are no duplicated locations that have
46+
# not been responded to yet
47+
for param in alert.alert_parameters:
48+
if param.name == LocationAlertLocationParameterName:
49+
return param.value
50+
return None
51+
52+
53+
def is_final_location_alert_check(alert: AlertRequest) -> bool:
54+
"""
55+
Checks if the alert request requires a check on all location alerts of this
56+
task.
57+
Note: This is an experimental feature and may be subjected to
58+
modifications often.
59+
"""
60+
if (
61+
alert.task_id is None
62+
or len(alert.alert_parameters) < 1
63+
or LocationAlertSuccessResponse not in alert.responses_available
64+
or LocationAlertFailResponse not in alert.responses_available
65+
):
66+
return False
67+
68+
# Check type
69+
for param in alert.alert_parameters:
70+
if param.name == LocationAlertTypeParameterName:
71+
if param.value == LocationAlertFinalCheckTypeParameterValue:
72+
return True
73+
return False
74+
return False
75+
76+
77+
class LRUCache:
78+
def __init__(self, capacity: int):
79+
self._cache = deque(maxlen=capacity)
80+
self._lookup = {}
81+
82+
def add(self, key, value):
83+
if key in self._lookup:
84+
self._cache.remove(key)
85+
elif len(self._cache) == self._cache.maxlen:
86+
oldest_key = self._cache.popleft()
87+
del self._lookup[oldest_key]
88+
89+
self._cache.append(key)
90+
self._lookup[key] = value
91+
92+
def remove(self, key):
93+
if key in self._lookup:
94+
self._cache.remove(key)
95+
del self._lookup[key]
96+
97+
def lookup(self, key):
98+
if key in self._lookup:
99+
self._cache.remove(key)
100+
self._cache.append(key)
101+
return self._lookup[key]
102+
return None
103+
104+
105+
task_id_to_all_locations_success_cache: LRUCache = LRUCache(20)
106+
107+
108+
class AlertRepository:
109+
async def create_new_alert(self, alert: AlertRequest) -> Optional[AlertRequest]:
110+
exists = await ttm.AlertRequest.exists(id=alert.id)
111+
if exists:
112+
logging.error(f"Alert with ID {alert.id} already exists")
113+
return None
114+
115+
await ttm.AlertRequest.create(
116+
id=alert.id,
117+
data=alert.json(),
118+
response_expected=(len(alert.responses_available) > 0),
119+
task_id=alert.task_id,
120+
)
121+
return alert
122+
123+
async def get_alert(self, alert_id: str) -> Optional[AlertRequest]:
124+
alert = await ttm.AlertRequest.get_or_none(id=alert_id)
125+
if alert is None:
126+
logging.error(f"Alert with ID {alert_id} does not exists")
127+
return None
128+
129+
alert_model = AlertRequest(**alert.data)
130+
return alert_model
131+
132+
async def create_response(
133+
self, alert_id: str, response: str
134+
) -> Optional[AlertResponse]:
135+
alert = await ttm.AlertRequest.get_or_none(id=alert_id)
136+
if alert is None:
137+
logging.error(f"Alert with ID {alert_id} does not exists")
138+
return None
139+
140+
alert_model = AlertRequest(**alert.data)
141+
if response not in alert_model.responses_available:
142+
logging.error(
143+
f"Alert with ID {alert_model.id} does not have allow response of {response}"
144+
)
145+
return None
146+
147+
alert_response_model = AlertResponse(
148+
id=alert_id,
149+
unix_millis_response_time=round(datetime.now().timestamp() * 1000),
150+
response=response,
151+
)
152+
await ttm.AlertResponse.create(
153+
id=alert_id, alert_request=alert, data=alert_response_model.json()
154+
)
155+
return alert_response_model
156+
157+
async def get_alert_response(self, alert_id: str) -> Optional[AlertResponse]:
158+
response = await ttm.AlertResponse.get_or_none(id=alert_id)
159+
if response is None:
160+
logging.error(f"Response to alert with ID {alert_id} does not exists")
161+
return None
162+
163+
response_model = AlertResponse(**response.data)
164+
return response_model
165+
166+
async def get_alerts_of_task(
167+
self, task_id: str, unresponded: bool = True
168+
) -> List[AlertRequest]:
169+
if unresponded:
170+
task_id_alerts = await ttm.AlertRequest.filter(
171+
response_expected=True,
172+
task_id=task_id,
173+
alert_response=None,
174+
)
175+
else:
176+
task_id_alerts = await ttm.AlertRequest.filter(task_id=task_id)
177+
178+
alert_models = [AlertRequest(**alert.data) for alert in task_id_alerts]
179+
return alert_models
180+
181+
async def get_unresponded_alerts(self) -> List[AlertRequest]:
182+
unresponded_alerts = await ttm.AlertRequest.filter(
183+
alert_response=None, response_expected=True
184+
)
185+
return [AlertRequest(**alert.data) for alert in unresponded_alerts]
186+
187+
async def create_location_alert_response(
188+
self,
189+
task_id: str,
190+
location: str,
191+
success: bool,
192+
) -> Optional[AlertResponse]:
193+
"""
194+
Creates an alert response for a location alert of the task.
195+
Note: This is an experimental feature and may be subjected to
196+
modifications often.
197+
"""
198+
alerts = await self.get_alerts_of_task(task_id=task_id, unresponded=True)
199+
if len(alerts) == 0:
200+
logging.error(
201+
f"There are no location alerts awaiting response for task {task_id}"
202+
)
203+
return None
204+
205+
for alert in alerts:
206+
location_alert_location = get_location_from_location_alert(alert)
207+
if location_alert_location is None:
208+
continue
209+
210+
if location_alert_location == location:
211+
response = (
212+
LocationAlertSuccessResponse
213+
if success
214+
else LocationAlertFailResponse
215+
)
216+
alert_response_model = await self.create_response(alert.id, response)
217+
if alert_response_model is None:
218+
logging.error(
219+
f"Failed to create response {response} to alert with ID {alert.id}"
220+
)
221+
return None
222+
223+
# Cache if all locations of this task has been successful so far
224+
cache = task_id_to_all_locations_success_cache.lookup(task_id)
225+
if cache is None:
226+
task_id_to_all_locations_success_cache.add(task_id, success)
227+
else:
228+
task_id_to_all_locations_success_cache.add(
229+
task_id, cache and success
230+
)
231+
232+
return alert_response_model
233+
234+
logging.error(
235+
f"Task {task_id} is not awaiting completion of location {location}"
236+
)
237+
return None
238+
239+
async def check_all_task_location_alerts_if_succeeded(self, task_id: str) -> bool:
240+
"""
241+
Checks if all location alert reponses for the task were successful.
242+
Note: This is an experimental feature and may be subjected to
243+
modifications often.
244+
"""
245+
task_id_alerts = await ttm.AlertRequest.filter(task_id=task_id)
246+
if len(task_id_alerts) == 0:
247+
logging.info(f"There were no location alerts for task {task_id}")
248+
return False
249+
250+
for alert in task_id_alerts:
251+
alert_model = AlertRequest(**alert.data)
252+
location_alert_location = get_location_from_location_alert(alert_model)
253+
if location_alert_location is None:
254+
continue
255+
256+
if alert.alert_response is None:
257+
logging.info(
258+
f"Alert {alert_model.id} does not have a response, check return False"
259+
)
260+
return False
261+
262+
alert_response_model = AlertResponse(**alert.alert_response.data)
263+
if alert_response_model.response != LocationAlertSuccessResponse:
264+
logging.info(
265+
f"Alert {alert_model.id} has a response {alert_response_model.response}, check return False"
266+
)
267+
return False
268+
269+
logging.info(f"All location alerts for task {task_id} succeeded")
270+
return True

packages/api-server/api_server/rmf_io/book_keeper.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,14 @@
2121
LiftState,
2222
)
2323
from api_server.models.health import BaseBasicHealth
24+
from api_server.repositories import (
25+
AlertRepository, # , is_final_location_alert_check, LocationAlertFailResponse, LocationAlertSuccessResponse
26+
)
2427

2528
from .events import AlertEvents, RmfEvents
2629

30+
# from api_server.gateway import rmf_gateway
31+
2732

2833
class RmfBookKeeperEvents:
2934
def __init__(self):
@@ -38,6 +43,7 @@ def __init__(
3843
):
3944
self.rmf_events = rmf_events
4045
self.alert_events = alert_events
46+
self.alert_repository = AlertRepository()
4147
self.bookkeeper_events = RmfBookKeeperEvents()
4248
self._loop: asyncio.AbstractEventLoop
4349
self._pending_tasks = set()

0 commit comments

Comments
 (0)