1111Example usage:
1212
1313Buck2 (internal):
14- buck2 run @fbcode//mode/opt fbcode//torchrec/distributed/benchmark:benchmark_comms --
14+ buck2 run @fbcode//mode/opt fbcode//torchrec/distributed/benchmark:benchmark_comms -- \
15+ a2a_single --name=a2a_sync_base-$(hg whereami | cut -c 1-10)
1516
1617OSS (external):
17- python -m torchrec.distributed.benchmark.benchmark_comms
18+ python -m torchrec.distributed.benchmark.benchmark_comms \
19+ a2a_single --name=a2a_sync_base-$(git rev-parse --short HEAD || echo $USER)
1820
21+ see README.md for more details
1922"""
2023
2124from dataclasses import dataclass
2225from typing import Any , Callable , Dict , List , Optional
2326
2427import torch
2528import torch .distributed as dist
29+ import torch .nn .functional as F
2630
2731from torch .autograd .profiler import record_function
2832
@@ -43,10 +47,34 @@ class A2A_Single_Run(BenchmarkFunc):
4347 profile_dir : str = "."
4448 num_benchmarks : int = 1
4549 num_profiles : int = 2
46- num_mul : int = 10
50+ num_mul : int = 5
4751 num_concat : int = 100
4852
4953
54+ def _compute (
55+ dim : int ,
56+ num_mul : int ,
57+ num_concat : int ,
58+ ctx : MultiProcessContext ,
59+ x : Optional [torch .Tensor ] = None ,
60+ ) -> torch .Tensor :
61+ if x is None :
62+ x = torch .rand (dim , dim , device = ctx .device ) - 0.5
63+ for _ in range (num_mul ):
64+ x = F .normalize (x @ x ) * 10
65+ x = torch .sigmoid (x ).reshape (1 , dim , dim ) + ctx .rank
66+ return torch .concat ([x ] * num_concat )
67+
68+
69+ def _validate (x : torch .Tensor , ctx : MultiProcessContext ) -> List [torch .Tensor ]:
70+ mixed_ranks = x .to (torch .int ).reshape (- 1 )
71+ N = mixed_ranks .numel () // ctx .world_size
72+ checks = [
73+ torch .all (mixed_ranks [i * N : (i + 1 ) * N ] == i ) for i in range (ctx .world_size )
74+ ]
75+ return checks
76+
77+
5078# all_to_all_single with sync and single stream
5179def a2a_sync_base (
5280 batch_inputs : List [Dict [str , Any ]],
@@ -56,38 +84,61 @@ def a2a_sync_base(
5684 ctx : MultiProcessContext ,
5785) -> None :
5886 with record_function ("## pre-comms compute ##" ):
59- pre_comms = torch .rand (dim , dim , device = ctx .device ) - 0.5
60- for _ in range (num_mul ):
61- pre_comms = pre_comms @ pre_comms
62- pre_comms = torch .sigmoid (pre_comms - torch .mean (pre_comms ))
63- pre_comms = torch .sigmoid (pre_comms ).reshape (1 , dim , dim ) + ctx .rank
64- pre_comms = torch .concat ([pre_comms ] * num_concat )
87+ pre_comms = _compute (dim = dim , num_mul = num_mul , num_concat = num_concat , ctx = ctx )
6588
6689 with record_function ("## all_to_all_single ##" ):
6790 post_comms = torch .empty_like (pre_comms )
6891 req = dist .all_to_all_single (output = post_comms , input = pre_comms , group = ctx .pg )
6992
7093 with record_function ("## comms validation ##" ):
71- mixed_ranks = post_comms .to (torch .int ).reshape (- 1 )
72- N = mixed_ranks .numel () // ctx .world_size
73- checks = [
74- torch .all (mixed_ranks [i * N : (i + 1 ) * N ] == i )
75- for i in range (ctx .world_size )
76- ]
94+ checks = _validate (post_comms , ctx )
7795
7896 with record_function ("## irrelevant compute ##" ):
79- pre_comms = torch .rand (dim , dim , device = ctx .device ) - 0.5
80- for _ in range (num_mul ):
81- pre_comms = pre_comms @ pre_comms
82- pre_comms = torch .sigmoid (pre_comms - torch .mean (pre_comms ))
83- pre_comms = torch .sigmoid (pre_comms ) + ctx .rank
97+ pre_comms = _compute (dim = dim , num_mul = num_mul , num_concat = num_concat , ctx = ctx )
8498
8599 with record_function ("## post-comms compute ##" ):
86- post_comms = post_comms [0 ]
87- for _ in range (num_mul ):
88- post_comms = post_comms @ post_comms
89- post_comms = torch .sigmoid (pre_comms - torch .mean (post_comms ))
90- post_comms = torch .sigmoid (post_comms ) + ctx .rank
100+ post_comms = _compute (
101+ dim = dim , num_mul = num_mul , num_concat = num_concat , ctx = ctx , x = post_comms [0 ]
102+ )
103+
104+ with record_function ("## assert ##" ):
105+ assert all (checks )
106+
107+
108+ # all_to_all_single with sync and single stream
109+ def a2a_async_base (
110+ batch_inputs : List [Dict [str , Any ]],
111+ dim : int ,
112+ num_mul : int ,
113+ num_concat : int ,
114+ ctx : MultiProcessContext ,
115+ ) -> None :
116+ with record_function ("## pre-comms compute ##" ):
117+ pre_comms = _compute (dim = dim , num_mul = num_mul , num_concat = num_concat , ctx = ctx )
118+
119+ with record_function ("## all_to_all_single ##" ):
120+ # use zeros instead of empty to make sure no previous data used
121+ post_comms = torch .zeros_like (pre_comms )
122+ req = dist .all_to_all_single (
123+ output = post_comms ,
124+ input = pre_comms ,
125+ group = ctx .pg ,
126+ async_op = True ,
127+ )
128+
129+ with record_function ("## comms validation ##" ):
130+ # pre-check is performed before comms' done
131+ pre_checks = _validate (post_comms , ctx )
132+
133+ with record_function ("## irrelevant compute ##" ):
134+ pre_comms = _compute (dim = dim , num_mul = num_mul , num_concat = num_concat , ctx = ctx )
135+
136+ with record_function (f"## post-comms compute: pre-check-{ all (pre_checks )} ##" ):
137+ req .wait () # assertion fails without wait()
138+ checks = _validate (post_comms , ctx )
139+ post_comms = _compute (
140+ dim = dim , num_mul = num_mul , num_concat = num_concat , ctx = ctx , x = post_comms [0 ]
141+ )
91142
92143 with record_function ("## assert ##" ):
93144 assert all (checks )
@@ -113,6 +164,8 @@ def a2a_single_runner(rank: int, world_size: int, arg: A2A_Single_Run) -> None:
113164
114165 if arg .name .startswith ("a2a_sync_base" ):
115166 func = a2a_sync_base
167+ elif arg .name .startswith ("a2a_async_base" ):
168+ func = a2a_async_base
116169 else :
117170 func = a2a_sync_base
118171
@@ -127,7 +180,7 @@ def a2a_single_runner(rank: int, world_size: int, arg: A2A_Single_Run) -> None:
127180 },
128181 func_to_benchmark = func ,
129182 rank = rank ,
130- ** arg .benchmark_func_kwargs ()
183+ ** arg .benchmark_func_kwargs (),
131184 )
132185
133186 if rank == 0 :
0 commit comments