16
16
runtime users need to take care to not assume a static rank or world size.
17
17
"""
18
18
19
+ import atexit
19
20
import logging
20
21
import threading
21
22
from contextlib import contextmanager , nullcontext
75
76
logger : logging .Logger = logging .getLogger (__name__ )
76
77
77
78
# TODO: use non strings which are cheaper
78
- _QUEUE_CLOSE = "queue_close "
79
+ _PIPE_CLOSE = "pipe_close "
79
80
_FUTURE_RESULT = "fut_result"
80
81
_FUTURE_EXCEPTION = "fut_exception"
81
82
@@ -940,36 +941,67 @@ def __init__(self, timeout: Union[float, timedelta] = 60.0) -> None:
940
941
941
942
self ._timeout : float = timeout
942
943
944
+ # Register the shutdown method to be called at exit
945
+ atexit .register (self .shutdown )
946
+
943
947
def shutdown (self ) -> None :
944
948
"""
945
949
Shutdown the process group. This will kill the underlying process and
946
- close all queues .
950
+ close all pipes .
947
951
948
952
This is a no-op if the process group is already shutdown.
949
953
950
954
ProcessGroup can be reconfigured after shutdown.
951
955
"""
952
-
956
+ # Close the future pipe first
957
+ if self ._future_pipe is not None :
958
+ # close future thread
959
+ self ._future_pipe .send ((- 1 , _PIPE_CLOSE , None , None ))
960
+ assert self ._future_pipe is not None
961
+ self ._future_pipe .close ()
962
+ self ._future_pipe = None
963
+ # Join the future thread after closing its pipe
964
+ if self ._future_thread is not None :
965
+ self ._future_thread .join (timeout = 10.0 )
966
+ assert self ._future_thread is not None
967
+ if self ._future_thread .is_alive ():
968
+ raise RuntimeError ("Future thread did not exit" )
969
+ self ._future_thread = None
970
+ # Close the request pipe to signal the worker process to exit
953
971
if self ._pipe is not None :
972
+ self ._pipe .send ((_PIPE_CLOSE ,))
973
+ assert self ._pipe is not None
954
974
self ._pipe .close ()
955
-
956
- future_pipe = self ._future_pipe
957
- if future_pipe is not None :
958
- # wait for the future thread to exit and then close the queue
959
- future_pipe .close ()
960
-
961
- future_thread = self ._future_thread
962
- assert future_thread is not None
963
-
964
- future_thread .join (timeout = 10.0 )
965
- if future_thread .is_alive ():
966
- raise RuntimeError ("future thread did not exit" )
967
-
968
- # Kill after closing queues to avoid log spam.
975
+ self ._pipe = None
976
+ # Terminate the worker process after closing its pipe
969
977
if self ._p is not None :
970
- self ._p .kill ()
978
+ self ._p .join (timeout = 10.0 )
979
+ assert self ._p is not None
980
+ if self ._p .is_alive ():
981
+ raise RuntimeError ("Worker process did not exit" )
982
+ self ._p = None
971
983
972
984
def configure (self , store_addr : str , rank : int , world_size : int ) -> None :
985
+ """
986
+ Structure
987
+ +-------------------+
988
+ | |
989
+ | Main Process | (updates futures)
990
+ | | <---------------
991
+ +-------------------+ |
992
+ | Pipe 1 |
993
+ v |
994
+ +-------------------+ +-------------------+
995
+ | | | |
996
+ | Worker Process | -> | Future Thread |
997
+ | | Pipe 2 | |
998
+ +-------------------+ +-------------------+
999
+
1000
+ Main Process: Maintains self._futures
1001
+ Worker Process: Handles tasks, communicates with Future Thread.
1002
+ Future Thread: Manages asynchronous tasks, updates self._futures.
1003
+ """
1004
+
973
1005
self ._world_size = world_size
974
1006
975
1007
self .shutdown ()
@@ -990,7 +1022,7 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
990
1022
rank ,
991
1023
world_size ,
992
1024
req_remote ,
993
- future_remote ,
1025
+ future_local ,
994
1026
curr_device ,
995
1027
),
996
1028
daemon = True ,
@@ -1003,7 +1035,7 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
1003
1035
self ._futures = {}
1004
1036
self ._future_thread = threading .Thread (
1005
1037
target = self ._future_handler ,
1006
- args = (future_local ,),
1038
+ args = (_MonitoredPipe ( future_remote ) ,),
1007
1039
daemon = True ,
1008
1040
)
1009
1041
self ._future_thread .start ()
@@ -1049,6 +1081,8 @@ def _worker(
1049
1081
while True :
1050
1082
op = cast (list [object ], req_pipe .recv ())
1051
1083
cmd = op [0 ]
1084
+ if cmd == _PIPE_CLOSE :
1085
+ break
1052
1086
if cmd == "func" :
1053
1087
op_id : int
1054
1088
op_id , func_name , args , kwargs , stream_device , stream_id , event = (
@@ -1172,6 +1206,8 @@ def _future_handler(self, future_pipe: _MonitoredPipe) -> None:
1172
1206
op_id , mode , data , event = cast (
1173
1207
Tuple [int , str , object , Optional [torch .cuda .Event ]], cmd
1174
1208
)
1209
+ if mode == _PIPE_CLOSE :
1210
+ break
1175
1211
with self ._futures_lock :
1176
1212
fut = self ._futures [op_id ]
1177
1213
del self ._futures [op_id ]
0 commit comments