diff --git a/src/dodal/devices/robot.py b/src/dodal/devices/robot.py index f886a6d835..623f9c99e8 100644 --- a/src/dodal/devices/robot.py +++ b/src/dodal/devices/robot.py @@ -1,5 +1,5 @@ import asyncio -from asyncio import FIRST_COMPLETED, CancelledError, Task, wait_for +from asyncio import ALL_COMPLETED, FIRST_COMPLETED, CancelledError, Task, wait_for from dataclasses import dataclass from enum import IntEnum @@ -113,6 +113,8 @@ class BartRobot(StandardReadable, Movable[SampleLocation]): CRYO_MODE_WARM = 0.0 CRYO_MODE_CRYO = 1.0 + NO_PIN_LOCATION = SampleLocation(0, 0) + def __init__(self, prefix: str, name: str = "") -> None: with self.add_children_as_readables(StandardReadableFormat.HINTED_SIGNAL): self.barcode = epics_signal_r(str, prefix + "BARCODE") @@ -153,10 +155,19 @@ def __init__(self, prefix: str, name: str = "") -> None: ) super().__init__(name=name) - async def beamline_status_or_error(self, expected_state: BeamlineStatus): + async def beamline_status_or_error( + self, + expected_state: BeamlineStatus, + sample_location: SampleLocation | None = None, + ): """This co-routine will finish when either the beamline reaches the specified state or the robot gives an error (whichever happens first). In the case where there is an error a RobotLoadError error is raised. + + Args: + expected_state (BeamlineStatus): The beamline state to wait for + sample_location (SampleLocation): The loaded puck and pin to wait for, or None to not + wait. """ async def raise_if_prog_error(): @@ -182,14 +193,32 @@ async def raise_if_ctl_error(): async def wait_for_expected_state(): await wait_for_value(self.beamline_disabled, expected_state.value, None) + async def wait_for_puck(_sample_location: SampleLocation): + await wait_for_value(self.current_puck, _sample_location.puck, None) + + async def wait_for_pin(_sample_location: SampleLocation): + await wait_for_value(self.current_pin, _sample_location.pin, None) + + check_for_prog_error_task = Task(raise_if_prog_error()) + check_for_ctl_error_task = Task(raise_if_ctl_error()) tasks = [ - (Task(raise_if_prog_error())), - (Task(raise_if_ctl_error())), (Task(wait_for_expected_state())), ] + if sample_location: + tasks += [ + Task(wait_for_puck(sample_location)), + Task(wait_for_pin(sample_location)), + ] + check_for_completion_conditions = asyncio.create_task( + asyncio.wait(tasks, return_when=ALL_COMPLETED) + ) try: finished, unfinished = await asyncio.wait( - tasks, + [ + check_for_prog_error_task, + check_for_ctl_error_task, + check_for_completion_conditions, + ], return_when=FIRST_COMPLETED, ) for task in unfinished: @@ -201,7 +230,8 @@ async def wait_for_expected_state(): # in the current task, when it propagates to here we should cancel all pending tasks before bubbling up for task in tasks: task.cancel() - + check_for_prog_error_task.cancel() + check_for_ctl_error_task.cancel() raise async def _check_errors_and_clear_if_retryable(self): @@ -243,15 +273,17 @@ async def _load_pin_and_puck(self, sample_location: SampleLocation): set_and_wait_for_value(self.next_pin, sample_location.pin), ) await self.load.trigger() - await self._wait_for_beamline_enabled_after_load_or_unload() + await self._wait_for_beamline_enabled_after_load_or_unload(sample_location) - async def _wait_for_beamline_enabled_after_load_or_unload(self): + async def _wait_for_beamline_enabled_after_load_or_unload( + self, sample_location: SampleLocation + ): if await self.beamline_disabled.get_value() == BeamlineStatus.ENABLED.value: LOGGER.info(WAIT_FOR_BEAMLINE_DISABLE_MSG) await self.beamline_status_or_error(BeamlineStatus.DISABLED) LOGGER.info(WAIT_FOR_BEAMLINE_ENABLE_MSG) - await self.beamline_status_or_error(BeamlineStatus.ENABLED) + await self.beamline_status_or_error(BeamlineStatus.ENABLED, sample_location) @AsyncStatus.wrap async def set(self, value: SampleLocation): @@ -275,7 +307,9 @@ async def set(self, value: SampleLocation): else: await self.unload.trigger(timeout=self.LOAD_TIMEOUT) await wait_for( - self._wait_for_beamline_enabled_after_load_or_unload(), + self._wait_for_beamline_enabled_after_load_or_unload( + self.NO_PIN_LOCATION + ), timeout=self.LOAD_TIMEOUT + self.NOT_BUSY_TIMEOUT, ) except TimeoutError as e: diff --git a/tests/devices/test_bart_robot.py b/tests/devices/test_bart_robot.py index b49a4e4d82..8c376dd078 100644 --- a/tests/devices/test_bart_robot.py +++ b/tests/devices/test_bart_robot.py @@ -1,12 +1,12 @@ import asyncio import traceback from asyncio import Event, create_task +from collections.abc import Callable from functools import partial from unittest.mock import ANY, AsyncMock, MagicMock, call, patch import pytest from ophyd_async.core import ( - AsyncStatus, callback_on_mock_put, get_mock, get_mock_put, @@ -54,16 +54,24 @@ async def fake_unload(*args, **kwargs): return device, trigger_complete, drying_complete +# Use log info messages to determine when to set the beamline enable, so we don't have to use any sleeps during testing @pytest.fixture() -async def robot_for_load(): - device = await _get_bart_robot() - set_mock_value(device.program_running, False) - set_mock_value(device.beamline_disabled, BeamlineStatus.ENABLED.value) +async def robot_for_load(bart_robot: BartRobot): + sample_location = SampleLocation(15, 10) + set_mock_value(bart_robot.beamline_disabled, BeamlineStatus.ENABLED.value) + + def _enable_beamline_and_update_pin(device: BartRobot): + _beamline_enable(device) + _set_pin_and_puck(sample_location, device) + with patch("dodal.devices.robot.LOGGER.info") as mock_log_info: mock_log_info.side_effect = partial( - _set_beamline_enabled_on_log_messages, device + _set_beamline_enabled_on_log_messages, + bart_robot, + _enable_beamline_and_update_pin, + _beamline_disable, ) - yield device + yield bart_robot @pytest.fixture() @@ -93,12 +101,13 @@ def clear_errors(device: BartRobot, *args, **kwargs): set_mock_value(device.prog_error.code, 0) -async def _get_bart_robot() -> BartRobot: +@pytest.fixture +async def bart_robot() -> BartRobot: device = BartRobot("robot", "-MO-ROBOT-01:") - device.LOAD_TIMEOUT = 1 # type: ignore - device.NOT_BUSY_TIMEOUT = 1 # type: ignore + device.NOT_BUSY_TIMEOUT = 0.3 # type: ignore + device.LOAD_TIMEOUT = 0.3 # type: ignore await device.connect(mock=True) - + set_mock_value(device.program_running, False) callback_on_mock_put(device.reset, partial(clear_errors, device)) return device @@ -109,31 +118,23 @@ def _set_fast_robot_timeouts(robot: BartRobot): robot.NOT_BUSY_TIMEOUT = 0.01 # type: ignore -async def test_bart_robot_can_be_connected_in_sim_mode(): - device = await _get_bart_robot() - await device.connect(mock=True) - - -async def test_given_robot_load_times_out_when_load_called_then_exception_contains_error_info(): - device = await _get_bart_robot() - _set_fast_robot_timeouts(device) - device._load_pin_and_puck = AsyncMock(side_effect=TimeoutError) - - set_mock_value(device.prog_error.code, (expected_error_code := 10)) - set_mock_value(device.prog_error.str, (expected_error_string := "BAD")) +async def test_given_robot_load_times_out_when_load_called_then_exception_contains_error_info( + robot_with_late_error: BartRobot, +): + device = robot_with_late_error with pytest.raises(RobotLoadError) as e: await device.set(SampleLocation(0, 0)) - assert e.value.error_code == expected_error_code - assert e.value.error_string == expected_error_string - assert str(e.value) == expected_error_string + assert e.value.error_code == EXPECTED_ERROR_CODE + assert e.value.error_string == EXPECTED_ERROR_STRING + assert str(e.value) == EXPECTED_ERROR_STRING @patch("dodal.devices.robot.LOGGER") async def test_given_program_running_when_load_pin_then_logs_the_program_name_and_times_out( - patch_logger: MagicMock, + patch_logger: MagicMock, bart_robot: BartRobot ): - device = await _get_bart_robot() + device = bart_robot _set_fast_robot_timeouts(device) program_name = "BAD_PROGRAM" set_mock_value(device.program_running, True) @@ -146,11 +147,10 @@ async def test_given_program_running_when_load_pin_then_logs_the_program_name_an @patch("dodal.devices.robot.LOGGER") async def test_given_program_not_running_but_pin_not_unmounting_when_load_pin_then_timeout( - patch_logger: MagicMock, + patch_logger: MagicMock, bart_robot: BartRobot ): - device = await _get_bart_robot() + device = bart_robot _set_fast_robot_timeouts(device) - set_mock_value(device.program_running, False) set_mock_value(device.gonio_pin_sensor, PinMounted.PIN_MOUNTED) device.load = AsyncMock(side_effect=device.load) with pytest.raises(RobotLoadError): @@ -163,10 +163,10 @@ async def test_given_program_not_running_but_pin_not_unmounting_when_load_pin_th @patch("dodal.devices.robot.LOGGER") async def test_given_program_not_running_and_pin_unmounting_but_new_pin_not_mounting_when_load_pin_then_timeout( patch_logger: MagicMock, + bart_robot: BartRobot, ): - device = await _get_bart_robot() + device = bart_robot _set_fast_robot_timeouts(device) - set_mock_value(device.program_running, False) set_mock_value(device.gonio_pin_sensor, PinMounted.NO_PIN_MOUNTED) device.load = AsyncMock(side_effect=device.load) with pytest.raises(RobotLoadError) as exc_info: @@ -181,11 +181,29 @@ async def test_given_program_not_running_and_pin_unmounting_but_new_pin_not_moun raise -def _set_beamline_enabled_on_log_messages(device: BartRobot, msg: str): +def _beamline_enable(device: BartRobot): + set_mock_value(device.beamline_disabled, BeamlineStatus.ENABLED.value) + + +def _beamline_disable(device: BartRobot): + set_mock_value(device.beamline_disabled, BeamlineStatus.DISABLED.value) + + +def _set_pin_and_puck(sample_location: SampleLocation, device: BartRobot): + set_mock_value(device.current_puck, sample_location.puck) + set_mock_value(device.current_pin, sample_location.pin) + + +def _set_beamline_enabled_on_log_messages( + device: BartRobot, + on_beamline_enable: Callable[[BartRobot], None], + on_beamline_disable: Callable[[BartRobot], None], + msg: str, +): if msg == WAIT_FOR_BEAMLINE_DISABLE_MSG: - set_mock_value(device.beamline_disabled, BeamlineStatus.DISABLED.value) + on_beamline_disable(device) elif msg == WAIT_FOR_BEAMLINE_ENABLE_MSG: - set_mock_value(device.beamline_disabled, BeamlineStatus.ENABLED.value) + on_beamline_enable(device) def _prog_error_on_unload_log_messages(device: BartRobot, msg: str): @@ -202,24 +220,59 @@ def _controller_error_on_unload_log_messages(device: BartRobot, msg: str): set_mock_value(device.prog_error.str, "Test error") -# Use log info messages to determine when to set the beamline enable, so we don't have to use any sleeps during testing -async def set_with_happy_path( - device: BartRobot, mock_log_info: MagicMock -) -> AsyncStatus: +@pytest.fixture +async def robot_with_early_error( + bart_robot: BartRobot, +): """Mocks the logic that the robot would do on a successful load.""" - mock_log_info.side_effect = partial(_set_beamline_enabled_on_log_messages, device) - set_mock_value(device.program_running, False) - set_mock_value(device.beamline_disabled, BeamlineStatus.ENABLED.value) - status = device.set(SampleLocation(15, 10)) - return status + with patch("dodal.devices.robot.LOGGER.info") as mock_log_info: + mock_log_info.side_effect = partial( + _prog_error_on_unload_log_messages, bart_robot + ) + set_mock_value(bart_robot.beamline_disabled, BeamlineStatus.ENABLED.value) + yield bart_robot + + +@pytest.fixture +def robot_which_never_reports_new_pin(bart_robot: BartRobot): + with patch("dodal.devices.robot.LOGGER.info") as mock_log_info: + set_mock_value(bart_robot.beamline_disabled, BeamlineStatus.ENABLED.value) + mock_log_info.side_effect = partial( + _set_beamline_enabled_on_log_messages, + bart_robot, + _beamline_enable, + _beamline_disable, + ) + yield bart_robot + + +EXPECTED_ERROR_CODE = 10 +EXPECTED_ERROR_STRING = "BAD" + + +@pytest.fixture +def robot_with_late_error(bart_robot: BartRobot): + _set_fast_robot_timeouts(bart_robot) + bart_robot._load_pin_and_puck = AsyncMock(side_effect=TimeoutError) + + set_mock_value(bart_robot.prog_error.code, EXPECTED_ERROR_CODE) + set_mock_value(bart_robot.prog_error.str, EXPECTED_ERROR_STRING) + with patch("dodal.devices.robot.LOGGER.info") as mock_log_info: + set_mock_value(bart_robot.beamline_disabled, BeamlineStatus.ENABLED.value) + mock_log_info.side_effect = partial( + _set_beamline_enabled_on_log_messages, + bart_robot, + _beamline_enable, + _beamline_disable, + ) + yield bart_robot -@patch("dodal.devices.robot.LOGGER.info") async def test_given_program_not_running_and_pin_unmounts_then_mounts_when_load_pin_then_pin_loaded( - mock_log_info: MagicMock, + robot_for_load: BartRobot, ): - device = await _get_bart_robot() - status = await set_with_happy_path(device, mock_log_info) + device = robot_for_load + status = device.set(SampleLocation(15, 10)) await status assert status.success assert (await device.next_puck.get_value()) == 15 @@ -227,8 +280,10 @@ async def test_given_program_not_running_and_pin_unmounts_then_mounts_when_load_ get_mock_put(device.load).assert_called_once() -async def test_waiting_for_beamline_status_raises_error_when_prog_error(): - device = await _get_bart_robot() +async def test_waiting_for_beamline_status_raises_error_when_prog_error( + bart_robot: BartRobot, +): + device = bart_robot set_mock_value(device.prog_error.code, 25) set_mock_value(device.beamline_disabled, BeamlineStatus.DISABLED.value) status = device.beamline_status_or_error(BeamlineStatus.ENABLED) @@ -236,8 +291,10 @@ async def test_waiting_for_beamline_status_raises_error_when_prog_error(): await status -async def test_waiting_for_beamline_status_raises_error_when_controller_error(): - device = await _get_bart_robot() +async def test_waiting_for_beamline_status_raises_error_when_controller_error( + bart_robot: BartRobot, +): + device = bart_robot set_mock_value(device.controller_error.code, 25) set_mock_value(device.beamline_disabled, BeamlineStatus.DISABLED.value) status = device.beamline_status_or_error(BeamlineStatus.ENABLED) @@ -245,16 +302,28 @@ async def test_waiting_for_beamline_status_raises_error_when_controller_error(): await status -async def test_given_waiting_for_beamline_to_enable_when_beamline_enabled_then_no_error_raised(): - device = await _get_bart_robot() +async def test_given_waiting_for_beamline_to_enable_when_beamline_enabled_then_no_error_raised( + bart_robot: BartRobot, +): + device = bart_robot status = create_task(device.beamline_status_or_error(BeamlineStatus.ENABLED)) set_mock_value(device.beamline_disabled, BeamlineStatus.ENABLED.value) await status +async def test_robot_load_fails_if_new_puck_and_pin_not_reported( + robot_which_never_reports_new_pin: BartRobot, +): + robot = robot_which_never_reports_new_pin + with pytest.raises(RobotLoadError, match="Robot timed out"): + await robot.set(SampleLocation(15, 10)) + + @patch("dodal.devices.robot.wait_for") -async def test_set_waits_for_both_timeouts(mock_wait_for: AsyncMock): - device = await _get_bart_robot() +async def test_set_waits_for_both_timeouts( + mock_wait_for: AsyncMock, bart_robot: BartRobot +): + device = bart_robot _set_fast_robot_timeouts(device) device._load_pin_and_puck = MagicMock() # type: ignore await device.set(SampleLocation(1, 2)) @@ -265,9 +334,9 @@ async def test_set_waits_for_both_timeouts(mock_wait_for: AsyncMock): "sample_location", [SAMPLE_LOCATION_EMPTY, SampleLocation(1, 2)] ) async def test_moving_the_robot_will_reset_controller_error_and_throw_if_error_not_cleared( - sample_location: SampleLocation, + sample_location: SampleLocation, bart_robot: BartRobot ): - device = await _get_bart_robot() + device = bart_robot _set_fast_robot_timeouts(device) set_mock_value( device.controller_error.code, ControllerErrorCode.LIGHT_CURTAIN_TRIPPED.value @@ -299,7 +368,7 @@ async def test_robot_load_resets_controller_error_and_succeeds_if_error_cleared( device.controller_error.code, ControllerErrorCode.LIGHT_CURTAIN_TRIPPED.value ) - await device.set(SampleLocation(1, 2)) + await device.set(SampleLocation(15, 10)) get_mock(device).assert_has_calls( [ @@ -320,7 +389,7 @@ async def test_robot_load_resets_prog_error_and_succeeds_if_error_cleared( device.prog_error.code, ProgErrorCode.SAMPLE_POSITION_NOT_READY.value ) - await device.set(SampleLocation(1, 2)) + await device.set(SampleLocation(15, 10)) get_mock(device).assert_has_calls( [