11
11
Example usage:
12
12
13
13
Buck2 (internal):
14
- buck2 run @fbcode//mode/opt fbcode//torchrec/distributed/benchmark:benchmark_comms --
14
+ buck2 run @fbcode//mode/opt fbcode//torchrec/distributed/benchmark:benchmark_comms -- \
15
+ a2a_single --name=a2a_sync_base-$(hg whereami | cut -c 1-10)
15
16
16
17
OSS (external):
17
- python -m torchrec.distributed.benchmark.benchmark_comms
18
+ python -m torchrec.distributed.benchmark.benchmark_comms \
19
+ a2a_single --name=a2a_sync_base-$(git rev-parse --short HEAD || echo $USER)
18
20
21
+ see README.md for more details
19
22
"""
20
23
21
24
from dataclasses import dataclass
22
25
from typing import Any , Callable , Dict , List , Optional
23
26
24
27
import torch
25
28
import torch .distributed as dist
29
+ import torch .nn .functional as F
26
30
27
31
from torch .autograd .profiler import record_function
28
32
@@ -47,10 +51,34 @@ class AllToAllSingleRunConfig(BenchFuncConfig):
47
51
profile_dir : str = "."
48
52
num_benchmarks : int = 1
49
53
num_profiles : int = 2
50
- num_mul : int = 10
54
+ num_mul : int = 5
51
55
num_concat : int = 100
52
56
53
57
58
+ def _compute (
59
+ dim : int ,
60
+ num_mul : int ,
61
+ num_concat : int ,
62
+ ctx : MultiProcessContext ,
63
+ x : Optional [torch .Tensor ] = None ,
64
+ ) -> torch .Tensor :
65
+ if x is None :
66
+ x = torch .rand (dim , dim , device = ctx .device ) - 0.5
67
+ for _ in range (num_mul ):
68
+ x = F .normalize (x @ x ) * 10
69
+ x = torch .sigmoid (x ).reshape (1 , dim , dim ) + ctx .rank
70
+ return torch .concat ([x ] * num_concat )
71
+
72
+
73
+ def _validate (x : torch .Tensor , ctx : MultiProcessContext ) -> List [torch .Tensor ]:
74
+ mixed_ranks = x .to (torch .int ).reshape (- 1 )
75
+ N = mixed_ranks .numel () // ctx .world_size
76
+ checks = [
77
+ torch .all (mixed_ranks [i * N : (i + 1 ) * N ] == i ) for i in range (ctx .world_size )
78
+ ]
79
+ return checks
80
+
81
+
54
82
# all_to_all_single with sync and single stream
55
83
def a2a_sync_base (
56
84
batch_inputs : List [Dict [str , Any ]],
@@ -60,38 +88,66 @@ def a2a_sync_base(
60
88
ctx : MultiProcessContext ,
61
89
) -> None :
62
90
with record_function ("## pre-comms compute ##" ):
63
- pre_comms = torch .rand (dim , dim , device = ctx .device ) - 0.5
64
- for _ in range (num_mul ):
65
- pre_comms = pre_comms @ pre_comms
66
- pre_comms = torch .sigmoid (pre_comms - torch .mean (pre_comms ))
67
- pre_comms = torch .sigmoid (pre_comms ).reshape (1 , dim , dim ) + ctx .rank
68
- pre_comms = torch .concat ([pre_comms ] * num_concat )
91
+ pre_comms = _compute (dim = dim , num_mul = num_mul , num_concat = num_concat , ctx = ctx )
69
92
70
93
with record_function ("## all_to_all_single ##" ):
71
94
post_comms = torch .empty_like (pre_comms )
72
95
req = dist .all_to_all_single (output = post_comms , input = pre_comms , group = ctx .pg )
73
96
74
97
with record_function ("## comms validation ##" ):
75
- mixed_ranks = post_comms .to (torch .int ).reshape (- 1 )
76
- N = mixed_ranks .numel () // ctx .world_size
77
- checks = [
78
- torch .all (mixed_ranks [i * N : (i + 1 ) * N ] == i )
79
- for i in range (ctx .world_size )
80
- ]
98
+ checks = _validate (post_comms , ctx )
81
99
82
100
with record_function ("## irrelevant compute ##" ):
83
- pre_comms = torch .rand (dim , dim , device = ctx .device ) - 0.5
84
- for _ in range (num_mul ):
85
- pre_comms = pre_comms @ pre_comms
86
- pre_comms = torch .sigmoid (pre_comms - torch .mean (pre_comms ))
87
- pre_comms = torch .sigmoid (pre_comms ) + ctx .rank
101
+ pre_comms = _compute (dim = dim , num_mul = num_mul , num_concat = num_concat , ctx = ctx )
88
102
89
103
with record_function ("## post-comms compute ##" ):
90
- post_comms = post_comms [0 ]
91
- for _ in range (num_mul ):
92
- post_comms = post_comms @ post_comms
93
- post_comms = torch .sigmoid (pre_comms - torch .mean (post_comms ))
94
- post_comms = torch .sigmoid (post_comms ) + ctx .rank
104
+ post_comms = _compute (
105
+ dim = dim , num_mul = num_mul , num_concat = num_concat , ctx = ctx , x = post_comms [0 ]
106
+ )
107
+
108
+ with record_function ("## assert ##" ):
109
+ assert all (checks )
110
+
111
+
112
+ # all_to_all_single with sync and single stream
113
+ def a2a_async_base (
114
+ batch_inputs : List [Dict [str , Any ]],
115
+ dim : int ,
116
+ num_mul : int ,
117
+ num_concat : int ,
118
+ ctx : MultiProcessContext ,
119
+ ) -> None :
120
+ with record_function ("## pre-comms compute ##" ):
121
+ pre_comms = _compute (dim = dim , num_mul = num_mul , num_concat = num_concat , ctx = ctx )
122
+
123
+ with record_function ("## all_to_all_single ##" ):
124
+ # use zeros instead of empty to make sure no previous data used
125
+ post_comms = torch .zeros_like (pre_comms )
126
+ req = dist .all_to_all_single (
127
+ output = post_comms ,
128
+ input = pre_comms ,
129
+ group = ctx .pg ,
130
+ async_op = True ,
131
+ )
132
+
133
+ with record_function ("## comms validation ##" ):
134
+ # pre-check is performed before comms' done
135
+ # all() will trigger a device-to-host sync to get the result
136
+ # of course you can also make it async by wrapping with Awaitable
137
+ pre_checks = all (_validate (post_comms , ctx ))
138
+
139
+ with record_function ("## irrelevant compute ##" ):
140
+ pre_comms = _compute (dim = dim , num_mul = num_mul , num_concat = num_concat , ctx = ctx )
141
+
142
+ with record_function (f"## post-comms compute: pre-check-{ pre_checks } ##" ):
143
+ # assertion fails without wait(), this wait() makes the main cuda stream wait
144
+ # for the comms to finish, so the post-comms compute will be blocked until
145
+ # the comms is done
146
+ req .wait ()
147
+ checks = _validate (post_comms , ctx )
148
+ post_comms = _compute (
149
+ dim = dim , num_mul = num_mul , num_concat = num_concat , ctx = ctx , x = post_comms [0 ]
150
+ )
95
151
96
152
with record_function ("## assert ##" ):
97
153
assert all (checks )
@@ -114,6 +170,8 @@ def a2a_single_runner(rank: int, world_size: int, arg: AllToAllSingleRunConfig)
114
170
115
171
if arg .name .startswith ("a2a_sync_base" ):
116
172
func = a2a_sync_base
173
+ elif arg .name .startswith ("a2a_async_base" ):
174
+ func = a2a_async_base
117
175
else :
118
176
func = a2a_sync_base
119
177
@@ -128,7 +186,7 @@ def a2a_single_runner(rank: int, world_size: int, arg: AllToAllSingleRunConfig)
128
186
},
129
187
func_to_benchmark = func ,
130
188
rank = rank ,
131
- ** arg .benchmark_func_kwargs ()
189
+ ** arg .benchmark_func_kwargs (),
132
190
)
133
191
134
192
if rank == 0 :
0 commit comments