Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 44 additions & 10 deletions src/dodal/devices/robot.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import asyncio
from asyncio import FIRST_COMPLETED, CancelledError, Task, wait_for
from asyncio import ALL_COMPLETED, FIRST_COMPLETED, CancelledError, Task, wait_for
from dataclasses import dataclass
from enum import IntEnum

Expand Down Expand Up @@ -113,6 +113,8 @@ class BartRobot(StandardReadable, Movable[SampleLocation]):
CRYO_MODE_WARM = 0.0
CRYO_MODE_CRYO = 1.0

NO_PIN_LOCATION = SampleLocation(0, 0)

def __init__(self, prefix: str, name: str = "") -> None:
with self.add_children_as_readables(StandardReadableFormat.HINTED_SIGNAL):
self.barcode = epics_signal_r(str, prefix + "BARCODE")
Expand Down Expand Up @@ -153,10 +155,19 @@ def __init__(self, prefix: str, name: str = "") -> None:
)
super().__init__(name=name)

async def beamline_status_or_error(self, expected_state: BeamlineStatus):
async def beamline_status_or_error(
self,
expected_state: BeamlineStatus,
sample_location: SampleLocation | None = None,
):
"""This co-routine will finish when either the beamline reaches the specified
state or the robot gives an error (whichever happens first). In the case where
there is an error a RobotLoadError error is raised.

Args:
expected_state (BeamlineStatus): The beamline state to wait for
sample_location (SampleLocation): The loaded puck and pin to wait for, or None to not
wait.
"""

async def raise_if_prog_error():
Expand All @@ -182,14 +193,32 @@ async def raise_if_ctl_error():
async def wait_for_expected_state():
await wait_for_value(self.beamline_disabled, expected_state.value, None)

async def wait_for_puck(_sample_location: SampleLocation):
await wait_for_value(self.current_puck, _sample_location.puck, None)

async def wait_for_pin(_sample_location: SampleLocation):
await wait_for_value(self.current_pin, _sample_location.pin, None)

check_for_prog_error_task = Task(raise_if_prog_error())
check_for_ctl_error_task = Task(raise_if_ctl_error())
tasks = [
(Task(raise_if_prog_error())),
(Task(raise_if_ctl_error())),
(Task(wait_for_expected_state())),
]
if sample_location:
tasks += [
Task(wait_for_puck(sample_location)),
Task(wait_for_pin(sample_location)),
]
check_for_completion_conditions = asyncio.create_task(
asyncio.wait(tasks, return_when=ALL_COMPLETED)
)
try:
finished, unfinished = await asyncio.wait(
tasks,
[
check_for_prog_error_task,
check_for_ctl_error_task,
check_for_completion_conditions,
],
return_when=FIRST_COMPLETED,
)
for task in unfinished:
Expand All @@ -201,7 +230,8 @@ async def wait_for_expected_state():
# in the current task, when it propagates to here we should cancel all pending tasks before bubbling up
for task in tasks:
task.cancel()

check_for_prog_error_task.cancel()
check_for_ctl_error_task.cancel()
raise

async def _check_errors_and_clear_if_retryable(self):
Expand Down Expand Up @@ -243,15 +273,17 @@ async def _load_pin_and_puck(self, sample_location: SampleLocation):
set_and_wait_for_value(self.next_pin, sample_location.pin),
)
await self.load.trigger()
await self._wait_for_beamline_enabled_after_load_or_unload()
await self._wait_for_beamline_enabled_after_load_or_unload(sample_location)

async def _wait_for_beamline_enabled_after_load_or_unload(self):
async def _wait_for_beamline_enabled_after_load_or_unload(
self, sample_location: SampleLocation
):
if await self.beamline_disabled.get_value() == BeamlineStatus.ENABLED.value:
LOGGER.info(WAIT_FOR_BEAMLINE_DISABLE_MSG)
await self.beamline_status_or_error(BeamlineStatus.DISABLED)

LOGGER.info(WAIT_FOR_BEAMLINE_ENABLE_MSG)
await self.beamline_status_or_error(BeamlineStatus.ENABLED)
await self.beamline_status_or_error(BeamlineStatus.ENABLED, sample_location)

@AsyncStatus.wrap
async def set(self, value: SampleLocation):
Expand All @@ -275,7 +307,9 @@ async def set(self, value: SampleLocation):
else:
await self.unload.trigger(timeout=self.LOAD_TIMEOUT)
await wait_for(
self._wait_for_beamline_enabled_after_load_or_unload(),
self._wait_for_beamline_enabled_after_load_or_unload(
self.NO_PIN_LOCATION
),
timeout=self.LOAD_TIMEOUT + self.NOT_BUSY_TIMEOUT,
)
except TimeoutError as e:
Expand Down
Loading
Loading