Skip to content

Commit b5f45eb

Browse files
committed
Cleanup process group pipe shutdown
1 parent 88cba6a commit b5f45eb

File tree

1 file changed

+56
-20
lines changed

1 file changed

+56
-20
lines changed

torchft/process_group.py

Lines changed: 56 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
runtime users need to take care to not assume a static rank or world size.
1717
"""
1818

19+
import atexit
1920
import logging
2021
import threading
2122
from contextlib import contextmanager, nullcontext
@@ -75,7 +76,7 @@
7576
logger: logging.Logger = logging.getLogger(__name__)
7677

7778
# TODO: use non strings which are cheaper
78-
_QUEUE_CLOSE = "queue_close"
79+
_PIPE_CLOSE = "pipe_close"
7980
_FUTURE_RESULT = "fut_result"
8081
_FUTURE_EXCEPTION = "fut_exception"
8182

@@ -940,36 +941,67 @@ def __init__(self, timeout: Union[float, timedelta] = 60.0) -> None:
940941

941942
self._timeout: float = timeout
942943

944+
# Register the shutdown method to be called at exit
945+
atexit.register(self.shutdown)
946+
943947
def shutdown(self) -> None:
944948
"""
945949
Shutdown the process group. This will kill the underlying process and
946-
close all queues.
950+
close all pipes.
947951
948952
This is a no-op if the process group is already shutdown.
949953
950954
ProcessGroup can be reconfigured after shutdown.
951955
"""
952-
956+
# Close the future pipe first
957+
if self._future_pipe is not None:
958+
# close future thread
959+
self._future_pipe.send((-1, _PIPE_CLOSE, None, None))
960+
assert self._future_pipe is not None
961+
self._future_pipe.close()
962+
self._future_pipe = None
963+
# Join the future thread after closing its pipe
964+
if self._future_thread is not None:
965+
self._future_thread.join(timeout=10.0)
966+
assert self._future_thread is not None
967+
if self._future_thread.is_alive():
968+
raise RuntimeError("Future thread did not exit")
969+
self._future_thread = None
970+
# Close the request pipe to signal the worker process to exit
953971
if self._pipe is not None:
972+
self._pipe.send((_PIPE_CLOSE,))
973+
assert self._pipe is not None
954974
self._pipe.close()
955-
956-
future_pipe = self._future_pipe
957-
if future_pipe is not None:
958-
# wait for the future thread to exit and then close the queue
959-
future_pipe.close()
960-
961-
future_thread = self._future_thread
962-
assert future_thread is not None
963-
964-
future_thread.join(timeout=10.0)
965-
if future_thread.is_alive():
966-
raise RuntimeError("future thread did not exit")
967-
968-
# Kill after closing queues to avoid log spam.
975+
self._pipe = None
976+
# Terminate the worker process after closing its pipe
969977
if self._p is not None:
970-
self._p.kill()
978+
self._p.join(timeout=10.0)
979+
assert self._p is not None
980+
if self._p.is_alive():
981+
raise RuntimeError("Worker process did not exit")
982+
self._p = None
971983

972984
def configure(self, store_addr: str, rank: int, world_size: int) -> None:
985+
"""
986+
Structure
987+
+-------------------+
988+
| |
989+
| Main Process | (updates futures)
990+
| | <---------------
991+
+-------------------+ |
992+
| Pipe 1 |
993+
v |
994+
+-------------------+ +-------------------+
995+
| | | |
996+
| Worker Process | -> | Future Thread |
997+
| | Pipe 2 | |
998+
+-------------------+ +-------------------+
999+
1000+
Main Process: Maintains self._futures
1001+
Worker Process: Handles tasks, communicates with Future Thread.
1002+
Future Thread: Manages asynchronous tasks, updates self._futures.
1003+
"""
1004+
9731005
self._world_size = world_size
9741006

9751007
self.shutdown()
@@ -990,7 +1022,7 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
9901022
rank,
9911023
world_size,
9921024
req_remote,
993-
future_remote,
1025+
future_local,
9941026
curr_device,
9951027
),
9961028
daemon=True,
@@ -1003,7 +1035,7 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
10031035
self._futures = {}
10041036
self._future_thread = threading.Thread(
10051037
target=self._future_handler,
1006-
args=(future_local,),
1038+
args=(_MonitoredPipe(future_remote),),
10071039
daemon=True,
10081040
)
10091041
self._future_thread.start()
@@ -1049,6 +1081,8 @@ def _worker(
10491081
while True:
10501082
op = cast(list[object], req_pipe.recv())
10511083
cmd = op[0]
1084+
if cmd == _PIPE_CLOSE:
1085+
break
10521086
if cmd == "func":
10531087
op_id: int
10541088
op_id, func_name, args, kwargs, stream_device, stream_id, event = (
@@ -1172,6 +1206,8 @@ def _future_handler(self, future_pipe: _MonitoredPipe) -> None:
11721206
op_id, mode, data, event = cast(
11731207
Tuple[int, str, object, Optional[torch.cuda.Event]], cmd
11741208
)
1209+
if mode == _PIPE_CLOSE:
1210+
break
11751211
with self._futures_lock:
11761212
fut = self._futures[op_id]
11771213
del self._futures[op_id]

0 commit comments

Comments
 (0)