Skip to content

Commit b991118

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
all_to_all_single with async_op (#3436)
Summary: # context * add benchmark for `all_to_all_single` with `async_op` option. * the comms uses a different cuda stream (comms stream) so it's non-blocking for the followed operations (on main cuda stream) * of course the comms results (pre-allocated output) are not valid until the comms' done (the pre-check fails) * when there's data dependency on the comms' output, user's need to call `req.wait()` explicitly, so that the main cuda stream wait on the comms stream NOTE: the `req.wait()` call is non-blocking on the CPU side. # ref * [torch.distributed](https://docs.pytorch.org/tutorials/beginner/dist_overview.html) * [sync and async comms](https://docs.pytorch.org/docs/stable/distributed.html#collective-functions) * [CUDA semantics](https://docs.pytorch.org/docs/stable/notes/cuda.html) Differential Revision: D83924526
1 parent 02b8ff7 commit b991118

File tree

2 files changed

+88
-28
lines changed

2 files changed

+88
-28
lines changed

torchrec/distributed/benchmark/base.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -601,6 +601,7 @@ def _run_benchmark_core(
601601
pre_gpu_load: int = 0,
602602
export_stacks: bool = False,
603603
reset_accumulated_memory_stats: bool = False,
604+
all_rank_traces: bool = False,
604605
) -> BenchmarkResult:
605606
"""Internal helper that contains the core benchmarking logic shared by
606607
``benchmark`` and ``benchmark_func``. All heavy–lifting (timing, memory
@@ -721,9 +722,10 @@ def _run_benchmark_core(
721722
def _trace_handler(prof: torch.profiler.profile) -> None:
722723
total_avg = prof.profiler.total_average()
723724
logger.info(f" TOTAL_AVERAGE:\n{name}\n{total_avg}")
724-
if rank > 0:
725+
if not all_rank_traces and rank > 0:
726+
# only save trace for rank 0 when all_rank_traces is disabled
725727
return
726-
trace_file = f"{output_dir}/trace-{name}.json"
728+
trace_file = f"{output_dir}/trace-{name}-rank{rank}.json"
727729
logger.info(f" PROFILE[{name}].chrome_trace:{trace_file}")
728730
prof.export_chrome_trace(trace_file)
729731
if export_stacks:
@@ -828,6 +830,7 @@ class BenchFuncConfig:
828830
device_type: str = "cuda"
829831
pre_gpu_load: int = 0
830832
export_stacks: bool = False
833+
all_rank_traces: bool = False
831834

832835
# pyre-ignore [2]
833836
def benchmark_func_kwargs(self, **kwargs_to_override) -> Dict[str, Any]:
@@ -840,6 +843,7 @@ def benchmark_func_kwargs(self, **kwargs_to_override) -> Dict[str, Any]:
840843
"device_type": self.device_type,
841844
"pre_gpu_load": self.pre_gpu_load,
842845
"export_stacks": self.export_stacks,
846+
"all_rank_traces": self.all_rank_traces,
843847
} | kwargs_to_override
844848

845849

@@ -857,6 +861,7 @@ def benchmark_func(
857861
device_type: str = "cuda",
858862
pre_gpu_load: int = 0,
859863
export_stacks: bool = False,
864+
all_rank_traces: bool = False,
860865
) -> BenchmarkResult:
861866
"""
862867
Args:
@@ -879,6 +884,7 @@ def benchmark_func(
879884
pre_gpu_load: Number of dummy matmul operations to run before the first
880885
measured iteration (helps simulating a loaded allocator).
881886
export_stacks: Whether to export flamegraph-compatible stack files.
887+
all_rank_traces: Whether to export traces from all ranks.
882888
"""
883889
if benchmark_func_kwargs is None:
884890
benchmark_func_kwargs = {}
@@ -905,4 +911,5 @@ def _profile_iter_fn(prof: torch.profiler.profile) -> None:
905911
pre_gpu_load=pre_gpu_load,
906912
export_stacks=export_stacks,
907913
reset_accumulated_memory_stats=True,
914+
all_rank_traces=all_rank_traces,
908915
)

torchrec/distributed/benchmark/benchmark_comms.py

Lines changed: 79 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,22 @@
1111
Example usage:
1212
1313
Buck2 (internal):
14-
buck2 run @fbcode//mode/opt fbcode//torchrec/distributed/benchmark:benchmark_comms --
14+
buck2 run @fbcode//mode/opt fbcode//torchrec/distributed/benchmark:benchmark_comms -- \
15+
a2a_single --name=a2a_sync_base-$(hg whereami | cut -c 1-10)
1516
1617
OSS (external):
17-
python -m torchrec.distributed.benchmark.benchmark_comms
18+
python -m torchrec.distributed.benchmark.benchmark_comms \
19+
a2a_single --name=a2a_sync_base-$(git rev-parse --short HEAD || echo $USER)
1820
21+
see README.md for more details
1922
"""
2023

2124
from dataclasses import dataclass
2225
from typing import Any, Callable, Dict, List, Optional
2326

2427
import torch
2528
import torch.distributed as dist
29+
import torch.nn.functional as F
2630

2731
from torch.autograd.profiler import record_function
2832

@@ -47,10 +51,34 @@ class AllToAllSingleRunConfig(BenchFuncConfig):
4751
profile_dir: str = "."
4852
num_benchmarks: int = 1
4953
num_profiles: int = 2
50-
num_mul: int = 10
54+
num_mul: int = 5
5155
num_concat: int = 100
5256

5357

58+
def _compute(
59+
dim: int,
60+
num_mul: int,
61+
num_concat: int,
62+
ctx: MultiProcessContext,
63+
x: Optional[torch.Tensor] = None,
64+
) -> torch.Tensor:
65+
if x is None:
66+
x = torch.rand(dim, dim, device=ctx.device) - 0.5
67+
for _ in range(num_mul):
68+
x = F.normalize(x @ x) * 10
69+
x = torch.sigmoid(x).reshape(1, dim, dim) + ctx.rank
70+
return torch.concat([x] * num_concat)
71+
72+
73+
def _validate(x: torch.Tensor, ctx: MultiProcessContext) -> List[torch.Tensor]:
74+
mixed_ranks = x.to(torch.int).reshape(-1)
75+
N = mixed_ranks.numel() // ctx.world_size
76+
checks = [
77+
torch.all(mixed_ranks[i * N : (i + 1) * N] == i) for i in range(ctx.world_size)
78+
]
79+
return checks
80+
81+
5482
# all_to_all_single with sync and single stream
5583
def a2a_sync_base(
5684
batch_inputs: List[Dict[str, Any]],
@@ -60,38 +88,61 @@ def a2a_sync_base(
6088
ctx: MultiProcessContext,
6189
) -> None:
6290
with record_function("## pre-comms compute ##"):
63-
pre_comms = torch.rand(dim, dim, device=ctx.device) - 0.5
64-
for _ in range(num_mul):
65-
pre_comms = pre_comms @ pre_comms
66-
pre_comms = torch.sigmoid(pre_comms - torch.mean(pre_comms))
67-
pre_comms = torch.sigmoid(pre_comms).reshape(1, dim, dim) + ctx.rank
68-
pre_comms = torch.concat([pre_comms] * num_concat)
91+
pre_comms = _compute(dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx)
6992

7093
with record_function("## all_to_all_single ##"):
7194
post_comms = torch.empty_like(pre_comms)
7295
req = dist.all_to_all_single(output=post_comms, input=pre_comms, group=ctx.pg)
7396

7497
with record_function("## comms validation ##"):
75-
mixed_ranks = post_comms.to(torch.int).reshape(-1)
76-
N = mixed_ranks.numel() // ctx.world_size
77-
checks = [
78-
torch.all(mixed_ranks[i * N : (i + 1) * N] == i)
79-
for i in range(ctx.world_size)
80-
]
98+
checks = _validate(post_comms, ctx)
8199

82100
with record_function("## irrelevant compute ##"):
83-
pre_comms = torch.rand(dim, dim, device=ctx.device) - 0.5
84-
for _ in range(num_mul):
85-
pre_comms = pre_comms @ pre_comms
86-
pre_comms = torch.sigmoid(pre_comms - torch.mean(pre_comms))
87-
pre_comms = torch.sigmoid(pre_comms) + ctx.rank
101+
pre_comms = _compute(dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx)
88102

89103
with record_function("## post-comms compute ##"):
90-
post_comms = post_comms[0]
91-
for _ in range(num_mul):
92-
post_comms = post_comms @ post_comms
93-
post_comms = torch.sigmoid(pre_comms - torch.mean(post_comms))
94-
post_comms = torch.sigmoid(post_comms) + ctx.rank
104+
post_comms = _compute(
105+
dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx, x=post_comms[0]
106+
)
107+
108+
with record_function("## assert ##"):
109+
assert all(checks)
110+
111+
112+
# all_to_all_single with sync and single stream
113+
def a2a_async_base(
114+
batch_inputs: List[Dict[str, Any]],
115+
dim: int,
116+
num_mul: int,
117+
num_concat: int,
118+
ctx: MultiProcessContext,
119+
) -> None:
120+
with record_function("## pre-comms compute ##"):
121+
pre_comms = _compute(dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx)
122+
123+
with record_function("## all_to_all_single ##"):
124+
# use zeros instead of empty to make sure no previous data used
125+
post_comms = torch.zeros_like(pre_comms)
126+
req = dist.all_to_all_single(
127+
output=post_comms,
128+
input=pre_comms,
129+
group=ctx.pg,
130+
async_op=True,
131+
)
132+
133+
with record_function("## comms validation ##"):
134+
# pre-check is performed before comms' done
135+
pre_checks = _validate(post_comms, ctx)
136+
137+
with record_function("## irrelevant compute ##"):
138+
pre_comms = _compute(dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx)
139+
140+
with record_function(f"## post-comms compute: pre-check-{all(pre_checks)}##"):
141+
req.wait() # assertion fails without wait()
142+
checks = _validate(post_comms, ctx)
143+
post_comms = _compute(
144+
dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx, x=post_comms[0]
145+
)
95146

96147
with record_function("## assert ##"):
97148
assert all(checks)
@@ -114,6 +165,8 @@ def a2a_single_runner(rank: int, world_size: int, arg: AllToAllSingleRunConfig)
114165

115166
if arg.name.startswith("a2a_sync_base"):
116167
func = a2a_sync_base
168+
elif arg.name.startswith("a2a_async_base"):
169+
func = a2a_async_base
117170
else:
118171
func = a2a_sync_base
119172

@@ -128,7 +181,7 @@ def a2a_single_runner(rank: int, world_size: int, arg: AllToAllSingleRunConfig)
128181
},
129182
func_to_benchmark=func,
130183
rank=rank,
131-
**arg.benchmark_func_kwargs()
184+
**arg.benchmark_func_kwargs(),
132185
)
133186

134187
if rank == 0:

0 commit comments

Comments
 (0)