Skip to content

Commit 51d3327

Browse files
committed
refactor: pipeline world infos
Refactored pipeline code and moved world info related stuff into config_runner. With this, config_runner will be responsible of computing worlds to add and to remove and it will be used for when we refactor the code to handle per world configuration.
1 parent 0112e8a commit 51d3327

File tree

3 files changed

+198
-165
lines changed

3 files changed

+198
-165
lines changed
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
# Copyright 2025 Cisco Systems, Inc. and its affiliates
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
# SPDX-License-Identifier: Apache-2.0
16+
17+
"""config_manager.py."""
18+
19+
import asyncio
20+
from typing import Awaitable, Callable
21+
22+
from infscale.configs.job import ServeConfig
23+
from infscale.execution.control import Channel as CtrlCh
24+
from infscale.execution.world import WorldInfo
25+
26+
27+
class ConfigManager:
28+
"""ConfigManager class."""
29+
30+
def __init__(self):
31+
"""Initialize config manager instance."""
32+
self._loop = asyncio.get_event_loop()
33+
self._task: asyncio.Task | None = None
34+
self._event = asyncio.Event()
35+
self._spec: ServeConfig = None
36+
self._event.set()
37+
self._curr_worlds_to_configure: set[str] = set()
38+
self._cancel_cur_cfg = False
39+
self._world_infos: dict[str, WorldInfo] = {}
40+
41+
def handle_new_spec(self, spec: ServeConfig) -> None:
42+
"""Handle new spec."""
43+
self._cancel_cur_cfg = self._should_cancel_current(spec)
44+
self._spec = spec
45+
46+
def _should_cancel_current(self, spec: ServeConfig) -> bool:
47+
"""Decide if current configuration should be cancelled."""
48+
if self._spec is None:
49+
return False
50+
51+
new_worlds_to_configure = ServeConfig.get_worlds_to_configure(self._spec, spec)
52+
53+
# cancel if the new config affects worlds currently being configured
54+
# TODO: if there's a overlap between new worlds and curr worlds we cancel
55+
# current configuration. This needs to be fixed, to cancel only the worlds that
56+
# are affected (eg new_worlds & curr_worlds)
57+
return not new_worlds_to_configure.isdisjoint(self._curr_worlds_to_configure)
58+
59+
def set_worlds_to_configure(self, world_names: set[str]) -> None:
60+
"""Set the world names currently being configured."""
61+
self._curr_worlds_to_configure = world_names
62+
63+
def set_world_infos(self, worlds: list[WorldInfo]) -> None:
64+
"""Set new world infos."""
65+
for world_info in worlds:
66+
self._world_infos[world_info.name] = world_info
67+
68+
def get_world_infos(self) -> dict[str, WorldInfo]:
69+
"Get world infos."
70+
return self._world_infos
71+
72+
def is_first_run(self) -> bool:
73+
"Return boolean if is first run or not."
74+
return not self._world_infos
75+
76+
def remove_world_info(self, world_name: str) -> None:
77+
"""Remove world info by name."""
78+
del self._world_infos[world_name]
79+
80+
def get_words_to_add(self) -> list[WorldInfo]:
81+
new_world_infos = self._build_world_infos()
82+
new = new_world_infos.keys()
83+
cur = self._world_infos.keys()
84+
"""Return a list of world infos to add."""
85+
return [new_world_infos[name] for name in new - cur]
86+
87+
def get_words_to_remove(self) -> list[WorldInfo]:
88+
"""Return a list of world infos to remove."""
89+
new_world_infos = self._build_world_infos()
90+
new = new_world_infos.keys()
91+
cur = self._world_infos.keys()
92+
93+
return [new_world_infos[name] for name in cur - new]
94+
95+
async def schedule(self, coro_factory: Callable[[], Awaitable[None]]):
96+
"""Cancel any in-progress configure and schedule a new one."""
97+
# wait for current to finish if we do not want to cancel
98+
if not self._cancel_cur_cfg:
99+
await self._event.wait()
100+
101+
# cancel current if running
102+
if self._task and not self._task.done():
103+
self._task.cancel()
104+
try:
105+
await self._task
106+
except asyncio.CancelledError:
107+
pass
108+
109+
# block again for new run
110+
self._event.clear()
111+
self._task = self._loop.create_task(self._run(coro_factory))
112+
113+
def _build_world_infos(self) -> dict[str, WorldInfo]:
114+
world_infos: dict[str, WorldInfo] = {}
115+
116+
my_id = self._spec.stage.id
117+
for k, v in self._spec.flow_graph.items():
118+
for cfg_world_info in v:
119+
# NOTE: no. of peers is always 1 for now
120+
assert len(cfg_world_info.peers) == 1
121+
122+
if my_id == k:
123+
my_rank = 0
124+
other_rank = 1
125+
other_id = cfg_world_info.peers[0]
126+
elif my_id in cfg_world_info.peers:
127+
# NOTE: this is always 1 for now
128+
my_rank = cfg_world_info.peers.index(my_id) + 1
129+
other_rank = 0
130+
other_id = k
131+
else:
132+
continue
133+
134+
name, backend, addr, data_port, ctrl_port, recover, conflict_count = (
135+
cfg_world_info.name,
136+
cfg_world_info.backend,
137+
cfg_world_info.addr,
138+
cfg_world_info.data_port,
139+
cfg_world_info.ctrl_port,
140+
cfg_world_info.recover,
141+
cfg_world_info.conflict_count,
142+
)
143+
144+
world_size = len(cfg_world_info.peers) + 1
145+
ctrl_ch = CtrlCh(my_rank, world_size, addr, ctrl_port)
146+
147+
data = {
148+
"name": name,
149+
"size": world_size,
150+
"addr": addr,
151+
"port": data_port,
152+
"backend": backend,
153+
"channel": ctrl_ch,
154+
"my_id": my_id,
155+
"me": my_rank,
156+
"other_id": other_id,
157+
"other": other_rank,
158+
"recover": recover,
159+
"conflict_count": conflict_count,
160+
"multiworld_name": f"{name}-{conflict_count}",
161+
}
162+
world_info = WorldInfo(**data)
163+
world_infos[name] = world_info
164+
165+
return world_infos
166+
167+
async def _run(self, coro_factory: Callable[[], Awaitable[None]]):
168+
"""Run coroutine factory."""
169+
try:
170+
await coro_factory()
171+
except asyncio.CancelledError:
172+
pass
173+
finally:
174+
# reset class attributes and events
175+
self._event.set()
176+
self._curr_worlds_to_configure = set()

infscale/execution/config_runner.py

Lines changed: 0 additions & 87 deletions
This file was deleted.

0 commit comments

Comments
 (0)