Skip to content

Commit 7444714

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
all_to_all_single with async_op (#3436)
Summary: # context * add benchmark for `all_to_all_single` with `async_op` option. * the comms uses a different cuda stream (comms stream) so it's non-blocking for the followed operations (on main cuda stream) * of course the comms results (pre-allocated output) are not valid until the comms' done (the pre-check fails) * when there's data dependency on the comms' output, user's need to call `req.wait()` explicitly, so that the main cuda stream wait on the comms stream NOTE: the `req.wait()` call is non-blocking on the CPU side. # ref * [torch.distributed](https://docs.pytorch.org/tutorials/beginner/dist_overview.html) * [sync and async comms](https://docs.pytorch.org/docs/stable/distributed.html#collective-functions) * [CUDA semantics](https://docs.pytorch.org/docs/stable/notes/cuda.html) Differential Revision: D83924526
1 parent c98fb11 commit 7444714

File tree

2 files changed

+87
-28
lines changed

2 files changed

+87
-28
lines changed

torchrec/distributed/benchmark/base.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -601,6 +601,7 @@ def _run_benchmark_core(
601601
pre_gpu_load: int = 0,
602602
export_stacks: bool = False,
603603
reset_accumulated_memory_stats: bool = False,
604+
all_rank_traces: bool = False,
604605
) -> BenchmarkResult:
605606
"""Internal helper that contains the core benchmarking logic shared by
606607
``benchmark`` and ``benchmark_func``. All heavy–lifting (timing, memory
@@ -718,9 +719,10 @@ def _run_benchmark_core(
718719
def _trace_handler(prof: torch.profiler.profile) -> None:
719720
total_avg = prof.profiler.total_average()
720721
logger.info(f" TOTAL_AVERAGE:\n{name}\n{total_avg}")
721-
if rank > 0:
722+
if not all_rank_traces and rank > 0:
723+
# only save trace for rank 0 when all_rank_traces is disabled
722724
return
723-
trace_file = f"{output_dir}/trace-{name}.json"
725+
trace_file = f"{output_dir}/trace-{name}-rank{rank}.json"
724726
logger.info(f" PROFILE[{name}].chrome_trace:{trace_file}")
725727
prof.export_chrome_trace(trace_file)
726728
if export_stacks:
@@ -825,6 +827,7 @@ class BenchmarkFunc:
825827
device_type: str = "cuda"
826828
pre_gpu_load: int = 0
827829
export_stacks: bool = False
830+
all_rank_traces: bool = False
828831

829832
# pyre-ignore [2]
830833
def benchmark_func_kwargs(self, **kwargs_to_override) -> Dict[str, Any]:
@@ -837,6 +840,7 @@ def benchmark_func_kwargs(self, **kwargs_to_override) -> Dict[str, Any]:
837840
"device_type": self.device_type,
838841
"pre_gpu_load": self.pre_gpu_load,
839842
"export_stacks": self.export_stacks,
843+
"all_rank_traces": self.all_rank_traces,
840844
} | kwargs_to_override
841845

842846

@@ -855,6 +859,7 @@ def benchmark_func(
855859
device_type: str = "cuda",
856860
pre_gpu_load: int = 0,
857861
export_stacks: bool = False,
862+
all_rank_traces: bool = False,
858863
) -> BenchmarkResult:
859864
"""
860865
Args:
@@ -901,4 +906,5 @@ def _profile_iter_fn(prof: torch.profiler.profile) -> None:
901906
pre_gpu_load=pre_gpu_load,
902907
export_stacks=export_stacks,
903908
reset_accumulated_memory_stats=True,
909+
all_rank_traces=all_rank_traces,
904910
)

torchrec/distributed/benchmark/benchmark_comms.py

