diff --git a/adaptive/runner.py b/adaptive/runner.py index 9a8ff0fd1..9368e9dd6 100644 --- a/adaptive/runner.py +++ b/adaptive/runner.py @@ -259,6 +259,11 @@ def _process_futures(self, done_futs): try: y = fut.result() t = time.time() - fut.start_time # total execution time + except asyncio.CancelledError: + # Cleanup + self._to_retry.pop(pid, None) + self._tracebacks.pop(pid, None) + self._id_to_point.pop(pid, None) except Exception as e: self._tracebacks[pid] = traceback.format_exc() self._to_retry[pid] = self._to_retry.get(pid, 0) + 1 @@ -649,6 +654,7 @@ def __init__( self.task = self.ioloop.create_task(self._run()) self.saving_task = None + self.callbacks = [] if in_ipynb() and not self.ioloop.is_running(): warnings.warn( "The runner has been scheduled, but the asyncio " @@ -753,6 +759,51 @@ def elapsed_time(self): end_time = time.time() return end_time - self.start_time + def cancel_point( + self, future: asyncio.Future | None = None, point: Any | None = None + ): + """Cancel a future or point that is currently being evaluated. + + Either the ``future`` or the ``point`` must be provided. + + Parameters + ---------- + future : asyncio.Future + The future that is currently being evaluated. + point + The point that should be cancelled. + """ + if point is None and future is None: + raise ValueError("Either point or future must be given") + if future is None: + future = next(fut for fut, p in self.pending_points if p == point) + future.cancel() + + def start_periodic_callback( + self, + method: Callable[[AsyncRunner], None], + interval: int = 30, + ): + """Start a periodic callback that calls the given method on the runner. + + Parameters + ---------- + method : callable + The method to call periodically. + interval : int + The interval in seconds between the calls. + """ + + async def _callback(): + while self.status() == "running": + method(self) + await asyncio.sleep(interval) + method(self) # one last time + + task = self.ioloop.create_task(_callback()) + self.callbacks.append(task) + return task + def start_periodic_saving( self, save_kwargs: dict[str, Any] | None = None, @@ -781,6 +832,8 @@ def start_periodic_saving( ... save_kwargs=dict(fname='data/test.pickle'), ... interval=600) """ + if self.saving_task is not None: + raise RuntimeError("Already saving.") def default_save(learner): learner.save(**save_kwargs) @@ -788,15 +841,11 @@ def default_save(learner): if method is None: method = default_save if save_kwargs is None: - raise ValueError("Must provide `save_kwargs` if method=None.") - - async def _saver(): - while self.status() == "running": - method(self.learner) - await asyncio.sleep(interval) - method(self.learner) # one last time + raise ValueError("Must provide `save_kwargs` if `method=None`.") - self.saving_task = self.ioloop.create_task(_saver()) + self.saving_task = self.start_periodic_callback( + lambda r: method(r.learner), interval=interval + ) return self.saving_task