Skip to content

Commit 967605e

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
all_to_all_single with async_op (#3436)
Summary: # context * add benchmark for `all_to_all_single` with `async_op` option. * the comms uses a different cuda stream (comms stream) so it's non-blocking for the followed operations (on main cuda stream) * of course the comms results (pre-allocated output) are not valid until the comms' done (the pre-check fails) * when there's data dependency on the comms' output, user's need to call `req.wait()` explicitly, so that the main cuda stream wait on the comms stream NOTE: the `req.wait()` call is non-blocking on the CPU side. # ref * [torch.distributed](https://docs.pytorch.org/tutorials/beginner/dist_overview.html) * [sync and async comms](https://docs.pytorch.org/docs/stable/distributed.html#collective-functions) * [CUDA semantics](https://docs.pytorch.org/docs/stable/notes/cuda.html) # optimization * the data validation is done on the device side, and very often the validation result is needed from the cpu side * in a previously example (a2a-sync), the assertion needs the `checks` on the cpu side, so it's blocking the cpu execution (see below) {F1982527184} * actually the `checks` is available at the gpu side once the "comms validation" is done, and a device-to-host data transfer can be initiated. however, this device-to-host data transfer is blocking the following cpu execution (i.e., "irrelevant compute") {F1982527242} * we tried to use `non_blocking=True` in the copy_to, and the results are very interesting: all the cpu executions are done very early, this is because the cpu has no idea about the completeness of non-blocking device-to-host data transfer, it will just go ahead executing, which of coruse causes assertion error (changed to `print(checks)` to bypass the assertion in order to generate the trace) {F1982527281} * the proper way of doing this is to use cuda.event as in this diff, so that the assert only needs to wait until the data (`checks`) becomes available on the cpu side. {F1982527334} * as for the async all_to_all_single case, the 2nd batch (CPU) starts right after the 1st assertion, which is done right after the device-to-host data transfer. {F1982527765} # results * [a2a sync trace](https://drive.google.com/file/d/1xI-qlI7V6zbmcbuQGyZgqFC3ZMY1scro/view?usp=sharing) * [a2a async trace](https://drive.google.com/file/d/1eboZ01quSZjKBtBcdu4vXX9zyokzohqi/view?usp=sharing) Reviewed By: spmex Differential Revision: D83924526
1 parent 8cd65b1 commit 967605e

File tree

2 files changed

+118
-30
lines changed

2 files changed

+118
-30
lines changed

torchrec/distributed/benchmark/base.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -601,6 +601,7 @@ def _run_benchmark_core(
601601
pre_gpu_load: int = 0,
602602
export_stacks: bool = False,
603603
reset_accumulated_memory_stats: bool = False,
604+
all_rank_traces: bool = False,
604605
) -> BenchmarkResult:
605606
"""Internal helper that contains the core benchmarking logic shared by
606607
``benchmark`` and ``benchmark_func``. All heavy–lifting (timing, memory
@@ -721,9 +722,10 @@ def _run_benchmark_core(
721722
def _trace_handler(prof: torch.profiler.profile) -> None:
722723
total_avg = prof.profiler.total_average()
723724
logger.info(f" TOTAL_AVERAGE:\n{name}\n{total_avg}")
724-
if rank > 0:
725+
if not all_rank_traces and rank > 0:
726+
# only save trace for rank 0 when all_rank_traces is disabled
725727
return
726-
trace_file = f"{output_dir}/trace-{name}.json"
728+
trace_file = f"{output_dir}/trace-{name}-rank{rank}.json"
727729
logger.info(f" PROFILE[{name}].chrome_trace:{trace_file}")
728730
prof.export_chrome_trace(trace_file)
729731
if export_stacks:
@@ -828,6 +830,7 @@ class BenchFuncConfig:
828830
device_type: str = "cuda"
829831
pre_gpu_load: int = 0
830832
export_stacks: bool = False
833+
all_rank_traces: bool = False
831834

832835
# pyre-ignore [2]
833836
def benchmark_func_kwargs(self, **kwargs_to_override) -> Dict[str, Any]:
@@ -840,6 +843,7 @@ def benchmark_func_kwargs(self, **kwargs_to_override) -> Dict[str, Any]:
840843
"device_type": self.device_type,
841844
"pre_gpu_load": self.pre_gpu_load,
842845
"export_stacks": self.export_stacks,
846+
"all_rank_traces": self.all_rank_traces,
843847
} | kwargs_to_override
844848

845849

@@ -857,6 +861,7 @@ def benchmark_func(
857861
device_type: str = "cuda",
858862
pre_gpu_load: int = 0,
859863
export_stacks: bool = False,
864+
all_rank_traces: bool = False,
860865
) -> BenchmarkResult:
861866
"""
862867
Args:
@@ -879,6 +884,7 @@ def benchmark_func(
879884
pre_gpu_load: Number of dummy matmul operations to run before the first
880885
measured iteration (helps simulating a loaded allocator).
881886
export_stacks: Whether to export flamegraph-compatible stack files.
887+
all_rank_traces: Whether to export traces from all ranks.
882888
"""
883889
if benchmark_func_kwargs is None:
884890
benchmark_func_kwargs = {}
@@ -905,4 +911,5 @@ def _profile_iter_fn(prof: torch.profiler.profile) -> None:
905911
pre_gpu_load=pre_gpu_load,
906912
export_stacks=export_stacks,
907913
reset_accumulated_memory_stats=True,
914+
all_rank_traces=all_rank_traces,
908915
)

torchrec/distributed/benchmark/benchmark_comms.py

Lines changed: 109 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,22 @@
1111
Example usage:
1212
1313
Buck2 (internal):
14-
buck2 run @fbcode//mode/opt fbcode//torchrec/distributed/benchmark:benchmark_comms --
14+
buck2 run @fbcode//mode/opt fbcode//torchrec/distributed/benchmark:benchmark_comms -- \
15+
a2a_single --name=a2a_sync_base-$(hg whereami | cut -c 1-10)
1516
1617
OSS (external):
17-
python -m torchrec.distributed.benchmark.benchmark_comms
18+
python -m torchrec.distributed.benchmark.benchmark_comms \
19+
a2a_single --name=a2a_sync_base-$(git rev-parse --short HEAD || echo $USER)
1820
21+
see README.md for more details
1922
"""
2023

2124
from dataclasses import dataclass
2225
from typing import Any, Callable, Dict, List, Optional
2326

2427
import torch
2528
import torch.distributed as dist
29+
import torch.nn.functional as F
2630

2731
from torch.autograd.profiler import record_function
2832

@@ -47,54 +51,129 @@ class AllToAllSingleRunConfig(BenchFuncConfig):
4751
profile_dir: str = "."
4852
num_benchmarks: int = 1
4953
num_profiles: int = 2
50-
num_mul: int = 10
54+
num_mul: int = 5
5155
num_concat: int = 100
5256

5357

58+
def _compute(
59+
dim: int,
60+
num_mul: int,
61+
num_concat: int,
62+
ctx: MultiProcessContext,
63+
x: Optional[torch.Tensor] = None,
64+
) -> torch.Tensor:
65+
"""
66+
a dummy compute function to simulate the GPU load for computing, all
67+
operations are on the GPU side, no need to block CPU operations
68+
"""
69+
if x is None:
70+
x = torch.rand(dim, dim, device=ctx.device) - 0.5
71+
for _ in range(num_mul):
72+
x = F.normalize(x @ x) * 10
73+
x = torch.sigmoid(x).reshape(1, dim, dim) + ctx.rank
74+
return torch.concat([x] * num_concat)
75+
76+
77+
def _validate(x: torch.Tensor, ctx: MultiProcessContext) -> torch.Tensor:
78+
"""
79+
validate the correctness of the comms result, the validation is done on GPU
80+
returns a GPU tensor with a single boolean value, non-blocking on CPU
81+
"""
82+
mixed_ranks = x.to(torch.int).reshape(ctx.world_size, -1)
83+
checks = torch.empty(ctx.world_size, dtype=torch.bool, device=ctx.device)
84+
for i in range(ctx.world_size):
85+
checks[i] = torch.all(mixed_ranks[i, :] == i)
86+
return torch.all(checks)
87+
88+
5489
# all_to_all_single with sync and single stream
5590
def a2a_sync_base(
56-
batch_inputs: List[Dict[str, Any]],
91+
_batch_inputs: List[Dict[str, Any]],
5792
dim: int,
5893
num_mul: int,
5994
num_concat: int,
6095
ctx: MultiProcessContext,
6196
) -> None:
6297
with record_function("## pre-comms compute ##"):
63-
pre_comms = torch.rand(dim, dim, device=ctx.device) - 0.5
64-
for _ in range(num_mul):
65-
pre_comms = pre_comms @ pre_comms
66-
pre_comms = torch.sigmoid(pre_comms - torch.mean(pre_comms))
67-
pre_comms = torch.sigmoid(pre_comms).reshape(1, dim, dim) + ctx.rank
68-
pre_comms = torch.concat([pre_comms] * num_concat)
98+
pre_comms = _compute(dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx)
6999

70100
with record_function("## all_to_all_single ##"):
71101
post_comms = torch.empty_like(pre_comms)
72102
req = dist.all_to_all_single(output=post_comms, input=pre_comms, group=ctx.pg)
73103

74104
with record_function("## comms validation ##"):
75-
mixed_ranks = post_comms.to(torch.int).reshape(-1)
76-
N = mixed_ranks.numel() // ctx.world_size
77-
checks = [
78-
torch.all(mixed_ranks[i * N : (i + 1) * N] == i)
79-
for i in range(ctx.world_size)
80-
]
105+
# this non-blocking copy to CPU will trigger a device-to-host data transfer
106+
# however, since it's from the device side, CPU doesn't know if it's finished
107+
# so we'll need a cuda event to mark if it's done from the device side
108+
# the trace looks very interesting without cuda.event in this case
109+
# all cpu-side operations are non-blocking, and finished before the comms
110+
# and hence failed the validation assertion
111+
checks = _validate(post_comms, ctx).to(torch.device("cpu"), non_blocking=True)
112+
ev_d2h = torch.cuda.Event()
113+
ev_d2h.record()
81114

82115
with record_function("## irrelevant compute ##"):
83-
pre_comms = torch.rand(dim, dim, device=ctx.device) - 0.5
84-
for _ in range(num_mul):
85-
pre_comms = pre_comms @ pre_comms
86-
pre_comms = torch.sigmoid(pre_comms - torch.mean(pre_comms))
87-
pre_comms = torch.sigmoid(pre_comms) + ctx.rank
116+
pre_comms = _compute(dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx)
88117

89118
with record_function("## post-comms compute ##"):
90-
post_comms = post_comms[0]
91-
for _ in range(num_mul):
92-
post_comms = post_comms @ post_comms
93-
post_comms = torch.sigmoid(pre_comms - torch.mean(post_comms))
94-
post_comms = torch.sigmoid(post_comms) + ctx.rank
119+
post_comms = _compute(
120+
dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx, x=post_comms[0]
121+
)
122+
123+
with record_function("## assert ##"):
124+
# explained above, this event.synchroize() is needed to make sure the
125+
# device-to-host data transfer is done before the assertion
126+
ev_d2h.synchronize()
127+
assert checks
128+
129+
130+
# all_to_all_single with sync and single stream
131+
def a2a_async_base(
132+
_batch_inputs: List[Dict[str, Any]],
133+
dim: int,
134+
num_mul: int,
135+
num_concat: int,
136+
ctx: MultiProcessContext,
137+
) -> None:
138+
with record_function("## pre-comms compute ##"):
139+
pre_comms = _compute(dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx)
140+
141+
with record_function("## all_to_all_single ##"):
142+
# use zeros instead of empty to make sure no previous data used
143+
post_comms = torch.zeros_like(pre_comms)
144+
req = dist.all_to_all_single(
145+
output=post_comms,
146+
input=pre_comms,
147+
group=ctx.pg,
148+
async_op=True,
149+
)
150+
151+
with record_function("## comms validation ##"):
152+
# pre-check is performed before comms' done
153+
pre_checks = _validate(post_comms, ctx).to("cpu", non_blocking=True)
154+
# need this cuda.event to record the device-to-host data transfer
155+
ev_d2h = torch.cuda.Event()
156+
ev_d2h.record()
157+
158+
with record_function("## irrelevant compute ##"):
159+
pre_comms = _compute(dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx)
160+
161+
ev_d2h.synchronize() # make sure the pre_checks is available from cpu side
162+
with record_function(f"## post-comms compute: pre-check-{pre_checks}##"):
163+
# assertion fails without wait(), this wait() makes the main cuda stream wait
164+
# for the comms to finish, so the post-comms compute will be blocked until
165+
# the comms is done
166+
req.wait()
167+
checks = _validate(post_comms, ctx).to("cpu", non_blocking=True)
168+
ev_d2h.record() # record the device-to-host data transfer
169+
post_comms = _compute(
170+
dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx, x=post_comms[0]
171+
)
95172

96173
with record_function("## assert ##"):
97-
assert all(checks)
174+
# again, make sure the device-to-host data transfer is done before the assertion
175+
ev_d2h.synchronize()
176+
assert checks
98177

99178

100179
# single-rank runner
@@ -114,6 +193,8 @@ def a2a_single_runner(rank: int, world_size: int, arg: AllToAllSingleRunConfig)
114193

115194
if arg.name.startswith("a2a_sync_base"):
116195
func = a2a_sync_base
196+
elif arg.name.startswith("a2a_async_base"):
197+
func = a2a_async_base
117198
else:
118199
func = a2a_sync_base
119200

@@ -128,7 +209,7 @@ def a2a_single_runner(rank: int, world_size: int, arg: AllToAllSingleRunConfig)
128209
},
129210
func_to_benchmark=func,
130211
rank=rank,
131-
**arg.benchmark_func_kwargs()
212+
**arg.benchmark_func_kwargs(),
132213
)
133214

134215
if rank == 0:

0 commit comments

Comments
 (0)