Skip to content

Perf Model Modeling Critical Path of Sparse Arch #2896

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 73 additions & 1 deletion torchrec/distributed/planner/perf_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@

# pyre-strict

from typing import cast, List
from collections import defaultdict
from typing import cast, DefaultDict, List, Optional

from torchrec.distributed.planner.types import (
Perf,
Expand Down Expand Up @@ -54,3 +55,74 @@ def rate(self, plan: List[ShardingOption]) -> float:
hbms[shard.rank] += cast(Storage, shard.storage).hbm

return max(hbms)


class NoopCriticalPathPerfModel(PerfModel):
"""
Models the critical path of the sparse arch. Makes the following assumptions:

1. There is a synchronization point across the ranks after each of the 4 events: Fwd/Bwd x Comms/Comp.
2. There could be additional synchronization points across ranks during communication (both fwd & bwd)
3. There could be additional synchronization points across ranks during computation (both fwd & bwd)

Args:
topology (Topology): System topology.
comms_group_keys (Optional[List[str]]): Additional synchronization points for communication. For example, if we assume that ranks
synchronize after each module and sharding type operation, then this would be ["module", "sharding_type"].
comp_group_keys (Optional[List[str]]): Additional synchronization points for computation. For example, if we assume that ranks
synchronize after each module and sharding type operation, then this would be ["module", "sharding_type"].
"""

def __init__(
self,
topology: Topology,
comms_group_keys: Optional[List[str]] = None,
comp_group_keys: Optional[List[str]] = None,
) -> None:
self._topology = topology
self.comms_group_keys: List[str] = comms_group_keys if comms_group_keys else []
self.comp_group_keys: List[str] = comp_group_keys if comp_group_keys else []

def rate(self, plan: List[ShardingOption]) -> float:
comms_data_fwd = defaultdict(lambda: defaultdict(float))
comms_data_bwd = defaultdict(lambda: defaultdict(float))
comp_data_fwd = defaultdict(lambda: defaultdict(float))
comp_data_bwd = defaultdict(lambda: defaultdict(float))
for so in plan:
if len(self.comms_group_keys) == 0:
comms_aggregation_group = ["default"]
else:
comms_aggregation_group = [
getattr(so, key) for key in self.comms_group_keys
]
if len(self.comp_group_keys) == 0:
comp_aggregation_group = ["default"]
else:
comp_aggregation_group = [
getattr(so, key) for key in self.comp_group_keys
]
for shard in so.shards:
rank = cast(int, shard.rank)
perf = cast(Perf, shard.perf)
comms_data_fwd[tuple(comms_aggregation_group)][rank] += perf.fwd_comms
comms_data_bwd[tuple(comms_aggregation_group)][rank] += perf.bwd_comms
comp_data_fwd[tuple(comp_aggregation_group)][rank] += perf.fwd_compute
comp_data_bwd[tuple(comp_aggregation_group)][rank] += perf.bwd_compute

# Compute the cost by looking at the summing up the max cost across all ranks for each synchronization point
def _compute_aggregated_cost(
d: DefaultDict[tuple[str, ...], DefaultDict[int, float]]
) -> float:
return sum(
{
outer_key: max(inner_dict.values())
for outer_key, inner_dict in d.items()
}.values()
)

comms_fwd_cost = _compute_aggregated_cost(comms_data_fwd)
comms_bwd_cost = _compute_aggregated_cost(comms_data_bwd)
comp_fwd_cost = _compute_aggregated_cost(comp_data_fwd)
comp_bwd_sum = _compute_aggregated_cost(comp_data_bwd)

return comms_fwd_cost + comp_fwd_cost + comms_bwd_cost + comp_bwd_sum
22 changes: 20 additions & 2 deletions torchrec/distributed/planner/tests/test_perf_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
import unittest
from unittest.mock import MagicMock

from torchrec.distributed.planner.perf_models import NoopPerfModel, NoopStorageModel
from torchrec.distributed.planner.perf_models import (
NoopCriticalPathPerfModel,
NoopPerfModel,
NoopStorageModel,
)
from torchrec.distributed.planner.types import (
Perf,
Shard,
Expand All @@ -22,6 +26,7 @@

class TestPerfModels(unittest.TestCase):
def setUp(self) -> None:
sharding_types = ["CW", "TW"]
self.topology = Topology(world_size=2, compute_device="cuda")
self.tables = [
ShardingOption(
Expand All @@ -30,7 +35,7 @@ def setUp(self) -> None:
module=MagicMock(),
input_lengths=MagicMock(),
batch_size=MagicMock(),
sharding_type=MagicMock(),
sharding_type=sharding_types[rank],
partition_by=MagicMock(),
compute_kernel=MagicMock(),
shards=[
Expand Down Expand Up @@ -60,3 +65,16 @@ def test_noop_storage_model(self) -> None:
perf_model = NoopStorageModel(self.topology)
perf_rating = perf_model.rate(self.tables)
self.assertEqual(perf_rating, 200)

def test_noop_critical_path_perf_model(self) -> None:
perf_model_default = NoopCriticalPathPerfModel(self.topology)
perf_rating_default = perf_model_default.rate(self.tables)
self.assertEqual(perf_rating_default, 2)

perf_model_sharding_type = NoopCriticalPathPerfModel(
self.topology,
comms_group_keys=["sharding_type"],
comp_group_keys=["sharding_type"],
)
perf_rating_sharding_type = perf_model_sharding_type.rate(self.tables)
self.assertEqual(perf_rating_sharding_type, 3)