Skip to content

Commit a229ddb

Browse files
committed
feat: handle individual world configuration
Since having one config for the entire configuration couldn't handle partial worlds update, we implemented per world configuration tasks. Config manager takes care of scheduling and canceling the tasks based on some diffs between old and new config. Also, config manager will orchestrate processing configs based on what it needs to cancel. Updated the code with proper error handlers for cancelling tasks so it properly propagates to parent task. With this, world configuration can be tracked individually, making easier to implement retry on failure in the future.
1 parent aa96e2d commit a229ddb

File tree

4 files changed

+159
-86
lines changed

4 files changed

+159
-86
lines changed

infscale/configs/job.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,11 +157,16 @@ def get_worlds_to_configure(
157157
"""Compare two specs and return new and updated worlds."""
158158
helper = ServeConfigHelper()
159159

160-
curr_worlds = helper._get_worlds(curr_spec)
161160
new_worlds = helper._get_worlds(new_spec)
161+
new_world_names = set(new_worlds.keys())
162+
163+
# if current spec is not available,
164+
# return worlds from the new spec.
165+
if curr_spec is None:
166+
return new_world_names
162167

168+
curr_worlds = helper._get_worlds(curr_spec)
163169
curr_world_names = set(curr_worlds.keys())
164-
new_world_names = set(new_worlds.keys())
165170

166171
deploy_worlds = new_world_names - curr_world_names
167172

infscale/execution/config_manager.py

Lines changed: 91 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -30,88 +30,111 @@ class ConfigManager:
3030
def __init__(self):
3131
"""Initialize config manager instance."""
3232
self._loop = asyncio.get_event_loop()
33-
self._task: asyncio.Task | None = None
34-
self._event = asyncio.Event()
35-
self._spec: ServeConfig = None
36-
self._event.set()
37-
self._curr_worlds_to_configure: set[str] = set()
38-
self._cancel_cur_cfg = False
39-
self._world_infos: dict[str, WorldInfo] = {}
40-
41-
def handle_new_spec(self, spec: ServeConfig) -> None:
33+
# semaphore event for back-to-back configs
34+
self._config_event = asyncio.Event()
35+
self._config_event.set()
36+
self._world_tasks: dict[str, asyncio.Task] = {}
37+
self._curr_spec: ServeConfig = None
38+
self._curr_world_infos: dict[str, WorldInfo] = {}
39+
self._new_world_infos: dict[str, WorldInfo] = {}
40+
self.worlds_to_cancel = set()
41+
42+
async def handle_new_spec(self, spec: ServeConfig) -> None:
4243
"""Handle new spec."""
43-
self._cancel_cur_cfg = self._should_cancel_current(spec)
44-
self._spec = spec
44+
new_worlds_to_configure = ServeConfig.get_worlds_to_configure(
45+
self._curr_spec, spec
46+
)
4547

46-
def _should_cancel_current(self, spec: ServeConfig) -> bool:
47-
"""Decide if current configuration should be cancelled."""
48-
if self._spec is None:
49-
return False
48+
# on the first run, both new and cur will be empty sets
49+
new = self._new_world_infos.keys()
50+
cur = self._curr_world_infos.keys()
51+
curr_worlds_to_configure = new - cur
5052

51-
new_worlds_to_configure = ServeConfig.get_worlds_to_configure(self._spec, spec)
53+
self.worlds_to_cancel = new_worlds_to_configure & curr_worlds_to_configure
5254

53-
# cancel if the new config affects worlds currently being configured
54-
# TODO: if there's a overlap between new worlds and curr worlds we cancel
55-
# current configuration. This needs to be fixed, to cancel only the worlds that
56-
# are affected (eg new_worlds & curr_worlds)
57-
return not new_worlds_to_configure.isdisjoint(self._curr_worlds_to_configure)
55+
if len(self.worlds_to_cancel):
56+
await self._cancel_world_configuration(self.worlds_to_cancel)
5857

59-
def set_worlds_to_configure(self, world_names: set[str]) -> None:
60-
"""Set the world names currently being configured."""
61-
self._curr_worlds_to_configure = world_names
58+
# wait for current configuration to finish
59+
await self._config_event.wait()
6260

63-
def set_world_infos(self, worlds: list[WorldInfo]) -> None:
64-
"""Set new world infos."""
65-
for world_info in worlds:
66-
self._world_infos[world_info.name] = world_info
61+
# executed after each configuration
62+
self._new_world_infos = self._build_world_infos(spec)
63+
self._curr_spec = spec
64+
self.worlds_to_cancel = set()
6765

68-
def get_world_infos(self) -> dict[str, WorldInfo]:
69-
"Get world infos."
70-
return self._world_infos
66+
# block handling new spec after doing cleanup for the current one
67+
self._config_event.clear()
68+
69+
def unblock_next_config(self) -> None:
70+
"""Set task event and unblock next config process."""
71+
self._config_event.set()
72+
73+
def update_world_infos(self, worlds_names: set[str]) -> None:
74+
"""Update world infos."""
75+
for world_name in worlds_names:
76+
world_info = self._new_world_infos[world_name]
77+
self._curr_world_infos[world_info.name] = world_info
78+
79+
def get_curr_world_infos(self) -> dict[str, WorldInfo]:
80+
"Get current world infos."
81+
return self._curr_world_infos
7182

7283
def is_first_run(self) -> bool:
7384
"Return boolean if is first run or not."
74-
return not self._world_infos
85+
return not self._curr_world_infos
7586

7687
def remove_world_info(self, world_name: str) -> None:
7788
"""Remove world info by name."""
78-
del self._world_infos[world_name]
89+
del self._curr_world_infos[world_name]
7990

80-
def get_worlds_to_add_and_remove(self) -> tuple[list[WorldInfo], list[WorldInfo]]:
91+
def get_worlds_to_add_and_remove(self) -> tuple[set[str], set[str]]:
8192
"""Return a list of world infos to add and to remove."""
82-
new_world_infos = self._build_world_infos()
83-
84-
new = new_world_infos.keys()
85-
cur = self._world_infos.keys()
93+
new = self._new_world_infos.keys()
94+
cur = self._curr_world_infos.keys()
8695

87-
worlds_to_add = [new_world_infos[name] for name in new - cur]
88-
worlds_to_remove = [self._world_infos[name] for name in cur - new]
96+
worlds_to_add = new - cur
97+
worlds_to_remove = cur - new
8998

9099
return worlds_to_add, worlds_to_remove
91100

92-
async def schedule(self, coro_factory: Callable[[], Awaitable[None]]):
93-
"""Cancel any in-progress configure and schedule a new one."""
94-
# wait for current to finish if we do not want to cancel
95-
if not self._cancel_cur_cfg:
96-
await self._event.wait()
97-
98-
# cancel current if running
99-
if self._task and not self._task.done():
100-
self._task.cancel()
101-
try:
102-
await self._task
103-
except asyncio.CancelledError:
104-
pass
105-
106-
# block again for new run
107-
self._event.clear()
108-
self._task = self._loop.create_task(self._run(coro_factory))
109-
110-
def _build_world_infos(self) -> dict[str, WorldInfo]:
101+
def get_new_world_info(self, world_name: str) -> dict[str, WorldInfo]:
102+
"""Return new world info based on world name."""
103+
return self._new_world_infos[world_name]
104+
105+
def get_worlds_to_add(self, world_names: set[str]) -> list[WorldInfo]:
106+
"""Return a list of world infos to add."""
107+
return [self._new_world_infos[world_name] for world_name in world_names]
108+
109+
def get_worlds_to_remove(self, world_names: set[str]) -> list[WorldInfo]:
110+
"""Return a list of world infos to remove."""
111+
return [self._curr_world_infos[world_name] for world_name in world_names]
112+
113+
async def _cancel_world_configuration(self, world_names: set[str]):
114+
"""Cancel only worlds that are impacted by new spec."""
115+
coroutines = [self._cancel_world(w) for w in world_names]
116+
await asyncio.gather(*coroutines, return_exceptions=True)
117+
118+
def schedule_world_cfg(
119+
self, world_info: WorldInfo, coro_factory: Callable[[], Awaitable[None]]
120+
):
121+
"""Schedule configuration for a single world."""
122+
task = self._loop.create_task(self._run_world(world_info, coro_factory))
123+
self._world_tasks[world_info.name] = task
124+
return task
125+
126+
async def _cancel_world(self, world_name: str):
127+
"""Cancel an in-progress world config task."""
128+
task = self._world_tasks.pop(world_name, None)
129+
if task and not task.done():
130+
task.cancel()
131+
raise asyncio.CancelledError
132+
133+
def _build_world_infos(self, spec: ServeConfig) -> dict[str, WorldInfo]:
111134
world_infos: dict[str, WorldInfo] = {}
112135

113-
my_id = self._spec.stage.id
114-
for k, v in self._spec.flow_graph.items():
136+
my_id = spec.stage.id
137+
for k, v in spec.flow_graph.items():
115138
for cfg_world_info in v:
116139
# NOTE: no. of peers is always 1 for now
117140
assert len(cfg_world_info.peers) == 1
@@ -161,13 +184,13 @@ def _build_world_infos(self) -> dict[str, WorldInfo]:
161184

162185
return world_infos
163186

164-
async def _run(self, coro_factory: Callable[[], Awaitable[None]]):
165-
"""Run coroutine factory."""
187+
async def _run_world(
188+
self, world_info: WorldInfo, coro_factory: Callable[[], Awaitable[None]]
189+
):
190+
"""Run and cleanup world configuration."""
166191
try:
167-
await coro_factory()
192+
await coro_factory(world_info)
168193
except asyncio.CancelledError:
169-
pass
194+
raise
170195
finally:
171-
# reset class attributes and events
172-
self._event.set()
173-
self._curr_worlds_to_configure = set()
196+
self._world_tasks.pop(world_info.name, None)

infscale/execution/control.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,10 +162,31 @@ async def setup(self) -> None:
162162
if self.rank == 0:
163163
self._server_task = asyncio.create_task(self._setup_server(setup_done))
164164
else:
165-
_ = asyncio.create_task(self._setup_client(setup_done))
165+
client_task = asyncio.create_task(self._setup_client(setup_done))
166166

167167
# wait until setting up either server or client is done
168-
await setup_done.wait()
168+
try:
169+
await setup_done.wait()
170+
except asyncio.CancelledError as e:
171+
# since both _setup_server and _setup_client are spawned as separate tasks
172+
# and the setup itself is a task, we need to handle parent task cancellation
173+
# on the awaited line, since cancellation only propagates through awaited calls
174+
# here, await setup_done.wait() is the propagation point from parent task to child tasks
175+
# so we need to cancel child tasks whenever CancelledError is received
176+
if self._server_task and not self._server_task.done():
177+
self._server_task.cancel()
178+
try:
179+
await self._server_task
180+
except asyncio.CancelledError:
181+
pass
182+
183+
if client_task and not client_task.done():
184+
client_task.cancel()
185+
try:
186+
await client_task
187+
except asyncio.CancelledError:
188+
pass
189+
raise
169190

170191
def cleanup(self) -> None:
171192
if self._server_task is not None:

infscale/execution/pipeline.py

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,8 @@ async def _configure_multiworld(self, world_info: WorldInfo) -> None:
104104
port=port,
105105
device=self.device,
106106
)
107+
except asyncio.CancelledError:
108+
logger.warning(f"multiworld configuration cancelled for {world_info.name}")
107109
except Exception as e:
108110
logger.error(f"failed to initialize a multiworld {name}: {e}")
109111
condition = self._status != WorkerStatus.UPDATING
@@ -117,7 +119,7 @@ def _set_worker_status(self, status: WorkerStatus) -> None:
117119
"""Set worker status in pipeline and channel."""
118120
self._status = status
119121

