@@ -30,88 +30,111 @@ class ConfigManager:
3030 def __init__ (self ):
3131 """Initialize config manager instance."""
3232 self ._loop = asyncio .get_event_loop ()
33- self ._task : asyncio .Task | None = None
34- self ._event = asyncio .Event ()
35- self ._spec : ServeConfig = None
36- self ._event .set ()
37- self ._curr_worlds_to_configure : set [str ] = set ()
38- self ._cancel_cur_cfg = False
39- self ._world_infos : dict [str , WorldInfo ] = {}
40-
41- def handle_new_spec (self , spec : ServeConfig ) -> None :
33+ # semaphore event for back-to-back configs
34+ self ._config_event = asyncio .Event ()
35+ self ._config_event .set ()
36+ self ._world_tasks : dict [str , asyncio .Task ] = {}
37+ self ._curr_spec : ServeConfig = None
38+ self ._curr_world_infos : dict [str , WorldInfo ] = {}
39+ self ._new_world_infos : dict [str , WorldInfo ] = {}
40+ self .worlds_to_cancel = set ()
41+
42+ async def handle_new_spec (self , spec : ServeConfig ) -> None :
4243 """Handle new spec."""
43- self ._cancel_cur_cfg = self ._should_cancel_current (spec )
44- self ._spec = spec
44+ new_worlds_to_configure = ServeConfig .get_worlds_to_configure (
45+ self ._curr_spec , spec
46+ )
4547
46- def _should_cancel_current ( self , spec : ServeConfig ) -> bool :
47- """Decide if current configuration should be cancelled."""
48- if self ._spec is None :
49- return False
48+ # on the first run, both new and cur will be empty sets
49+ new = self . _new_world_infos . keys ()
50+ cur = self ._curr_world_infos . keys ()
51+ curr_worlds_to_configure = new - cur
5052
51- new_worlds_to_configure = ServeConfig . get_worlds_to_configure ( self . _spec , spec )
53+ self . worlds_to_cancel = new_worlds_to_configure & curr_worlds_to_configure
5254
53- # cancel if the new config affects worlds currently being configured
54- # TODO: if there's a overlap between new worlds and curr worlds we cancel
55- # current configuration. This needs to be fixed, to cancel only the worlds that
56- # are affected (eg new_worlds & curr_worlds)
57- return not new_worlds_to_configure .isdisjoint (self ._curr_worlds_to_configure )
55+ if len (self .worlds_to_cancel ):
56+ await self ._cancel_world_configuration (self .worlds_to_cancel )
5857
59- def set_worlds_to_configure (self , world_names : set [str ]) -> None :
60- """Set the world names currently being configured."""
61- self ._curr_worlds_to_configure = world_names
58+ # wait for current configuration to finish
59+ await self ._config_event .wait ()
6260
63- def set_world_infos ( self , worlds : list [ WorldInfo ]) -> None :
64- """Set new world infos."""
65- for world_info in worlds :
66- self ._world_infos [ world_info . name ] = world_info
61+ # executed after each configuration
62+ self . _new_world_infos = self . _build_world_infos ( spec )
63+ self . _curr_spec = spec
64+ self .worlds_to_cancel = set ()
6765
68- def get_world_infos (self ) -> dict [str , WorldInfo ]:
69- "Get world infos."
70- return self ._world_infos
66+ # block handling new spec after doing cleanup for the current one
67+ self ._config_event .clear ()
68+
69+ def unblock_next_config (self ) -> None :
70+ """Set task event and unblock next config process."""
71+ self ._config_event .set ()
72+
73+ def update_world_infos (self , worlds_names : set [str ]) -> None :
74+ """Update world infos."""
75+ for world_name in worlds_names :
76+ world_info = self ._new_world_infos [world_name ]
77+ self ._curr_world_infos [world_info .name ] = world_info
78+
79+ def get_curr_world_infos (self ) -> dict [str , WorldInfo ]:
80+ "Get current world infos."
81+ return self ._curr_world_infos
7182
7283 def is_first_run (self ) -> bool :
7384 "Return boolean if is first run or not."
74- return not self ._world_infos
85+ return not self ._curr_world_infos
7586
7687 def remove_world_info (self , world_name : str ) -> None :
7788 """Remove world info by name."""
78- del self ._world_infos [world_name ]
89+ del self ._curr_world_infos [world_name ]
7990
80- def get_worlds_to_add_and_remove (self ) -> tuple [list [ WorldInfo ], list [ WorldInfo ]]:
91+ def get_worlds_to_add_and_remove (self ) -> tuple [set [ str ], set [ str ]]:
8192 """Return a list of world infos to add and to remove."""
82- new_world_infos = self ._build_world_infos ()
83-
84- new = new_world_infos .keys ()
85- cur = self ._world_infos .keys ()
93+ new = self ._new_world_infos .keys ()
94+ cur = self ._curr_world_infos .keys ()
8695
87- worlds_to_add = [ new_world_infos [ name ] for name in new - cur ]
88- worlds_to_remove = [ self . _world_infos [ name ] for name in cur - new ]
96+ worlds_to_add = new - cur
97+ worlds_to_remove = cur - new
8998
9099 return worlds_to_add , worlds_to_remove
91100
92- async def schedule (self , coro_factory : Callable [[], Awaitable [None ]]):
93- """Cancel any in-progress configure and schedule a new one."""
94- # wait for current to finish if we do not want to cancel
95- if not self ._cancel_cur_cfg :
96- await self ._event .wait ()
97-
98- # cancel current if running
99- if self ._task and not self ._task .done ():
100- self ._task .cancel ()
101- try :
102- await self ._task
103- except asyncio .CancelledError :
104- pass
105-
106- # block again for new run
107- self ._event .clear ()
108- self ._task = self ._loop .create_task (self ._run (coro_factory ))
109-
110- def _build_world_infos (self ) -> dict [str , WorldInfo ]:
101+ def get_new_world_info (self , world_name : str ) -> dict [str , WorldInfo ]:
102+ """Return new world info based on world name."""
103+ return self ._new_world_infos [world_name ]
104+
105+ def get_worlds_to_add (self , world_names : set [str ]) -> list [WorldInfo ]:
106+ """Return a list of world infos to add."""
107+ return [self ._new_world_infos [world_name ] for world_name in world_names ]
108+
109+ def get_worlds_to_remove (self , world_names : set [str ]) -> list [WorldInfo ]:
110+ """Return a list of world infos to remove."""
111+ return [self ._curr_world_infos [world_name ] for world_name in world_names ]
112+
113+ async def _cancel_world_configuration (self , world_names : set [str ]):
114+ """Cancel only worlds that are impacted by new spec."""
115+ coroutines = [self ._cancel_world (w ) for w in world_names ]
116+ await asyncio .gather (* coroutines , return_exceptions = True )
117+
118+ def schedule_world_cfg (
119+ self , world_info : WorldInfo , coro_factory : Callable [[], Awaitable [None ]]
120+ ):
121+ """Schedule configuration for a single world."""
122+ task = self ._loop .create_task (self ._run_world (world_info , coro_factory ))
123+ self ._world_tasks [world_info .name ] = task
124+ return task
125+
126+ async def _cancel_world (self , world_name : str ):
127+ """Cancel an in-progress world config task."""
128+ task = self ._world_tasks .pop (world_name , None )
129+ if task and not task .done ():
130+ task .cancel ()
131+ raise asyncio .CancelledError
132+
133+ def _build_world_infos (self , spec : ServeConfig ) -> dict [str , WorldInfo ]:
111134 world_infos : dict [str , WorldInfo ] = {}
112135
113- my_id = self . _spec .stage .id
114- for k , v in self . _spec .flow_graph .items ():
136+ my_id = spec .stage .id
137+ for k , v in spec .flow_graph .items ():
115138 for cfg_world_info in v :
116139 # NOTE: no. of peers is always 1 for now
117140 assert len (cfg_world_info .peers ) == 1
@@ -161,13 +184,13 @@ def _build_world_infos(self) -> dict[str, WorldInfo]:
161184
162185 return world_infos
163186
164- async def _run (self , coro_factory : Callable [[], Awaitable [None ]]):
165- """Run coroutine factory."""
187+ async def _run_world (
188+ self , world_info : WorldInfo , coro_factory : Callable [[], Awaitable [None ]]
189+ ):
190+ """Run and cleanup world configuration."""
166191 try :
167- await coro_factory ()
192+ await coro_factory (world_info )
168193 except asyncio .CancelledError :
169- pass
194+ raise
170195 finally :
171- # reset class attributes and events
172- self ._event .set ()
173- self ._curr_worlds_to_configure = set ()
196+ self ._world_tasks .pop (world_info .name , None )
0 commit comments