Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions infscale/configs/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,11 +157,16 @@ def get_worlds_to_configure(
"""Compare two specs and return new and updated worlds."""
helper = ServeConfigHelper()

curr_worlds = helper._get_worlds(curr_spec)
new_worlds = helper._get_worlds(new_spec)
new_world_names = set(new_worlds.keys())

# if current spec is not available,
# return worlds from the new spec.
if curr_spec is None:
return new_world_names

curr_worlds = helper._get_worlds(curr_spec)
curr_world_names = set(curr_worlds.keys())
new_world_names = set(new_worlds.keys())

deploy_worlds = new_world_names - curr_world_names

Expand Down
99 changes: 54 additions & 45 deletions infscale/execution/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,38 +30,43 @@ class ConfigManager:
def __init__(self):
"""Initialize config manager instance."""
self._loop = asyncio.get_event_loop()
self._task: asyncio.Task | None = None
self._event = asyncio.Event()
# semaphore event for back-to-back configs
self._config_event = asyncio.Event()
self._config_event.set()
self._world_tasks: dict[str, asyncio.Task] = {}
self._spec: ServeConfig = None
self._event.set()
self._curr_worlds_to_configure: set[str] = set()
self._cancel_cur_cfg = False
self._world_infos: dict[str, WorldInfo] = {}
self.worlds_to_cancel = set()

def handle_new_spec(self, spec: ServeConfig) -> None:
async def handle_new_spec(self, spec: ServeConfig) -> None:
"""Handle new spec."""
self._cancel_cur_cfg = self._should_cancel_current(spec)
new_worlds_to_configure = ServeConfig.get_worlds_to_configure(self._spec, spec)
self.worlds_to_cancel = new_worlds_to_configure & self._curr_worlds_to_configure
self._spec = spec

def _should_cancel_current(self, spec: ServeConfig) -> bool:
"""Decide if current configuration should be cancelled."""
if self._spec is None:
return False
if len(self.worlds_to_cancel):
await self._cancel_world_configuration(self.worlds_to_cancel)

new_worlds_to_configure = ServeConfig.get_worlds_to_configure(self._spec, spec)
# wait for current configuration to finish
await self._config_event.wait()

# block configuration until current config is processed
# when a new spec is waiting, we want to block the execution until
# the current one is under configuration.
self._config_event.clear()

# cancel if the new config affects worlds currently being configured
# TODO: if there's a overlap between new worlds and curr worlds we cancel
# current configuration. This needs to be fixed, to cancel only the worlds that
# are affected (eg new_worlds & curr_worlds)
return not new_worlds_to_configure.isdisjoint(self._curr_worlds_to_configure)
def reset_state(self) -> None:
"""Reset any state that is kept between configs."""
self._curr_worlds_to_configure = set()
self.worlds_to_cancel = set()

def set_worlds_to_configure(self, world_names: set[str]) -> None:
"""Set the world names currently being configured."""
self._curr_worlds_to_configure = world_names
def unblock_next_config(self) -> None:
"""Set task event and unblock next config process."""
self._config_event.set()

def set_world_infos(self, worlds: list[WorldInfo]) -> None:
"""Set new world infos."""
def update_world_infos(self, worlds: list[WorldInfo]) -> None:
"""Update world infos."""
for world_info in worlds:
self._world_infos[world_info.name] = world_info

Expand All @@ -87,25 +92,29 @@ def get_worlds_to_add_and_remove(self) -> tuple[list[WorldInfo], list[WorldInfo]
worlds_to_add = [new_world_infos[name] for name in new - cur]
worlds_to_remove = [new_world_infos[name] for name in cur - new]

return worlds_to_add, worlds_to_remove

async def schedule(self, coro_factory: Callable[[], Awaitable[None]]):
"""Cancel any in-progress configure and schedule a new one."""
# wait for current to finish if we do not want to cancel
if not self._cancel_cur_cfg:
await self._event.wait()
self._curr_worlds_to_configure = new - cur

# cancel current if running
if self._task and not self._task.done():
self._task.cancel()
try:
await self._task
except asyncio.CancelledError:
pass
return worlds_to_add, worlds_to_remove

# block again for new run
self._event.clear()
self._task = self._loop.create_task(self._run(coro_factory))
async def _cancel_world_configuration(self, world_names: set[str]):
"""Cancel only worlds that are impacted by new spec."""
coroutines = [self._cancel_world(w) for w in world_names]
await asyncio.gather(*coroutines, return_exceptions=True)

def schedule_world_cfg(
self, world_info: WorldInfo, coro_factory: Callable[[], Awaitable[None]]
):
"""Schedule configuration for a single world."""
task = self._loop.create_task(self._run_world(world_info, coro_factory))
self._world_tasks[world_info.name] = task
return task

async def _cancel_world(self, world_name: str):
"""Cancel an in-progress world config task."""
task = self._world_tasks.pop(world_name, None)
if task and not task.done():
task.cancel()
raise asyncio.CancelledError

def _build_world_infos(self) -> dict[str, WorldInfo]:
world_infos: dict[str, WorldInfo] = {}
Expand Down Expand Up @@ -161,13 +170,13 @@ def _build_world_infos(self) -> dict[str, WorldInfo]:

return world_infos

async def _run(self, coro_factory: Callable[[], Awaitable[None]]):
"""Run coroutine factory."""
async def _run_world(
self, world_info: WorldInfo, coro_factory: Callable[[], Awaitable[None]]
):
"""Run and cleanup world configuration."""
try:
await coro_factory()
await coro_factory(world_info)
except asyncio.CancelledError:
pass
raise
finally:
# reset class attributes and events
self._event.set()
self._curr_worlds_to_configure = set()
self._world_tasks.pop(world_info.name, None)
25 changes: 23 additions & 2 deletions infscale/execution/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,31 @@ async def setup(self) -> None:
if self.rank == 0:
self._server_task = asyncio.create_task(self._setup_server(setup_done))
else:
_ = asyncio.create_task(self._setup_client(setup_done))
client_task = asyncio.create_task(self._setup_client(setup_done))

# wait until setting up either server or client is done
await setup_done.wait()
try:
await setup_done.wait()
except asyncio.CancelledError as e:
# since both _setup_server and _setup_client are spawned as separate tasks
# and the setup itself is a task, we need to handle parent task cancellation
# on the awaited line, since cancellation only propagates through awaited calls
# here, await setup_done.wait() is the propagation point from parent task to child tasks
# so we need to cancel child tasks whenever CancelledError is received
if self._server_task and not self._server_task.done():
self._server_task.cancel()
try:
await self._server_task
except asyncio.CancelledError:
pass

if client_task and not client_task.done():
client_task.cancel()
try:
await client_task
except asyncio.CancelledError:
pass
raise

def cleanup(self) -> None:
if self._server_task is not None:
Expand Down
44 changes: 34 additions & 10 deletions infscale/execution/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ async def _configure_multiworld(self, world_info: WorldInfo) -> None:
port=port,
device=self.device,
)
except asyncio.CancelledError:
logger.warning(f"multiworld configuration cancelled for {world_info.name}")
except Exception as e:
logger.error(f"failed to initialize a multiworld {name}: {e}")
condition = self._status != WorkerStatus.UPDATING
Expand All @@ -130,12 +132,20 @@ def _set_n_send_worker_status(self, status: WorkerStatus) -> None:
self.wcomm.send(msg)

async def _configure_control_channel(self, world_info: WorldInfo) -> None:
await world_info.channel.setup()
try:
await world_info.channel.setup()

await world_info.channel.wait_readiness()
await world_info.channel.wait_readiness()
except asyncio.CancelledError:
logger.warning(f"channel configuration cancelled for {world_info}")

def _reset_multiworld(self, world_info: WorldInfo) -> None:
self.world_manager.remove_world(world_info.multiworld_name)
try:
self.world_manager.remove_world(world_info.multiworld_name)
except ValueError as e:
logger.warning(f"failed to reset {world_info.multiworld_name}: {e}")
return

logger.info(f"remove world {world_info.multiworld_name} from multiworld")

def _reset_control_channel(self, world_info: WorldInfo) -> None:
Expand Down Expand Up @@ -186,7 +196,12 @@ async def _configure(self) -> None:
tasks = []
# 1. set up control channel
for world_info in worlds_to_add:
task = self._configure_control_channel(world_info)
if world_info.name in self.config_manager.worlds_to_cancel:
continue

task = self.config_manager.schedule_world_cfg(
world_info, self._configure_control_channel
)
tasks.append(task)

# TODO: this doesn't handle partial success
Expand All @@ -196,22 +211,28 @@ async def _configure(self) -> None:
tasks = []
# 2. set up multiworld
for world_info in worlds_to_add:
task = self._configure_multiworld(world_info)
if world_info.name in self.config_manager.worlds_to_cancel:
continue

task = self.config_manager.schedule_world_cfg(
world_info, self._configure_multiworld
)
tasks.append(task)

# TODO: this doesn't handle partial success
# a mechanism to handle a failure is left as a todo
await asyncio.gather(*tasks)

# update world_info for added worlds
self.config_manager.set_world_infos(worlds_to_add)
self.config_manager.update_world_infos(worlds_to_add)

# configure router with worlds to add and remove
await self.router.configure(
self.spec,
self.device,
worlds_to_add,
worlds_to_remove,
self.config_manager.worlds_to_cancel,
)

# handle unnecessary world
Expand All @@ -226,6 +247,9 @@ async def _configure(self) -> None:

worker_status = WorkerStatus.RUNNING if is_first_run else WorkerStatus.UPDATED

# config is done, do cleanup in config runner
self.config_manager.reset_state()
self.config_manager.unblock_next_config()
self._set_n_send_worker_status(worker_status)

self.cfg_event.set()
Expand Down Expand Up @@ -425,16 +449,16 @@ async def _handle_config(self, spec: ServeConfig) -> None:
if spec is None:
return

self.config_manager.handle_new_spec(spec)

self._configure_variables(spec)

self._inspector.configure(self.spec)

self._initialize_once()

# (re)configure the pipeline
await self.config_manager.schedule(self._configure)
await self.config_manager.handle_new_spec(spec)
# run configure as a separate task since we need to unblock receiving
# a new config to be processed when current configuration is finished
self._configure_task = asyncio.create_task(self._configure())

def _configure_variables(self, spec: ServeConfig) -> None:
"""Set variables that need to be updated."""
Expand Down
4 changes: 4 additions & 0 deletions infscale/execution/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ async def configure(
device=torch.device("cpu"),
worlds_to_add: list[WorldInfo] = [],
worlds_to_remove: list[WorldInfo] = [],
worlds_to_cancel: set[str] = set(),
) -> None:
"""(Re)configure router."""
self._is_server = spec.is_server
Expand All @@ -131,6 +132,9 @@ async def configure(
self._fwder.set_stickiness(sticky)

for world_info in worlds_to_add:
if world_info.name in worlds_to_cancel:
continue

cancellable = asyncio.Event()
if world_info.me == 0: # I am a receiver from other
task = asyncio.create_task(self._recv(world_info, cancellable))
Expand Down