120-
world_infos = self.config_manager.get_world_infos()
122+
world_infos = self.config_manager.get_curr_world_infos()
121123

122124
for world_info in world_infos.values():
123125
world_info.channel.set_worker_status(status)
@@ -130,13 +132,16 @@ def _set_n_send_worker_status(self, status: WorkerStatus) -> None:
130132
self.wcomm.send(msg)
131133

132134
async def _configure_control_channel(self, world_info: WorldInfo) -> None:
133-
await world_info.channel.setup()
135+
try:
136+
await world_info.channel.setup()
134137

135-
await world_info.channel.wait_readiness()
138+
await world_info.channel.wait_readiness()
139+
except asyncio.CancelledError:
140+
logger.warning(f"channel configuration cancelled for {world_info}")
136141

137142
async def _cleanup_recovered_worlds(self) -> None:
138143
"""Clean up world infos for recovered worlds."""
139-
world_infos = self.config_manager.get_world_infos()
144+
world_infos = self.config_manager.get_curr_world_infos()
140145

141146
# if I'm the recovered worker, return
142147
if len(world_infos) == 0:
@@ -169,14 +174,18 @@ async def _configure(self) -> None:
169174
if not is_first_run:
170175
self._set_worker_status(WorkerStatus.UPDATING)
171176

