From 0bf03887c5e35105de43d62f42145f048e36fa00 Mon Sep 17 00:00:00 2001 From: Rares Gaia Date: Thu, 6 Nov 2025 14:57:44 +0200 Subject: [PATCH] feat: handle individual world configuration Since having one config for the entire configuration couldn't handle partial worlds update, we implemented per world configuration tasks. Config manager takes care of scheduling and canceling the tasks based on some diffs between old and new config. Also, config manager will orchestrate processing configs based on what it needs to cancel. Updated the code with proper error handlers for cancelling tasks so it properly propagates to parent task. With this, world configuration can be tracked individually, making easier to implement retry on failure in the future. --- infscale/configs/job.py | 9 ++- infscale/execution/config_manager.py | 99 +++++++++++++++------------- infscale/execution/control.py | 25 ++++++- infscale/execution/pipeline.py | 44 ++++++++++--- infscale/execution/router.py | 4 ++ 5 files changed, 122 insertions(+), 59 deletions(-) diff --git a/infscale/configs/job.py b/infscale/configs/job.py index d7d033e8..bb0cc8dd 100644 --- a/infscale/configs/job.py +++ b/infscale/configs/job.py @@ -157,11 +157,16 @@ def get_worlds_to_configure( """Compare two specs and return new and updated worlds.""" helper = ServeConfigHelper() - curr_worlds = helper._get_worlds(curr_spec) new_worlds = helper._get_worlds(new_spec) + new_world_names = set(new_worlds.keys()) + + # if current spec is not available, + # return worlds from the new spec. + if curr_spec is None: + return new_world_names + curr_worlds = helper._get_worlds(curr_spec) curr_world_names = set(curr_worlds.keys()) - new_world_names = set(new_worlds.keys()) deploy_worlds = new_world_names - curr_world_names diff --git a/infscale/execution/config_manager.py b/infscale/execution/config_manager.py index db6e805b..320e7757 100644 --- a/infscale/execution/config_manager.py +++ b/infscale/execution/config_manager.py @@ -30,38 +30,43 @@ class ConfigManager: def __init__(self): """Initialize config manager instance.""" self._loop = asyncio.get_event_loop() - self._task: asyncio.Task | None = None - self._event = asyncio.Event() + # semaphore event for back-to-back configs + self._config_event = asyncio.Event() + self._config_event.set() + self._world_tasks: dict[str, asyncio.Task] = {} self._spec: ServeConfig = None - self._event.set() self._curr_worlds_to_configure: set[str] = set() - self._cancel_cur_cfg = False self._world_infos: dict[str, WorldInfo] = {} + self.worlds_to_cancel = set() - def handle_new_spec(self, spec: ServeConfig) -> None: + async def handle_new_spec(self, spec: ServeConfig) -> None: """Handle new spec.""" - self._cancel_cur_cfg = self._should_cancel_current(spec) + new_worlds_to_configure = ServeConfig.get_worlds_to_configure(self._spec, spec) + self.worlds_to_cancel = new_worlds_to_configure & self._curr_worlds_to_configure self._spec = spec - def _should_cancel_current(self, spec: ServeConfig) -> bool: - """Decide if current configuration should be cancelled.""" - if self._spec is None: - return False + if len(self.worlds_to_cancel): + await self._cancel_world_configuration(self.worlds_to_cancel) - new_worlds_to_configure = ServeConfig.get_worlds_to_configure(self._spec, spec) + # wait for current configuration to finish + await self._config_event.wait() + + # block configuration until current config is processed + # when a new spec is waiting, we want to block the execution until + # the current one is under configuration. + self._config_event.clear() - # cancel if the new config affects worlds currently being configured - # TODO: if there's a overlap between new worlds and curr worlds we cancel - # current configuration. This needs to be fixed, to cancel only the worlds that - # are affected (eg new_worlds & curr_worlds) - return not new_worlds_to_configure.isdisjoint(self._curr_worlds_to_configure) + def reset_state(self) -> None: + """Reset any state that is kept between configs.""" + self._curr_worlds_to_configure = set() + self.worlds_to_cancel = set() - def set_worlds_to_configure(self, world_names: set[str]) -> None: - """Set the world names currently being configured.""" - self._curr_worlds_to_configure = world_names + def unblock_next_config(self) -> None: + """Set task event and unblock next config process.""" + self._config_event.set() - def set_world_infos(self, worlds: list[WorldInfo]) -> None: - """Set new world infos.""" + def update_world_infos(self, worlds: list[WorldInfo]) -> None: + """Update world infos.""" for world_info in worlds: self._world_infos[world_info.name] = world_info @@ -87,25 +92,29 @@ def get_worlds_to_add_and_remove(self) -> tuple[list[WorldInfo], list[WorldInfo] worlds_to_add = [new_world_infos[name] for name in new - cur] worlds_to_remove = [new_world_infos[name] for name in cur - new] - return worlds_to_add, worlds_to_remove - - async def schedule(self, coro_factory: Callable[[], Awaitable[None]]): - """Cancel any in-progress configure and schedule a new one.""" - # wait for current to finish if we do not want to cancel - if not self._cancel_cur_cfg: - await self._event.wait() + self._curr_worlds_to_configure = new - cur - # cancel current if running - if self._task and not self._task.done(): - self._task.cancel() - try: - await self._task - except asyncio.CancelledError: - pass + return worlds_to_add, worlds_to_remove - # block again for new run - self._event.clear() - self._task = self._loop.create_task(self._run(coro_factory)) + async def _cancel_world_configuration(self, world_names: set[str]): + """Cancel only worlds that are impacted by new spec.""" + coroutines = [self._cancel_world(w) for w in world_names] + await asyncio.gather(*coroutines, return_exceptions=True) + + def schedule_world_cfg( + self, world_info: WorldInfo, coro_factory: Callable[[], Awaitable[None]] + ): + """Schedule configuration for a single world.""" + task = self._loop.create_task(self._run_world(world_info, coro_factory)) + self._world_tasks[world_info.name] = task + return task + + async def _cancel_world(self, world_name: str): + """Cancel an in-progress world config task.""" + task = self._world_tasks.pop(world_name, None) + if task and not task.done(): + task.cancel() + raise asyncio.CancelledError def _build_world_infos(self) -> dict[str, WorldInfo]: world_infos: dict[str, WorldInfo] = {} @@ -161,13 +170,13 @@ def _build_world_infos(self) -> dict[str, WorldInfo]: return world_infos - async def _run(self, coro_factory: Callable[[], Awaitable[None]]): - """Run coroutine factory.""" + async def _run_world( + self, world_info: WorldInfo, coro_factory: Callable[[], Awaitable[None]] + ): + """Run and cleanup world configuration.""" try: - await coro_factory() + await coro_factory(world_info) except asyncio.CancelledError: - pass + raise finally: - # reset class attributes and events - self._event.set() - self._curr_worlds_to_configure = set() + self._world_tasks.pop(world_info.name, None) diff --git a/infscale/execution/control.py b/infscale/execution/control.py index 1c769f1d..938520b1 100644 --- a/infscale/execution/control.py +++ b/infscale/execution/control.py @@ -162,10 +162,31 @@ async def setup(self) -> None: if self.rank == 0: self._server_task = asyncio.create_task(self._setup_server(setup_done)) else: - _ = asyncio.create_task(self._setup_client(setup_done)) + client_task = asyncio.create_task(self._setup_client(setup_done)) # wait until setting up either server or client is done - await setup_done.wait() + try: + await setup_done.wait() + except asyncio.CancelledError as e: + # since both _setup_server and _setup_client are spawned as separate tasks + # and the setup itself is a task, we need to handle parent task cancellation + # on the awaited line, since cancellation only propagates through awaited calls + # here, await setup_done.wait() is the propagation point from parent task to child tasks + # so we need to cancel child tasks whenever CancelledError is received + if self._server_task and not self._server_task.done(): + self._server_task.cancel() + try: + await self._server_task + except asyncio.CancelledError: + pass + + if client_task and not client_task.done(): + client_task.cancel() + try: + await client_task + except asyncio.CancelledError: + pass + raise def cleanup(self) -> None: if self._server_task is not None: diff --git a/infscale/execution/pipeline.py b/infscale/execution/pipeline.py index 66663ebf..33c88960 100644 --- a/infscale/execution/pipeline.py +++ b/infscale/execution/pipeline.py @@ -104,6 +104,8 @@ async def _configure_multiworld(self, world_info: WorldInfo) -> None: port=port, device=self.device, ) + except asyncio.CancelledError: + logger.warning(f"multiworld configuration cancelled for {world_info.name}") except Exception as e: logger.error(f"failed to initialize a multiworld {name}: {e}") condition = self._status != WorkerStatus.UPDATING @@ -130,12 +132,20 @@ def _set_n_send_worker_status(self, status: WorkerStatus) -> None: self.wcomm.send(msg) async def _configure_control_channel(self, world_info: WorldInfo) -> None: - await world_info.channel.setup() + try: + await world_info.channel.setup() - await world_info.channel.wait_readiness() + await world_info.channel.wait_readiness() + except asyncio.CancelledError: + logger.warning(f"channel configuration cancelled for {world_info}") def _reset_multiworld(self, world_info: WorldInfo) -> None: - self.world_manager.remove_world(world_info.multiworld_name) + try: + self.world_manager.remove_world(world_info.multiworld_name) + except ValueError as e: + logger.warning(f"failed to reset {world_info.multiworld_name}: {e}") + return + logger.info(f"remove world {world_info.multiworld_name} from multiworld") def _reset_control_channel(self, world_info: WorldInfo) -> None: @@ -186,7 +196,12 @@ async def _configure(self) -> None: tasks = [] # 1. set up control channel for world_info in worlds_to_add: - task = self._configure_control_channel(world_info) + if world_info.name in self.config_manager.worlds_to_cancel: + continue + + task = self.config_manager.schedule_world_cfg( + world_info, self._configure_control_channel + ) tasks.append(task) # TODO: this doesn't handle partial success @@ -196,7 +211,12 @@ async def _configure(self) -> None: tasks = [] # 2. set up multiworld for world_info in worlds_to_add: - task = self._configure_multiworld(world_info) + if world_info.name in self.config_manager.worlds_to_cancel: + continue + + task = self.config_manager.schedule_world_cfg( + world_info, self._configure_multiworld + ) tasks.append(task) # TODO: this doesn't handle partial success @@ -204,7 +224,7 @@ async def _configure(self) -> None: await asyncio.gather(*tasks) # update world_info for added worlds - self.config_manager.set_world_infos(worlds_to_add) + self.config_manager.update_world_infos(worlds_to_add) # configure router with worlds to add and remove await self.router.configure( @@ -212,6 +232,7 @@ async def _configure(self) -> None: self.device, worlds_to_add, worlds_to_remove, + self.config_manager.worlds_to_cancel, ) # handle unnecessary world @@ -226,6 +247,9 @@ async def _configure(self) -> None: worker_status = WorkerStatus.RUNNING if is_first_run else WorkerStatus.UPDATED + # config is done, do cleanup in config runner + self.config_manager.reset_state() + self.config_manager.unblock_next_config() self._set_n_send_worker_status(worker_status) self.cfg_event.set() @@ -425,16 +449,16 @@ async def _handle_config(self, spec: ServeConfig) -> None: if spec is None: return - self.config_manager.handle_new_spec(spec) - self._configure_variables(spec) self._inspector.configure(self.spec) self._initialize_once() - # (re)configure the pipeline - await self.config_manager.schedule(self._configure) + await self.config_manager.handle_new_spec(spec) + # run configure as a separate task since we need to unblock receiving + # a new config to be processed when current configuration is finished + self._configure_task = asyncio.create_task(self._configure()) def _configure_variables(self, spec: ServeConfig) -> None: """Set variables that need to be updated.""" diff --git a/infscale/execution/router.py b/infscale/execution/router.py index 5c03a3ef..7c8740ae 100644 --- a/infscale/execution/router.py +++ b/infscale/execution/router.py @@ -114,6 +114,7 @@ async def configure( device=torch.device("cpu"), worlds_to_add: list[WorldInfo] = [], worlds_to_remove: list[WorldInfo] = [], + worlds_to_cancel: set[str] = set(), ) -> None: """(Re)configure router.""" self._is_server = spec.is_server @@ -131,6 +132,9 @@ async def configure( self._fwder.set_stickiness(sticky) for world_info in worlds_to_add: + if world_info.name in worlds_to_cancel: + continue + cancellable = asyncio.Event() if world_info.me == 0: # I am a receiver from other task = asyncio.create_task(self._recv(world_info, cancellable))