Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions torchrec/distributed/benchmark/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,7 @@ def _run_benchmark_core(
pre_gpu_load: int = 0,
export_stacks: bool = False,
reset_accumulated_memory_stats: bool = False,
all_rank_traces: bool = False,
) -> BenchmarkResult:
"""Internal helper that contains the core benchmarking logic shared by
``benchmark`` and ``benchmark_func``. All heavy–lifting (timing, memory
Expand Down Expand Up @@ -721,9 +722,10 @@ def _run_benchmark_core(
def _trace_handler(prof: torch.profiler.profile) -> None:
total_avg = prof.profiler.total_average()
logger.info(f" TOTAL_AVERAGE:\n{name}\n{total_avg}")
if rank > 0:
if not all_rank_traces and rank > 0:
# only save trace for rank 0 when all_rank_traces is disabled
return
trace_file = f"{output_dir}/trace-{name}.json"
trace_file = f"{output_dir}/trace-{name}-rank{rank}.json"
logger.info(f" PROFILE[{name}].chrome_trace:{trace_file}")
prof.export_chrome_trace(trace_file)
if export_stacks:
Expand Down Expand Up @@ -828,6 +830,7 @@ class BenchFuncConfig:
device_type: str = "cuda"
pre_gpu_load: int = 0
export_stacks: bool = False
all_rank_traces: bool = False

# pyre-ignore [2]
def benchmark_func_kwargs(self, **kwargs_to_override) -> Dict[str, Any]:
Expand All @@ -840,6 +843,7 @@ def benchmark_func_kwargs(self, **kwargs_to_override) -> Dict[str, Any]:
"device_type": self.device_type,
"pre_gpu_load": self.pre_gpu_load,
"export_stacks": self.export_stacks,
"all_rank_traces": self.all_rank_traces,
} | kwargs_to_override


Expand All @@ -857,6 +861,7 @@ def benchmark_func(
device_type: str = "cuda",
pre_gpu_load: int = 0,
export_stacks: bool = False,
all_rank_traces: bool = False,
) -> BenchmarkResult:
"""
Args:
Expand All @@ -879,6 +884,7 @@ def benchmark_func(
pre_gpu_load: Number of dummy matmul operations to run before the first
measured iteration (helps simulating a loaded allocator).
export_stacks: Whether to export flamegraph-compatible stack files.
all_rank_traces: Whether to export traces from all ranks.
"""
if benchmark_func_kwargs is None:
benchmark_func_kwargs = {}
Expand All @@ -905,4 +911,5 @@ def _profile_iter_fn(prof: torch.profiler.profile) -> None:
pre_gpu_load=pre_gpu_load,
export_stacks=export_stacks,
reset_accumulated_memory_stats=True,
all_rank_traces=all_rank_traces,
)
137 changes: 109 additions & 28 deletions torchrec/distributed/benchmark/benchmark_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,22 @@
Example usage:

Buck2 (internal):
buck2 run @fbcode//mode/opt fbcode//torchrec/distributed/benchmark:benchmark_comms --
buck2 run @fbcode//mode/opt fbcode//torchrec/distributed/benchmark:benchmark_comms -- \
a2a_single --name=a2a_sync_base-$(hg whereami | cut -c 1-10)

OSS (external):
python -m torchrec.distributed.benchmark.benchmark_comms
python -m torchrec.distributed.benchmark.benchmark_comms \
a2a_single --name=a2a_sync_base-$(git rev-parse --short HEAD || echo $USER)

see README.md for more details
"""

from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional

import torch
import torch.distributed as dist
import torch.nn.functional as F

from torch.autograd.profiler import record_function

Expand All @@ -47,54 +51,129 @@ class AllToAllSingleRunConfig(BenchFuncConfig):
profile_dir: str = "."
num_benchmarks: int = 1
num_profiles: int = 2
num_mul: int = 10
num_mul: int = 5
num_concat: int = 100


def _compute(
dim: int,
num_mul: int,
num_concat: int,
ctx: MultiProcessContext,
x: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
a dummy compute function to simulate the GPU load for computing, all
operations are on the GPU side, no need to block CPU operations
"""
if x is None:
x = torch.rand(dim, dim, device=ctx.device) - 0.5
for _ in range(num_mul):
x = F.normalize(x @ x) * 10
x = torch.sigmoid(x).reshape(1, dim, dim) + ctx.rank
return torch.concat([x] * num_concat)


def _validate(x: torch.Tensor, ctx: MultiProcessContext) -> torch.Tensor:
"""
validate the correctness of the comms result, the validation is done on GPU
returns a GPU tensor with a single boolean value, non-blocking on CPU
"""
mixed_ranks = x.to(torch.int).reshape(ctx.world_size, -1)
checks = torch.empty(ctx.world_size, dtype=torch.bool, device=ctx.device)
for i in range(ctx.world_size):
checks[i] = torch.all(mixed_ranks[i, :] == i)
return torch.all(checks)


# all_to_all_single with sync and single stream
def a2a_sync_base(
batch_inputs: List[Dict[str, Any]],
_batch_inputs: List[Dict[str, Any]],
dim: int,
num_mul: int,
num_concat: int,
ctx: MultiProcessContext,
) -> None:
with record_function("## pre-comms compute ##"):
pre_comms = torch.rand(dim, dim, device=ctx.device) - 0.5
for _ in range(num_mul):
pre_comms = pre_comms @ pre_comms
pre_comms = torch.sigmoid(pre_comms - torch.mean(pre_comms))
pre_comms = torch.sigmoid(pre_comms).reshape(1, dim, dim) + ctx.rank
pre_comms = torch.concat([pre_comms] * num_concat)
pre_comms = _compute(dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx)

with record_function("## all_to_all_single ##"):
post_comms = torch.empty_like(pre_comms)
req = dist.all_to_all_single(output=post_comms, input=pre_comms, group=ctx.pg)

with record_function("## comms validation ##"):
mixed_ranks = post_comms.to(torch.int).reshape(-1)
N = mixed_ranks.numel() // ctx.world_size
checks = [
torch.all(mixed_ranks[i * N : (i + 1) * N] == i)
for i in range(ctx.world_size)
]
# this non-blocking copy to CPU will trigger a device-to-host data transfer
# however, since it's from the device side, CPU doesn't know if it's finished
# so we'll need a cuda event to mark if it's done from the device side
# the trace looks very interesting without cuda.event in this case
# all cpu-side operations are non-blocking, and finished before the comms
# and hence failed the validation assertion
checks = _validate(post_comms, ctx).to(torch.device("cpu"), non_blocking=True)
ev_d2h = torch.cuda.Event()
ev_d2h.record()

with record_function("## irrelevant compute ##"):
pre_comms = torch.rand(dim, dim, device=ctx.device) - 0.5
for _ in range(num_mul):
pre_comms = pre_comms @ pre_comms
pre_comms = torch.sigmoid(pre_comms - torch.mean(pre_comms))
pre_comms = torch.sigmoid(pre_comms) + ctx.rank
pre_comms = _compute(dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx)

with record_function("## post-comms compute ##"):
post_comms = post_comms[0]
for _ in range(num_mul):
post_comms = post_comms @ post_comms
post_comms = torch.sigmoid(pre_comms - torch.mean(post_comms))
post_comms = torch.sigmoid(post_comms) + ctx.rank
post_comms = _compute(
dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx, x=post_comms[0]
)

with record_function("## assert ##"):
# explained above, this event.synchroize() is needed to make sure the
# device-to-host data transfer is done before the assertion
ev_d2h.synchronize()
assert checks


# all_to_all_single with sync and single stream
def a2a_async_base(
_batch_inputs: List[Dict[str, Any]],
dim: int,
num_mul: int,
num_concat: int,
ctx: MultiProcessContext,
) -> None:
with record_function("## pre-comms compute ##"):
pre_comms = _compute(dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx)

with record_function("## all_to_all_single ##"):
# use zeros instead of empty to make sure no previous data used
post_comms = torch.zeros_like(pre_comms)
req = dist.all_to_all_single(
output=post_comms,
input=pre_comms,
group=ctx.pg,
async_op=True,
)

with record_function("## comms validation ##"):
# pre-check is performed before comms' done
pre_checks = _validate(post_comms, ctx).to("cpu", non_blocking=True)
# need this cuda.event to record the device-to-host data transfer
ev_d2h = torch.cuda.Event()
ev_d2h.record()

with record_function("## irrelevant compute ##"):
pre_comms = _compute(dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx)

ev_d2h.synchronize() # make sure the pre_checks is available from cpu side
with record_function(f"## post-comms compute: pre-check-{pre_checks}##"):
# assertion fails without wait(), this wait() makes the main cuda stream wait
# for the comms to finish, so the post-comms compute will be blocked until
# the comms is done
req.wait()
checks = _validate(post_comms, ctx).to("cpu", non_blocking=True)
ev_d2h.record() # record the device-to-host data transfer
post_comms = _compute(
dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx, x=post_comms[0]
)

with record_function("## assert ##"):
assert all(checks)
# again, make sure the device-to-host data transfer is done before the assertion
ev_d2h.synchronize()
assert checks


# single-rank runner
Expand All @@ -114,6 +193,8 @@ def a2a_single_runner(rank: int, world_size: int, arg: AllToAllSingleRunConfig)

if arg.name.startswith("a2a_sync_base"):
func = a2a_sync_base
elif arg.name.startswith("a2a_async_base"):
func = a2a_async_base
else:
func = a2a_sync_base

Expand All @@ -128,7 +209,7 @@ def a2a_single_runner(rank: int, world_size: int, arg: AllToAllSingleRunConfig)
},
func_to_benchmark=func,
rank=rank,
**arg.benchmark_func_kwargs()
**arg.benchmark_func_kwargs(),
)

if rank == 0:
Expand Down
Loading