Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 45 additions & 2 deletions torchrec/distributed/benchmark/benchmark_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
"""

from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Dict, List, Optional

import torch
import torch.distributed as dist
Expand All @@ -39,6 +39,7 @@
MultiProcessContext,
run_multi_process_func,
)
from torchrec.distributed.types import DeviceToHostTensorAwaitable

_cc = cmd_conf()

Expand Down Expand Up @@ -253,6 +254,46 @@ def a2a_async_twice(
assert checks1 and checks2


# all_to_all_single with sync and single stream
def lazyawaitable(
_batch_inputs: List[Dict[str, Any]],
dim: int,
num_mul: int,
num_concat: int,
ctx: MultiProcessContext,
) -> None:
with record_function("## pre-comms compute ##"):
pre_comms = _compute(dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx)

with record_function("## all_to_all_single ##"):
# use zeros instead of empty to make sure no previous data used
post_comms = torch.zeros_like(pre_comms)
req = dist.all_to_all_single(
output=post_comms,
input=pre_comms,
group=ctx.pg,
async_op=True,
)

with record_function("## irrelevant compute ##"):
pre_comms = _compute(dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx)

with record_function("## comms check ##"):
# assertion fails without wait(), this wait() makes the main cuda stream wait
# for the comms to finish, so the post-comms compute will be blocked until
# the comms is done
req.wait()
check_awaitable = DeviceToHostTensorAwaitable(_validate(post_comms, ctx))

with record_function("## post-comms compute ##"):
post_comms = _compute(
dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx, x=post_comms[0]
)

with record_function("## assert ##"):
assert check_awaitable.item()


# single-rank runner
def a2a_single_runner(rank: int, world_size: int, arg: AllToAllSingleRunConfig) -> None:
# Ensure GPUs are available and we have enough of them
Expand All @@ -274,8 +315,10 @@ def a2a_single_runner(rank: int, world_size: int, arg: AllToAllSingleRunConfig)
func = a2a_async_base
elif arg.name.startswith("a2a_async_twice"):
func = a2a_async_twice
elif arg.name.startswith("lazyawaitable"):
func = lazyawaitable
else:
func = a2a_sync_base
raise ValueError(f"Unknown benchmark name: {arg.name}")

result = benchmark_func(
bench_inputs=[],
Expand Down
18 changes: 18 additions & 0 deletions torchrec/distributed/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,24 @@ def _wait_impl(self) -> W:
return self._obj


class DeviceToHostTensorAwaitable(LazyAwaitable[torch.Tensor]):
"""An awaitable that waits for a tensor to be copied from device to host."""

def __init__(self, tensor_on_device: torch.Tensor) -> None:
super().__init__()
# self._tensor has unintialized value at this momenet
self._tensor: torch.Tensor = tensor_on_device.to("cpu", non_blocking=True)

# cuda event to record the completion of the copy
self._event = torch.cuda.Event()
self._event.record()

def _wait_impl(self) -> torch.Tensor:
# wait for the copy to complete
self._event.synchronize()
return self._tensor


KT = TypeVar("KT")
VT_co = TypeVar("VT_co")
ParentW = TypeVar("ParentW")
Expand Down
Loading