Lines changed: 79 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,22 @@
1111
Example usage:
1212
1313
Buck2 (internal):
14-
buck2 run @fbcode//mode/opt fbcode//torchrec/distributed/benchmark:benchmark_comms --
14+
buck2 run @fbcode//mode/opt fbcode//torchrec/distributed/benchmark:benchmark_comms -- \
15+
a2a_single --name=a2a_sync_base-$(hg whereami | cut -c 1-10)
1516
1617
OSS (external):
17-
python -m torchrec.distributed.benchmark.benchmark_comms
18+
python -m torchrec.distributed.benchmark.benchmark_comms \
19+
a2a_single --name=a2a_sync_base-$(git rev-parse --short HEAD || echo $USER)
1820
21+
see README.md for more details
1922
"""
2023

2124
from dataclasses import dataclass
2225
from typing import Any, Callable, Dict, List, Optional
2326

2427
import torch
2528
import torch.distributed as dist
29+
import torch.nn.functional as F
2630

2731
from torch.autograd.profiler import record_function
2832

@@ -43,10 +47,34 @@ class A2A_Single_Run(BenchmarkFunc):
4347
profile_dir: str = "."
4448
num_benchmarks: int = 1
4549
num_profiles: int = 2
46-
num_mul: int = 10
50+
num_mul: int = 5
4751
num_concat: int = 100
4852

4953

54+
def _compute(
55+
dim: int,
56+
num_mul: int,
57+
num_concat: int,
58+
ctx: MultiProcessContext,
59+
x: Optional[torch.Tensor] = None,
60+
) -> torch.Tensor:
61+
if x is None:
62+
x = torch.rand(dim, dim, device=ctx.device) - 0.5
63+
for _ in range(num_mul):
64+
x = F.normalize(x @ x) * 10
65+
x = torch.sigmoid(x).reshape(1, dim, dim) + ctx.rank
66+
return torch.concat([x] * num_concat)
67+
68+
69+
def _validate(x: torch.Tensor, ctx: MultiProcessContext) -> List[torch.Tensor]:
70+
mixed_ranks = x.to(torch.int).reshape(-1)
71+
N = mixed_ranks.numel() // ctx.world_size
72+
checks = [
73+
torch.all(mixed_ranks[i * N : (i + 1) * N] == i) for i in range(ctx.world_size)
74+
]
75+
return checks
76+
77+
5078
# all_to_all_single with sync and single stream
5179
def a2a_sync_base(
5280
batch_inputs: List[Dict[str, Any]],
@@ -56,38 +84,61 @@ def a2a_sync_base(
5684
ctx: MultiProcessContext,
5785
) -> None:
5886
with record_function("## pre-comms compute ##"):
59-
pre_comms = torch.rand(dim, dim, device=ctx.device) - 0.5
60-
for _ in range(num_mul):
61-
pre_comms = pre_comms @ pre_comms
62-
pre_comms = torch.sigmoid(pre_comms - torch.mean(pre_comms))
63-
pre_comms = torch.sigmoid(pre_comms).reshape(1, dim, dim) + ctx.rank
64-
pre_comms = torch.concat([pre_comms] * num_concat)
87+
pre_comms = _compute(dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx)
6588

6689
with record_function("## all_to_all_single ##"):
6790
post_comms = torch.empty_like(pre_comms)
6891
req = dist.all_to_all_single(output=post_comms, input=pre_comms, group=ctx.pg)
6992

7093
with record_function("## comms validation ##"):
71-
mixed_ranks = post_comms.to(torch.int).reshape(-1)
72-
N = mixed_ranks.numel() // ctx.world_size
73-
checks = [
74-
torch.all(mixed_ranks[i * N : (i + 1) * N] == i)
75-
for i in range(ctx.world_size)
76-
]
94+
checks = _validate(post_comms, ctx)
7795

7896
with record_function("## irrelevant compute ##"):
79-
pre_comms = torch.rand(dim, dim, device=ctx.device) - 0.5
80-
for _ in range(num_mul):
81-
pre_comms = pre_comms @ pre_comms
82-
pre_comms = torch.sigmoid(pre_comms - torch.mean(pre_comms))
83-
pre_comms = torch.sigmoid(pre_comms) + ctx.rank
97+
pre_comms = _compute(dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx)
8498

8599
with record_function("## post-comms compute ##"):
86-
post_comms = post_comms[0]
87-
for _ in range(num_mul):
88-
post_comms = post_comms @ post_comms
89-
post_comms = torch.sigmoid(pre_comms - torch.mean(post_comms))
90-
post_comms = torch.sigmoid(post_comms) + ctx.rank
100+
post_comms = _compute(
101+
dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx, x=post_comms[0]
102+
)
103+
104+
with record_function("## assert ##"):
105+
assert all(checks)
106+
107+
108+
# all_to_all_single with sync and single stream
109+
def a2a_async_base(
110+
batch_inputs: List[Dict[str, Any]],
111+
dim: int,
112+
num_mul: int,
113+
num_concat: int,
114+
ctx: MultiProcessContext,
115+
) -> None:
116+
with record_function("## pre-comms compute ##"):
117+
pre_comms = _compute(dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx)
118+
119+
with record_function("## all_to_all_single ##"):
120+
# use zeros instead of empty to make sure no previous data used
121+
post_comms = torch.zeros_like(pre_comms)
122+
req = dist.all_to_all_single(
123+
output=post_comms,
124+
input=pre_comms,
125+
group=ctx.pg,
126+
async_op=True,
127+
)
128+
129+
with record_function("## comms validation ##"):
130+
# pre-check is performed before comms' done
131+
pre_checks = _validate(post_comms, ctx)
132+
133+
with record_function("## irrelevant compute ##"):
134+
pre_comms = _compute(dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx)
135+
136+
with record_function(f"## post-comms compute: pre-check-{all(pre_checks)}##"):
137+
req.wait() # assertion fails without wait()
138+
checks = _validate(post_comms, ctx)
139+
post_comms = _compute(
140+
dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx, x=post_comms[0]
141+
)
91142

92143
with record_function("## assert ##"):
93144
assert all(checks)
@@ -113,6 +164,8 @@ def a2a_single_runner(rank: int, world_size: int, arg: A2A_Single_Run) -> None:
113164

114165
if arg.name.startswith("a2a_sync_base"):
115166
func = a2a_sync_base
167+
elif arg.name.startswith("a2a_async_base"):
168+
func = a2a_async_base
116169
else:
117170
func = a2a_sync_base
118171

@@ -127,7 +180,7 @@ def a2a_single_runner(rank: int, world_size: int, arg: A2A_Single_Run) -> None:
127180
},
128181
func_to_benchmark=func,
129182
rank=rank,
130-
**arg.benchmark_func_kwargs()
183+
**arg.benchmark_func_kwargs(),
131184
)
132185

133186
if rank == 0:

0 commit comments

Comments
 (0)