diff --git a/torchrec/distributed/benchmark/base.py b/torchrec/distributed/benchmark/base.py index c9eaca711..c95544a8d 100644 --- a/torchrec/distributed/benchmark/base.py +++ b/torchrec/distributed/benchmark/base.py @@ -601,6 +601,7 @@ def _run_benchmark_core( pre_gpu_load: int = 0, export_stacks: bool = False, reset_accumulated_memory_stats: bool = False, + all_rank_traces: bool = False, ) -> BenchmarkResult: """Internal helper that contains the core benchmarking logic shared by ``benchmark`` and ``benchmark_func``. All heavy–lifting (timing, memory @@ -721,9 +722,10 @@ def _run_benchmark_core( def _trace_handler(prof: torch.profiler.profile) -> None: total_avg = prof.profiler.total_average() logger.info(f" TOTAL_AVERAGE:\n{name}\n{total_avg}") - if rank > 0: + if not all_rank_traces and rank > 0: + # only save trace for rank 0 when all_rank_traces is disabled return - trace_file = f"{output_dir}/trace-{name}.json" + trace_file = f"{output_dir}/trace-{name}-rank{rank}.json" logger.info(f" PROFILE[{name}].chrome_trace:{trace_file}") prof.export_chrome_trace(trace_file) if export_stacks: @@ -828,6 +830,7 @@ class BenchFuncConfig: device_type: str = "cuda" pre_gpu_load: int = 0 export_stacks: bool = False + all_rank_traces: bool = False # pyre-ignore [2] def benchmark_func_kwargs(self, **kwargs_to_override) -> Dict[str, Any]: @@ -840,6 +843,7 @@ def benchmark_func_kwargs(self, **kwargs_to_override) -> Dict[str, Any]: "device_type": self.device_type, "pre_gpu_load": self.pre_gpu_load, "export_stacks": self.export_stacks, + "all_rank_traces": self.all_rank_traces, } | kwargs_to_override @@ -857,6 +861,7 @@ def benchmark_func( device_type: str = "cuda", pre_gpu_load: int = 0, export_stacks: bool = False, + all_rank_traces: bool = False, ) -> BenchmarkResult: """ Args: @@ -879,6 +884,7 @@ def benchmark_func( pre_gpu_load: Number of dummy matmul operations to run before the first measured iteration (helps simulating a loaded allocator). export_stacks: Whether to export flamegraph-compatible stack files. + all_rank_traces: Whether to export traces from all ranks. """ if benchmark_func_kwargs is None: benchmark_func_kwargs = {} @@ -905,4 +911,5 @@ def _profile_iter_fn(prof: torch.profiler.profile) -> None: pre_gpu_load=pre_gpu_load, export_stacks=export_stacks, reset_accumulated_memory_stats=True, + all_rank_traces=all_rank_traces, ) diff --git a/torchrec/distributed/benchmark/benchmark_comms.py b/torchrec/distributed/benchmark/benchmark_comms.py index 66725396d..b1ba76a0f 100644 --- a/torchrec/distributed/benchmark/benchmark_comms.py +++ b/torchrec/distributed/benchmark/benchmark_comms.py @@ -11,11 +11,14 @@ Example usage: Buck2 (internal): - buck2 run @fbcode//mode/opt fbcode//torchrec/distributed/benchmark:benchmark_comms -- + buck2 run @fbcode//mode/opt fbcode//torchrec/distributed/benchmark:benchmark_comms -- \ + a2a_single --name=a2a_sync_base-$(hg whereami | cut -c 1-10) OSS (external): - python -m torchrec.distributed.benchmark.benchmark_comms + python -m torchrec.distributed.benchmark.benchmark_comms \ + a2a_single --name=a2a_sync_base-$(git rev-parse --short HEAD || echo $USER) +see README.md for more details """ from dataclasses import dataclass @@ -23,6 +26,7 @@ import torch import torch.distributed as dist +import torch.nn.functional as F from torch.autograd.profiler import record_function @@ -47,54 +51,129 @@ class AllToAllSingleRunConfig(BenchFuncConfig): profile_dir: str = "." num_benchmarks: int = 1 num_profiles: int = 2 - num_mul: int = 10 + num_mul: int = 5 num_concat: int = 100 +def _compute( + dim: int, + num_mul: int, + num_concat: int, + ctx: MultiProcessContext, + x: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + a dummy compute function to simulate the GPU load for computing, all + operations are on the GPU side, no need to block CPU operations + """ + if x is None: + x = torch.rand(dim, dim, device=ctx.device) - 0.5 + for _ in range(num_mul): + x = F.normalize(x @ x) * 10 + x = torch.sigmoid(x).reshape(1, dim, dim) + ctx.rank + return torch.concat([x] * num_concat) + + +def _validate(x: torch.Tensor, ctx: MultiProcessContext) -> torch.Tensor: + """ + validate the correctness of the comms result, the validation is done on GPU + returns a GPU tensor with a single boolean value, non-blocking on CPU + """ + mixed_ranks = x.to(torch.int).reshape(ctx.world_size, -1) + checks = torch.empty(ctx.world_size, dtype=torch.bool, device=ctx.device) + for i in range(ctx.world_size): + checks[i] = torch.all(mixed_ranks[i, :] == i) + return torch.all(checks) + + # all_to_all_single with sync and single stream def a2a_sync_base( - batch_inputs: List[Dict[str, Any]], + _batch_inputs: List[Dict[str, Any]], dim: int, num_mul: int, num_concat: int, ctx: MultiProcessContext, ) -> None: with record_function("## pre-comms compute ##"): - pre_comms = torch.rand(dim, dim, device=ctx.device) - 0.5 - for _ in range(num_mul): - pre_comms = pre_comms @ pre_comms - pre_comms = torch.sigmoid(pre_comms - torch.mean(pre_comms)) - pre_comms = torch.sigmoid(pre_comms).reshape(1, dim, dim) + ctx.rank - pre_comms = torch.concat([pre_comms] * num_concat) + pre_comms = _compute(dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx) with record_function("## all_to_all_single ##"): post_comms = torch.empty_like(pre_comms) req = dist.all_to_all_single(output=post_comms, input=pre_comms, group=ctx.pg) with record_function("## comms validation ##"): - mixed_ranks = post_comms.to(torch.int).reshape(-1) - N = mixed_ranks.numel() // ctx.world_size - checks = [ - torch.all(mixed_ranks[i * N : (i + 1) * N] == i) - for i in range(ctx.world_size) - ] + # this non-blocking copy to CPU will trigger a device-to-host data transfer + # however, since it's from the device side, CPU doesn't know if it's finished + # so we'll need a cuda event to mark if it's done from the device side + # the trace looks very interesting without cuda.event in this case + # all cpu-side operations are non-blocking, and finished before the comms + # and hence failed the validation assertion + checks = _validate(post_comms, ctx).to(torch.device("cpu"), non_blocking=True) + ev_d2h = torch.cuda.Event() + ev_d2h.record() with record_function("## irrelevant compute ##"): - pre_comms = torch.rand(dim, dim, device=ctx.device) - 0.5 - for _ in range(num_mul): - pre_comms = pre_comms @ pre_comms - pre_comms = torch.sigmoid(pre_comms - torch.mean(pre_comms)) - pre_comms = torch.sigmoid(pre_comms) + ctx.rank + pre_comms = _compute(dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx) with record_function("## post-comms compute ##"): - post_comms = post_comms[0] - for _ in range(num_mul): - post_comms = post_comms @ post_comms - post_comms = torch.sigmoid(pre_comms - torch.mean(post_comms)) - post_comms = torch.sigmoid(post_comms) + ctx.rank + post_comms = _compute( + dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx, x=post_comms[0] + ) + + with record_function("## assert ##"): + # explained above, this event.synchroize() is needed to make sure the + # device-to-host data transfer is done before the assertion + ev_d2h.synchronize() + assert checks + + +# all_to_all_single with sync and single stream +def a2a_async_base( + _batch_inputs: List[Dict[str, Any]], + dim: int, + num_mul: int, + num_concat: int, + ctx: MultiProcessContext, +) -> None: + with record_function("## pre-comms compute ##"): + pre_comms = _compute(dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx) + + with record_function("## all_to_all_single ##"): + # use zeros instead of empty to make sure no previous data used + post_comms = torch.zeros_like(pre_comms) + req = dist.all_to_all_single( + output=post_comms, + input=pre_comms, + group=ctx.pg, + async_op=True, + ) + + with record_function("## comms validation ##"): + # pre-check is performed before comms' done + pre_checks = _validate(post_comms, ctx).to("cpu", non_blocking=True) + # need this cuda.event to record the device-to-host data transfer + ev_d2h = torch.cuda.Event() + ev_d2h.record() + + with record_function("## irrelevant compute ##"): + pre_comms = _compute(dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx) + + ev_d2h.synchronize() # make sure the pre_checks is available from cpu side + with record_function(f"## post-comms compute: pre-check-{pre_checks}##"): + # assertion fails without wait(), this wait() makes the main cuda stream wait + # for the comms to finish, so the post-comms compute will be blocked until + # the comms is done + req.wait() + checks = _validate(post_comms, ctx).to("cpu", non_blocking=True) + ev_d2h.record() # record the device-to-host data transfer + post_comms = _compute( + dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx, x=post_comms[0] + ) with record_function("## assert ##"): - assert all(checks) + # again, make sure the device-to-host data transfer is done before the assertion + ev_d2h.synchronize() + assert checks # single-rank runner @@ -114,6 +193,8 @@ def a2a_single_runner(rank: int, world_size: int, arg: AllToAllSingleRunConfig) if arg.name.startswith("a2a_sync_base"): func = a2a_sync_base + elif arg.name.startswith("a2a_async_base"): + func = a2a_async_base else: func = a2a_sync_base @@ -128,7 +209,7 @@ def a2a_single_runner(rank: int, world_size: int, arg: AllToAllSingleRunConfig) }, func_to_benchmark=func, rank=rank, - **arg.benchmark_func_kwargs() + **arg.benchmark_func_kwargs(), ) if rank == 0: