2020from torch ._C ._distributed_c10d import _register_work , _SymmetricMemory
2121from torch .distributed ._symmetric_memory import get_symm_mem_workspace , rendezvous
2222
23+ # Import GroupName for type checking
24+ GroupName = c10d .GroupName
25+
2326_is_test_mode : bool = False
2427_mocked_group_names : set [str ] | None = None
2528_backend_streams : dict [int , torch .cuda .Stream ] = {}
@@ -35,7 +38,7 @@ def _pipelined_multi_all_gather_and_consume(
3538 shard : list [torch .Tensor ],
3639 shard_consumer : Callable [[list [torch .Tensor ], int ], None ],
3740 ag_out : list [torch .Tensor ],
38- group_name : str ,
41+ group_name : GroupName ,
3942 ag_out_needed : bool = True ,
4043) -> None :
4144 """
@@ -181,7 +184,7 @@ def _pipelined_all_gather_and_consume(
181184 shard : torch .Tensor ,
182185 shard_consumer : Callable [[torch .Tensor , int ], None ],
183186 ag_out : torch .Tensor ,
184- group_name : str ,
187+ group_name : GroupName ,
185188 ag_out_needed : bool = True ,
186189) -> None :
187190 """
@@ -209,7 +212,7 @@ def adapter(shard: list[torch.Tensor], rank: int) -> None:
209212def _pipelined_produce_and_all2all (
210213 chunk_producer : Callable [[int , torch .Tensor ], None ],
211214 output : torch .Tensor ,
212- group_name : str ,
215+ group_name : GroupName ,
213216 out_chunk_dim = 0 ,
214217) -> None :
215218 """
@@ -400,7 +403,7 @@ def _fused_all_gather_matmul_impl(
400403 kwargs_list : list [dict [str , Any ]],
401404 out_dtypes : list [torch .dtype | None ],
402405 gather_dim : int ,
403- group_name : str ,
406+ group_name : GroupName ,
404407 return_A : bool ,
405408) -> tuple [torch .Tensor | None , list [torch .Tensor ]]:
406409 if A_shard .dim () < 2 :
@@ -534,7 +537,7 @@ def _pipelined_all_gather_and_consume_last_dim(
534537 shard : torch .Tensor ,
535538 shard_consumer : Callable [[torch .Tensor , int ], None ],
536539 ag_out : torch .Tensor ,
537- group_name : str ,
540+ group_name : GroupName ,
538541 ag_out_needed : bool = True ,
539542) -> None :
540543 p2p_workspace_size_req = 0
@@ -607,7 +610,7 @@ def _fused_all_gather_matmul_last_gather_dim_impl(
607610 kwargs_list : list [dict [str , Any ]],
608611 out_dtypes : list [torch .dtype | None ],
609612 gather_dim : int ,
610- group_name : str ,
613+ group_name : GroupName ,
611614 return_A : bool ,
612615) -> tuple [torch .Tensor | None , list [torch .Tensor ]]:
613616 assert gather_dim == A_shard .ndim - 1
@@ -674,7 +677,7 @@ def _fused_all_gather_matmul_fallback(
674677 A_shard : torch .Tensor ,
675678 Bs : list [torch .Tensor ],
676679 gather_dim : int ,
677- group_name : str ,
680+ group_name : GroupName ,
678681 * ,
679682 return_A : bool = True ,
680683) -> tuple [torch .Tensor | None , list [torch .Tensor ]]:
@@ -705,7 +708,7 @@ def _fused_all_gather_matmul(
705708 A_shard : torch .Tensor ,
706709 Bs : list [torch .Tensor ],
707710 gather_dim : int ,
708- group_name : str ,
711+ group_name : GroupName ,
709712 * ,
710713 return_A : bool = True ,
711714) -> tuple [torch .Tensor | None , list [torch .Tensor ]]:
@@ -756,7 +759,7 @@ def _should_use_fused_all_gather_matmul_native(
756759 A_shard : torch .Tensor ,
757760 Bs : list [torch .Tensor ],
758761 gather_dim : int ,
759- group_name : str ,
762+ group_name : GroupName ,
760763) -> bool :
761764 group = c10d ._resolve_process_group (group_name )
762765 local_M = math .prod (A_shard .shape [:- 1 ])
@@ -778,7 +781,7 @@ def _should_use_fused_all_gather_matmul_native(
778781def _fused_all_gather_matmul_native (
779782 A_shard : torch .Tensor ,
780783 B : torch .Tensor ,
781- group_name : str ,
784+ group_name : GroupName ,
782785) -> tuple [torch .Tensor , torch .Tensor ]:
783786 symm_mem = rendezvous (A_shard , group_name )
784787 if symm_mem is None :
@@ -832,7 +835,7 @@ def _fused_all_gather_matmul_native(
832835def _should_use_multimem_all_gather_matmul (
833836 A_shard : torch .Tensor ,
834837 gather_dim : int ,
835- group_name : str ,
838+ group_name : GroupName ,
836839 return_A : bool ,
837840) -> bool :
838841 group = c10d ._resolve_process_group (group_name )
@@ -858,7 +861,7 @@ def _should_use_multimem_all_gather_matmul(
858861def _multimem_all_gather_matmul (
859862 A_shard : torch .Tensor ,
860863 Bs : list [torch .Tensor ],
861- group_name : str ,
864+ group_name : GroupName ,
862865) -> list [torch .Tensor ]:
863866 group = c10d ._resolve_process_group (group_name )
864867 A_shape = torch .Size ((A_shard .shape [0 ] * group .size (), * A_shard .shape [1 :]))
@@ -877,7 +880,7 @@ def _fused_all_gather_scaled_matmul_fallback(
877880 A_scale : torch .Tensor ,
878881 B_scales : list [torch .Tensor ],
879882 gather_dim : int ,
880- group_name : str ,
883+ group_name : GroupName ,
881884 biases : list [torch .Tensor | None ],
882885 result_scales : list [torch .Tensor | None ],
883886 out_dtypes : list [torch .dtype | None ],
@@ -951,7 +954,7 @@ def _fused_all_gather_scaled_matmul(
951954 A_scale : torch .Tensor ,
952955 B_scales : list [torch .Tensor ],
953956 gather_dim : int ,
954- group_name : str ,
957+ group_name : GroupName ,
955958 biases : list [torch .Tensor | None ],
956959 result_scales : list [torch .Tensor | None ],
957960 out_dtypes : list [torch .dtype | None ],
@@ -1057,7 +1060,7 @@ def _fused_matmul_reduce_scatter(
10571060 B : torch .Tensor ,
10581061 reduce_op : str ,
10591062 scatter_dim : int ,
1060- group_name : str ,
1063+ group_name : GroupName ,
10611064) -> torch .Tensor :
10621065 """
10631066 Perform the following logic with micro-pipelined computation and
@@ -1093,7 +1096,7 @@ def _fused_matmul_reduce_scatter_fallback(
10931096 B : torch .Tensor ,
10941097 reduce_op : str ,
10951098 scatter_dim : int ,
1096- group_name : str ,
1099+ group_name : GroupName ,
10971100) -> torch .Tensor :
10981101 res = funcol .reduce_scatter_tensor (A @ B , reduce_op , scatter_dim , group_name )
10991102 res = funcol .wait_tensor (res )
@@ -1108,7 +1111,7 @@ def _fused_matmul_reduce_scatter_impl(
11081111 out_dtype : torch .dtype | None ,
11091112 reduce_op : str ,
11101113 scatter_dim : int ,
1111- group_name : str ,
1114+ group_name : GroupName ,
11121115) -> torch .Tensor :
11131116 if A .dim () < 2 :
11141117 raise ValueError ("A_shard must be a matrix" )
@@ -1194,7 +1197,7 @@ def _fused_scaled_matmul_reduce_scatter(
11941197 reduce_op : str ,
11951198 orig_scatter_dim : int ,
11961199 scatter_dim_after_maybe_reshape : int ,
1197- group_name : str ,
1200+ group_name : GroupName ,
11981201 output_shape : list [int ],
11991202 bias : torch .Tensor | None = None ,
12001203 result_scale : torch .Tensor | None = None ,
@@ -1248,7 +1251,7 @@ def _fused_scaled_matmul_reduce_scatter_fallback(
12481251 reduce_op : str ,
12491252 orig_scatter_dim : int ,
12501253 scatter_dim_after_maybe_reshape : int ,
1251- group_name : str ,
1254+ group_name : GroupName ,
12521255 output_shape : list [int ],
12531256 bias : torch .Tensor | None = None ,
12541257 result_scale : torch .Tensor | None = None ,
@@ -1300,7 +1303,7 @@ def _fused_scaled_matmul_reduce_scatter_impl(
13001303 reduce_op : str ,
13011304 orig_scatter_dim : int ,
13021305 scatter_dim_after_maybe_reshape : int ,
1303- group_name : str ,
1306+ group_name : GroupName ,
13041307 output_shape : list [int ],
13051308) -> torch .Tensor :
13061309 if A .dim () < 2 :
@@ -1524,7 +1527,7 @@ def wait(self, timeout: timedelta = timedelta(seconds=0)) -> bool:
15241527@torch .library .impl (lib , "_low_contention_all_gather" , "Meta" )
15251528def _low_contention_all_gather_meta (
15261529 tensor : torch .Tensor ,
1527- group_name : str ,
1530+ group_name : GroupName ,
15281531) -> torch .Tensor :
15291532 group_size = c10d ._get_group_size_by_name (group_name )
15301533 return tensor .new_empty (tensor .shape [0 ] * group_size , * tensor .shape [1 :])
@@ -1533,7 +1536,7 @@ def _low_contention_all_gather_meta(
15331536@torch .library .impl (lib , "_low_contention_all_gather" , "CUDA" )
15341537def _low_contention_all_gather (
15351538 tensor : torch .Tensor ,
1536- group_name : str ,
1539+ group_name : GroupName ,
15371540) -> torch .Tensor :
15381541 """
15391542 Performs all-gather with symmetric memory in a low-contention fashion.
@@ -1582,7 +1585,7 @@ def _low_contention_all_gather(
15821585def _low_contention_reduce_scatter_meta (
15831586 tensor : torch .Tensor ,
15841587 reduce_op : str ,
1585- group_name : str ,
1588+ group_name : GroupName ,
15861589) -> torch .Tensor :
15871590 group_size = c10d ._get_group_size_by_name (group_name )
15881591 return tensor .unflatten (0 , (group_size , - 1 )).mean (dim = 0 )
@@ -1665,7 +1668,7 @@ def _low_contention_reduce_scatter_with_workspace(
16651668def _low_contention_reduce_scatter (
16661669 tensor : torch .Tensor ,
16671670 reduce_op : str ,
1668- group_name : str ,
1671+ group_name : GroupName ,
16691672) -> torch .Tensor :
16701673 """
16711674 Performs reduce-scatter with symmetric memory in a low-contention fashion.
0 commit comments