11import asyncio
2- from asyncio import FIRST_COMPLETED , CancelledError , Task , wait_for
2+ from asyncio import ALL_COMPLETED , FIRST_COMPLETED , CancelledError , Task , wait_for
33from dataclasses import dataclass
44from enum import IntEnum
55
@@ -113,6 +113,8 @@ class BartRobot(StandardReadable, Movable[SampleLocation]):
113113 CRYO_MODE_WARM = 0.0
114114 CRYO_MODE_CRYO = 1.0
115115
116+ NO_PIN_LOCATION = SampleLocation (0 , 0 )
117+
116118 def __init__ (self , prefix : str , name : str = "" ) -> None :
117119 with self .add_children_as_readables (StandardReadableFormat .HINTED_SIGNAL ):
118120 self .barcode = epics_signal_r (str , prefix + "BARCODE" )
@@ -153,10 +155,19 @@ def __init__(self, prefix: str, name: str = "") -> None:
153155 )
154156 super ().__init__ (name = name )
155157
156- async def beamline_status_or_error (self , expected_state : BeamlineStatus ):
158+ async def beamline_status_or_error (
159+ self ,
160+ expected_state : BeamlineStatus ,
161+ sample_location : SampleLocation | None = None ,
162+ ):
157163 """This co-routine will finish when either the beamline reaches the specified
158164 state or the robot gives an error (whichever happens first). In the case where
159165 there is an error a RobotLoadError error is raised.
166+
167+ Args:
168+ expected_state (BeamlineStatus): The beamline state to wait for
169+ sample_location (SampleLocation): The loaded puck and pin to wait for, or None to not
170+ wait.
160171 """
161172
162173 async def raise_if_prog_error ():
@@ -182,14 +193,32 @@ async def raise_if_ctl_error():
182193 async def wait_for_expected_state ():
183194 await wait_for_value (self .beamline_disabled , expected_state .value , None )
184195
196+ async def wait_for_puck (_sample_location : SampleLocation ):
197+ await wait_for_value (self .current_puck , _sample_location .puck , None )
198+
199+ async def wait_for_pin (_sample_location : SampleLocation ):
200+ await wait_for_value (self .current_pin , _sample_location .pin , None )
201+
202+ check_for_prog_error_task = Task (raise_if_prog_error ())
203+ check_for_ctl_error_task = Task (raise_if_ctl_error ())
185204 tasks = [
186- (Task (raise_if_prog_error ())),
187- (Task (raise_if_ctl_error ())),
188205 (Task (wait_for_expected_state ())),
189206 ]
207+ if sample_location :
208+ tasks += [
209+ Task (wait_for_puck (sample_location )),
210+ Task (wait_for_pin (sample_location )),
211+ ]
212+ check_for_completion_conditions = asyncio .create_task (
213+ asyncio .wait (tasks , return_when = ALL_COMPLETED )
214+ )
190215 try :
191216 finished , unfinished = await asyncio .wait (
192- tasks ,
217+ [
218+ check_for_prog_error_task ,
219+ check_for_ctl_error_task ,
220+ check_for_completion_conditions ,
221+ ],
193222 return_when = FIRST_COMPLETED ,
194223 )
195224 for task in unfinished :
@@ -201,7 +230,8 @@ async def wait_for_expected_state():
201230 # in the current task, when it propagates to here we should cancel all pending tasks before bubbling up
202231 for task in tasks :
203232 task .cancel ()
204-
233+ check_for_prog_error_task .cancel ()
234+ check_for_ctl_error_task .cancel ()
205235 raise
206236
207237 async def _check_errors_and_clear_if_retryable (self ):
@@ -243,15 +273,17 @@ async def _load_pin_and_puck(self, sample_location: SampleLocation):
243273 set_and_wait_for_value (self .next_pin , sample_location .pin ),
244274 )
245275 await self .load .trigger ()
246- await self ._wait_for_beamline_enabled_after_load_or_unload ()
276+ await self ._wait_for_beamline_enabled_after_load_or_unload (sample_location )
247277
248- async def _wait_for_beamline_enabled_after_load_or_unload (self ):
278+ async def _wait_for_beamline_enabled_after_load_or_unload (
279+ self , sample_location : SampleLocation
280+ ):
249281 if await self .beamline_disabled .get_value () == BeamlineStatus .ENABLED .value :
250282 LOGGER .info (WAIT_FOR_BEAMLINE_DISABLE_MSG )
251283 await self .beamline_status_or_error (BeamlineStatus .DISABLED )
252284
253285 LOGGER .info (WAIT_FOR_BEAMLINE_ENABLE_MSG )
254- await self .beamline_status_or_error (BeamlineStatus .ENABLED )
286+ await self .beamline_status_or_error (BeamlineStatus .ENABLED , sample_location )
255287
256288 @AsyncStatus .wrap
257289 async def set (self , value : SampleLocation ):
@@ -275,7 +307,9 @@ async def set(self, value: SampleLocation):
275307 else :
276308 await self .unload .trigger (timeout = self .LOAD_TIMEOUT )
277309 await wait_for (
278- self ._wait_for_beamline_enabled_after_load_or_unload (),
310+ self ._wait_for_beamline_enabled_after_load_or_unload (
311+ self .NO_PIN_LOCATION
312+ ),
279313 timeout = self .LOAD_TIMEOUT + self .NOT_BUSY_TIMEOUT ,
280314 )
281315 except TimeoutError as e :
0 commit comments