From 18d4c3472cfc5c26c8b9b80efd733b0da085ee29 Mon Sep 17 00:00:00 2001 From: Caner Gocmen Date: Fri, 18 Apr 2025 17:25:11 -0700 Subject: [PATCH] Perf Model Modeling Critical Path of Sparse Arch (#2896) Summary: Update the critical path definition in the planner logs to match what we think is the most realistic option. See the docstring for the detailed modeling. The goal is to later feed this into the proposer and partitioner to test if we can improve upon the current greedy algorithm used with the new perf model. We will also replace a [similar version](https://www.internalfb.com/code/fbsource/[50e47c413eb3d4e00facb9df71592f9aa81d8aee]/fbcode/torchrec/distributed/planner/stats.py?lines=1071) of this function in the stats.py to avoid duplicating the logic. Differential Revision: D73207877 --- torchrec/distributed/planner/perf_models.py | 74 ++++++++++++++++++- .../planner/tests/test_perf_models.py | 22 +++++- 2 files changed, 93 insertions(+), 3 deletions(-) diff --git a/torchrec/distributed/planner/perf_models.py b/torchrec/distributed/planner/perf_models.py index c52087379..784fa9f88 100644 --- a/torchrec/distributed/planner/perf_models.py +++ b/torchrec/distributed/planner/perf_models.py @@ -7,7 +7,8 @@ # pyre-strict -from typing import cast, List +from collections import defaultdict +from typing import cast, DefaultDict, List, Optional from torchrec.distributed.planner.types import ( Perf, @@ -54,3 +55,74 @@ def rate(self, plan: List[ShardingOption]) -> float: hbms[shard.rank] += cast(Storage, shard.storage).hbm return max(hbms) + + +class NoopCriticalPathPerfModel(PerfModel): + """ + Models the critical path of the sparse arch. Makes the following assumptions: + + 1. There is a synchronization point across the ranks after each of the 4 events: Fwd/Bwd x Comms/Comp. + 2. There could be additional synchronization points across ranks during communication (both fwd & bwd) + 3. There could be additional synchronization points across ranks during computation (both fwd & bwd) + + Args: + topology (Topology): System topology. + comms_group_keys (Optional[List[str]]): Additional synchronization points for communication. For example, if we assume that ranks + synchronize after each module and sharding type operation, then this would be ["module", "sharding_type"]. + comp_group_keys (Optional[List[str]]): Additional synchronization points for computation. For example, if we assume that ranks + synchronize after each module and sharding type operation, then this would be ["module", "sharding_type"]. + """ + + def __init__( + self, + topology: Topology, + comms_group_keys: Optional[List[str]] = None, + comp_group_keys: Optional[List[str]] = None, + ) -> None: + self._topology = topology + self.comms_group_keys: List[str] = comms_group_keys if comms_group_keys else [] + self.comp_group_keys: List[str] = comp_group_keys if comp_group_keys else [] + + def rate(self, plan: List[ShardingOption]) -> float: + comms_data_fwd = defaultdict(lambda: defaultdict(float)) + comms_data_bwd = defaultdict(lambda: defaultdict(float)) + comp_data_fwd = defaultdict(lambda: defaultdict(float)) + comp_data_bwd = defaultdict(lambda: defaultdict(float)) + for so in plan: + if len(self.comms_group_keys) == 0: + comms_aggregation_group = ["default"] + else: + comms_aggregation_group = [ + getattr(so, key) for key in self.comms_group_keys + ] + if len(self.comp_group_keys) == 0: + comp_aggregation_group = ["default"] + else: + comp_aggregation_group = [ + getattr(so, key) for key in self.comp_group_keys + ] + for shard in so.shards: + rank = cast(int, shard.rank) + perf = cast(Perf, shard.perf) + comms_data_fwd[tuple(comms_aggregation_group)][rank] += perf.fwd_comms + comms_data_bwd[tuple(comms_aggregation_group)][rank] += perf.bwd_comms + comp_data_fwd[tuple(comp_aggregation_group)][rank] += perf.fwd_compute + comp_data_bwd[tuple(comp_aggregation_group)][rank] += perf.bwd_compute + + # Compute the cost by looking at the summing up the max cost across all ranks for each synchronization point + def _compute_aggregated_cost( + d: DefaultDict[tuple[str, ...], DefaultDict[int, float]] + ) -> float: + return sum( + { + outer_key: max(inner_dict.values()) + for outer_key, inner_dict in d.items() + }.values() + ) + + comms_fwd_cost = _compute_aggregated_cost(comms_data_fwd) + comms_bwd_cost = _compute_aggregated_cost(comms_data_bwd) + comp_fwd_cost = _compute_aggregated_cost(comp_data_fwd) + comp_bwd_sum = _compute_aggregated_cost(comp_data_bwd) + + return comms_fwd_cost + comp_fwd_cost + comms_bwd_cost + comp_bwd_sum diff --git a/torchrec/distributed/planner/tests/test_perf_models.py b/torchrec/distributed/planner/tests/test_perf_models.py index d290b6647..7a2aab7bd 100644 --- a/torchrec/distributed/planner/tests/test_perf_models.py +++ b/torchrec/distributed/planner/tests/test_perf_models.py @@ -10,7 +10,11 @@ import unittest from unittest.mock import MagicMock -from torchrec.distributed.planner.perf_models import NoopPerfModel, NoopStorageModel +from torchrec.distributed.planner.perf_models import ( + NoopCriticalPathPerfModel, + NoopPerfModel, + NoopStorageModel, +) from torchrec.distributed.planner.types import ( Perf, Shard, @@ -22,6 +26,7 @@ class TestPerfModels(unittest.TestCase): def setUp(self) -> None: + sharding_types = ["CW", "TW"] self.topology = Topology(world_size=2, compute_device="cuda") self.tables = [ ShardingOption( @@ -30,7 +35,7 @@ def setUp(self) -> None: module=MagicMock(), input_lengths=MagicMock(), batch_size=MagicMock(), - sharding_type=MagicMock(), + sharding_type=sharding_types[rank], partition_by=MagicMock(), compute_kernel=MagicMock(), shards=[ @@ -60,3 +65,16 @@ def test_noop_storage_model(self) -> None: perf_model = NoopStorageModel(self.topology) perf_rating = perf_model.rate(self.tables) self.assertEqual(perf_rating, 200) + + def test_noop_critical_path_perf_model(self) -> None: + perf_model_default = NoopCriticalPathPerfModel(self.topology) + perf_rating_default = perf_model_default.rate(self.tables) + self.assertEqual(perf_rating_default, 2) + + perf_model_sharding_type = NoopCriticalPathPerfModel( + self.topology, + comms_group_keys=["sharding_type"], + comp_group_keys=["sharding_type"], + ) + perf_rating_sharding_type = perf_model_sharding_type.rate(self.tables) + self.assertEqual(perf_rating_sharding_type, 3)