172-
worlds_to_add, worlds_to_remove = (
177+
world_names_to_add, world_names_to_remove = (
173178
self.config_manager.get_worlds_to_add_and_remove()
174179
)
175180

176181
tasks = []
177182
# 1. set up control channel
178-
for world_info in worlds_to_add:
179-
task = self._configure_control_channel(world_info)
183+
for world_name in world_names_to_add - self.config_manager.worlds_to_cancel:
184+
world_info = self.config_manager.get_new_world_info(world_name)
185+
186+
task = self.config_manager.schedule_world_cfg(
187+
world_info, self._configure_control_channel
188+
)
180189
tasks.append(task)
181190

182191
# TODO: this doesn't handle partial success
@@ -185,16 +194,29 @@ async def _configure(self) -> None:
185194

186195
tasks = []
187196
# 2. set up multiworld
188-
for world_info in worlds_to_add:
189-
task = self._configure_multiworld(world_info)
197+
for world_name in world_names_to_add - self.config_manager.worlds_to_cancel:
198+
world_info = self.config_manager.get_new_world_info(world_name)
199+
task = self.config_manager.schedule_world_cfg(
200+
world_info, self._configure_multiworld
201+
)
190202
tasks.append(task)
191203

192204
# TODO: this doesn't handle partial success
193205
# a mechanism to handle a failure is left as a todo
194206
await asyncio.gather(*tasks)
195207

196208
# update world_info for added worlds
197-
self.config_manager.set_world_infos(worlds_to_add)
209+
self.config_manager.update_world_infos(
210+
world_names_to_add - self.config_manager.worlds_to_cancel
211+
)
212+
213+
worlds_to_add = self.config_manager.get_worlds_to_add(
214+
world_names_to_add - self.config_manager.worlds_to_cancel
215+
)
216+
217+
worlds_to_remove = self.config_manager.get_worlds_to_remove(
218+
world_names_to_remove
219+
)
198220

199221
# configure router with worlds to add and remove
200222
await self.router.configure(
@@ -217,6 +239,7 @@ async def _configure(self) -> None:
217239

218240
worker_status = WorkerStatus.RUNNING if is_first_run else WorkerStatus.UPDATED
219241

242+
self.config_manager.unblock_next_config()
220243
self._set_n_send_worker_status(worker_status)
221244

222245
self.cfg_event.set()
@@ -415,16 +438,17 @@ async def _handle_config(self, spec: ServeConfig) -> None:
415438
if spec is None:
416439
return
417440

418-
self.config_manager.handle_new_spec(spec)
419-
420441
self._configure_variables(spec)
421442

422443
self._inspector.configure(self.spec)
423444

424445
self._initialize_once()
425446

426-
# (re)configure the pipeline
427-
await self.config_manager.schedule(self._configure)
447+
await self.config_manager.handle_new_spec(spec)
448+
449+
# run configure as a separate task since we need to unblock receiving
450+
# a new config to be processed when current configuration is finished
451+
_ = asyncio.create_task(self._configure())
428452

429453
def _configure_variables(self, spec: ServeConfig) -> None:
430454
"""Set variables that need to be updated."""

0 commit comments

Comments
 (0)