11import asyncio
2- from asyncio import FIRST_COMPLETED , CancelledError , Task , wait_for
2+ from asyncio import ALL_COMPLETED , FIRST_COMPLETED , CancelledError , Task , wait_for
33from dataclasses import dataclass
44from enum import IntEnum
55
@@ -92,6 +92,8 @@ class BartRobot(StandardReadable, Movable[SampleLocation]):
9292 CRYO_MODE_WARM = 0.0
9393 CRYO_MODE_CRYO = 1.0
9494
95+ NO_PIN_LOCATION = SampleLocation (0 , 0 )
96+
9597 def __init__ (self , prefix : str , name : str = "" ) -> None :
9698 with self .add_children_as_readables (StandardReadableFormat .HINTED_SIGNAL ):
9799 self .barcode = epics_signal_r (str , prefix + "BARCODE" )
@@ -132,10 +134,17 @@ def __init__(self, prefix: str, name: str = "") -> None:
132134 )
133135 super ().__init__ (name = name )
134136
135- async def beamline_status_or_error (self , expected_state : BeamlineStatus ):
137+ async def beamline_status_or_error (
138+ self , expected_state : BeamlineStatus , sample_location : SampleLocation = None
139+ ):
136140 """This co-routine will finish when either the beamline reaches the specified
137141 state or the robot gives an error (whichever happens first). In the case where
138142 there is an error a RobotLoadError error is raised.
143+
144+ Args:
145+ expected_state (BeamlineStatus): The beamline state to wait for
146+ sample_location (SampleLocation): The loaded puck and pin to wait for, or None to not
147+ wait.
139148 """
140149
141150 async def raise_if_error ():
@@ -149,13 +158,24 @@ async def raise_if_error():
149158 async def wait_for_expected_state ():
150159 await wait_for_value (self .beamline_disabled , expected_state .value , None )
151160
161+ async def wait_for_puck ():
162+ await wait_for_value (self .current_puck , sample_location .puck , None )
163+
164+ async def wait_for_pin ():
165+ await wait_for_value (self .current_pin , sample_location .pin , None )
166+
167+ check_for_error_task = Task (raise_if_error ())
152168 tasks = [
153- (Task (raise_if_error ())),
154169 (Task (wait_for_expected_state ())),
155170 ]
171+ if sample_location :
172+ tasks += [Task (wait_for_puck ()), Task (wait_for_pin ())]
173+ check_for_completion_conditions = asyncio .create_task (
174+ asyncio .wait (tasks , return_when = ALL_COMPLETED )
175+ )
156176 try :
157177 finished , unfinished = await asyncio .wait (
158- tasks ,
178+ [ check_for_error_task , check_for_completion_conditions ] ,
159179 return_when = FIRST_COMPLETED ,
160180 )
161181 for task in unfinished :
@@ -167,7 +187,7 @@ async def wait_for_expected_state():
167187 # in the current task, when it propagates to here we should cancel all pending tasks before bubbling up
168188 for task in tasks :
169189 task .cancel ()
170-
190+ check_for_error_task . cancel ()
171191 raise
172192
173193 async def _load_pin_and_puck (self , sample_location : SampleLocation ):
@@ -187,15 +207,17 @@ async def _load_pin_and_puck(self, sample_location: SampleLocation):
187207 set_and_wait_for_value (self .next_pin , sample_location .pin ),
188208 )
189209 await self .load .trigger ()
190- await self ._wait_for_beamline_enabled_after_load_or_unload ()
210+ await self ._wait_for_beamline_enabled_after_load_or_unload (sample_location )
191211
192- async def _wait_for_beamline_enabled_after_load_or_unload (self ):
212+ async def _wait_for_beamline_enabled_after_load_or_unload (
213+ self , sample_location : SampleLocation = None
214+ ):
193215 if await self .beamline_disabled .get_value () == BeamlineStatus .ENABLED .value :
194216 LOGGER .info (WAIT_FOR_BEAMLINE_DISABLE_MSG )
195217 await self .beamline_status_or_error (BeamlineStatus .DISABLED )
196218
197219 LOGGER .info (WAIT_FOR_BEAMLINE_ENABLE_MSG )
198- await self .beamline_status_or_error (BeamlineStatus .ENABLED )
220+ await self .beamline_status_or_error (BeamlineStatus .ENABLED , sample_location )
199221
200222 @AsyncStatus .wrap
201223 async def set (self , value : SampleLocation ):
@@ -218,7 +240,9 @@ async def set(self, value: SampleLocation):
218240 else :
219241 await self .unload .trigger (timeout = self .LOAD_TIMEOUT )
220242 await wait_for (
221- self ._wait_for_beamline_enabled_after_load_or_unload (),
243+ self ._wait_for_beamline_enabled_after_load_or_unload (
244+ self .NO_PIN_LOCATION
245+ ),
222246 timeout = self .LOAD_TIMEOUT + self .NOT_BUSY_TIMEOUT ,
223247 )
224248 except TimeoutError as e :
0 commit comments