Skip to content

Commit 1f44c75

Browse files
ge0405facebook-github-bot
authored andcommitted
Fuse "states" level tensors to reduce all gather during metrics compute (#2892)
Summary: Pull Request resolved: #2892 # Contexts During S503023, we found that the metrics `compute` time can take up to **30s** on a large OMNI FM model. (Note this is not the metrics `update` every iteration. This is `compute` which will create sync and refresh metrics on tensorboard. The frequency of `compute` can be set by `compute_interval_steps` (default=100).) If Zoomer happens to capture the iteration of metrics compute, it will even make the trace file too big and can't be opened. The reason for metrics `compute` being so long was due to too many all_gather calls. See the screenshot of metrics `compute` from the above SEV. That single metrics `compute` takes 30s. And once zooming in, you will find it composed of hundreds of all_gather. {F1976937278} {F1976937301} Therefore, this diff tried to fuse the tensors before all_gather during metrics `compute`. With this diff plus `FUSE_TASKS_COMPUTATION` to fuse task-level tensors, it can reduce "RecMetricModule compute" time from 1s 287 ms to 35 ms (36X reduction) on a shrunk OMNI FM model. # What has been changed/added? 1. Add `FUSED_TASKS_AND_STATES_COMPUTATION` in RecComputeMode. When turned on, this will both fuse tasks tensors and fuse state tensors. 2. Add `fuse_state_tensors` to [Metric](https://www.internalfb.com/code/fbsource/[0c89c01039abfadd62e8ec1b34eb24b249b99f3f]/fbcode/pytorch_lightning_deprecated/metrics/torchmetrics/metric.py?lines=43) class. This will be turned on once `FUSED_TASKS_AND_STATES_COMPUTATION` is set from config. Then when a metric's `compute` is called, it will fuse (stack) state tensors before all_gather. Then reconstruct the output tensor to desired format to do reduction. 3. It is noted currently we only support fusing/stacking 1D state tensors. Therefore, for states with `List` (e.g. auc) or 2D tensor (e.g. multiclass_ne), the `FUSED_TASKS_AND_STATES_COMPUTATION` shouldn't be used or should at least fall back to either `FUSED_TASKS_COMPUTATION` or `UNFUSED_TASKS_COMPUTATION`. Reviewed By: iamzainhuda Differential Revision: D72010614 fbshipit-source-id: 5ac77088cf9737ad783ad7e1740daf6afc0224a8
1 parent ca53db2 commit 1f44c75

20 files changed

+320
-47
lines changed

torchrec/metrics/metrics_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ class RecComputeMode(Enum):
8787

8888
FUSED_TASKS_COMPUTATION = 1
8989
UNFUSED_TASKS_COMPUTATION = 2
90+
FUSED_TASKS_AND_STATES_COMPUTATION = 3
9091

9192

9293
_DEFAULT_WINDOW_SIZE = 10_000_000

torchrec/metrics/precision_session.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,10 @@ def __init__(
196196
process_group: Optional[dist.ProcessGroup] = None,
197197
**kwargs: Any,
198198
) -> None:
199-
if compute_mode == RecComputeMode.FUSED_TASKS_COMPUTATION:
199+
if compute_mode in [
200+
RecComputeMode.FUSED_TASKS_COMPUTATION,
201+
RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION,
202+
]:
200203
raise RecMetricException(
201204
"Fused computation is not supported for precision session-level metrics"
202205
)

torchrec/metrics/rec_metric.py

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -134,13 +134,19 @@ def __init__(
134134
window_size: int,
135135
compute_on_all_ranks: bool = False,
136136
should_validate_update: bool = False,
137+
fuse_state_tensors: bool = False,
137138
process_group: Optional[dist.ProcessGroup] = None,
138139
fused_update_limit: int = 0,
139140
allow_missing_label_with_zero_weight: bool = False,
140141
*args: Any,
141142
**kwargs: Any,
142143
) -> None:
143-
super().__init__(process_group=process_group, *args, **kwargs)
144+
super().__init__(
145+
process_group=process_group,
146+
fuse_state_tensors=fuse_state_tensors,
147+
*args,
148+
**kwargs,
149+
)
144150

145151
self._my_rank = my_rank
146152
self._n_tasks = n_tasks
@@ -341,7 +347,11 @@ def __init__(
341347
# TODO(stellaya): consider to inherit from TorchMetrics.Metric or
342348
# TorchMetrics.MetricCollection.
343349
if (
344-
compute_mode == RecComputeMode.FUSED_TASKS_COMPUTATION
350+
compute_mode
351+
in [
352+
RecComputeMode.FUSED_TASKS_COMPUTATION,
353+
RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION,
354+
]
345355
and fused_update_limit > 0
346356
):
347357
raise ValueError(
@@ -376,7 +386,10 @@ def __init__(
376386
f"Local window size must be larger than batch size. Got local window size {self._window_size} and batch size {self._batch_size}."
377387
)
378388

379-
if compute_mode == RecComputeMode.FUSED_TASKS_COMPUTATION:
389+
if compute_mode in [
390+
RecComputeMode.FUSED_TASKS_COMPUTATION,
391+
RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION,
392+
]:
380393
task_per_metric = len(self._tasks)
381394
self._tasks_iter = self._fused_tasks_iter
382395
else:
@@ -385,7 +398,11 @@ def __init__(
385398

386399
for task_config in (
387400
[self._tasks]
388-
if compute_mode == RecComputeMode.FUSED_TASKS_COMPUTATION
401+
if compute_mode
402+
in [
403+
RecComputeMode.FUSED_TASKS_COMPUTATION,
404+
RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION,
405+
]
389406
else self._tasks
390407
):
391408
# pyre-ignore
@@ -394,13 +411,16 @@ def __init__(
394411
# according to https://github.com/python/mypy/issues/3048.
395412
# pyre-fixme[45]: Cannot instantiate abstract class `RecMetricCoputation`.
396413
metric_computation = self._computation_class(
397-
my_rank,
398-
batch_size,
399-
task_per_metric,
400-
self._window_size,
401-
compute_on_all_ranks,
402-
self._should_validate_update,
403-
process_group,
414+
my_rank=my_rank,
415+
batch_size=batch_size,
416+
n_tasks=task_per_metric,
417+
window_size=self._window_size,
418+
compute_on_all_ranks=compute_on_all_ranks,
419+
should_validate_update=self._should_validate_update,
420+
fuse_state_tensors=(
421+
compute_mode == RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION
422+
),
423+
process_group=process_group,
404424
**{**kwargs, **self._get_task_kwargs(task_config)},
405425
)
406426
required_inputs = self._get_task_required_inputs(task_config)
@@ -527,7 +547,10 @@ def _update(
527547
**kwargs: Dict[str, Any],
528548
) -> None:
529549
with torch.no_grad():
530-
if self._compute_mode == RecComputeMode.FUSED_TASKS_COMPUTATION:
550+
if self._compute_mode in [
551+
RecComputeMode.FUSED_TASKS_COMPUTATION,
552+
RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION,
553+
]:
531554
task_names = [task.name for task in self._tasks]
532555

533556
if not isinstance(predictions, torch.Tensor):

torchrec/metrics/recall_session.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,10 @@ def __init__(
235235
process_group: Optional[dist.ProcessGroup] = None,
236236
**kwargs: Any,
237237
) -> None:
238-
if compute_mode == RecComputeMode.FUSED_TASKS_COMPUTATION:
238+
if compute_mode in [
239+
RecComputeMode.FUSED_TASKS_COMPUTATION,
240+
RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION,
241+
]:
239242
raise RecMetricException(
240243
"Fused computation is not supported for recall session-level metrics"
241244
)

torchrec/metrics/test_utils/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,10 @@ def get_target_rec_metric_value(
291291
labels, predictions, weights, _ = parse_task_model_outputs(
292292
tasks, model_outs[i]
293293
)
294-
if target_compute_mode == RecComputeMode.FUSED_TASKS_COMPUTATION:
294+
if target_compute_mode in [
295+
RecComputeMode.FUSED_TASKS_COMPUTATION,
296+
RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION,
297+
]:
295298
labels = torch.stack(list(labels.values()))
296299
predictions = torch.stack(list(predictions.values()))
297300
weights = torch.stack(list(weights.values()))

torchrec/metrics/tests/test_accuracy.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ class AccuracyMetricTest(unittest.TestCase):
5252
clazz: Type[RecMetric] = AccuracyMetric
5353
task_name: str = "accuracy"
5454

55-
def test_unfused_accuracy(self) -> None:
55+
def test_accuracy_unfused(self) -> None:
5656
rec_metric_value_test_launcher(
5757
target_clazz=AccuracyMetric,
5858
target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION,
@@ -66,7 +66,7 @@ def test_unfused_accuracy(self) -> None:
6666
entry_point=metric_test_helper,
6767
)
6868

69-
def test_fused_accuracy(self) -> None:
69+
def test_accuracy_fused_tasks(self) -> None:
7070
rec_metric_value_test_launcher(
7171
target_clazz=AccuracyMetric,
7272
target_compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION,
@@ -80,6 +80,20 @@ def test_fused_accuracy(self) -> None:
8080
entry_point=metric_test_helper,
8181
)
8282

83+
def test_accuracy_fused_tasks_and_states(self) -> None:
84+
rec_metric_value_test_launcher(
85+
target_clazz=AccuracyMetric,
86+
target_compute_mode=RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION,
87+
test_clazz=TestAccuracyMetric,
88+
metric_name=AccuracyMetricTest.task_name,
89+
task_names=["t1", "t2", "t3"],
90+
fused_update_limit=0,
91+
compute_on_all_ranks=False,
92+
should_validate_update=False,
93+
world_size=WORLD_SIZE,
94+
entry_point=metric_test_helper,
95+
)
96+
8397

8498
class AccuracyMetricValueTest(unittest.TestCase):
8599
r"""This set of tests verify the computation logic of accuracy in several

torchrec/metrics/tests/test_cali_free_ne.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def test_cali_free_ne_unfused(self) -> None:
8989
entry_point=metric_test_helper,
9090
)
9191

92-
def test_cali_free_ne_fused(self) -> None:
92+
def test_cali_free_ne_fused_tasks(self) -> None:
9393
rec_metric_value_test_launcher(
9494
target_clazz=CaliFreeNEMetric,
9595
target_compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION,
@@ -103,7 +103,21 @@ def test_cali_free_ne_fused(self) -> None:
103103
entry_point=metric_test_helper,
104104
)
105105

106-
def test_cali_free_ne_update_fused(self) -> None:
106+
def test_cali_free_ne_fused_tasks_and_states(self) -> None:
107+
rec_metric_value_test_launcher(
108+
target_clazz=CaliFreeNEMetric,
109+
target_compute_mode=RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION,
110+
test_clazz=TestCaliFreeNEMetric,
111+
metric_name=CaliFreeNEMetricTest.task_name,
112+
task_names=["t1", "t2", "t3"],
113+
fused_update_limit=0,
114+
compute_on_all_ranks=False,
115+
should_validate_update=False,
116+
world_size=WORLD_SIZE,
117+
entry_point=metric_test_helper,
118+
)
119+
120+
def test_cali_free_ne_update_unfused(self) -> None:
107121
rec_metric_value_test_launcher(
108122
target_clazz=CaliFreeNEMetric,
109123
target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION,

torchrec/metrics/tests/test_calibration.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ class CalibrationMetricTest(unittest.TestCase):
5252
clazz: Type[RecMetric] = CalibrationMetric
5353
task_name: str = "calibration"
5454

55-
def test_unfused_calibration(self) -> None:
55+
def test_calibration_unfused(self) -> None:
5656
rec_metric_value_test_launcher(
5757
target_clazz=CalibrationMetric,
5858
target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION,
@@ -66,7 +66,7 @@ def test_unfused_calibration(self) -> None:
6666
entry_point=metric_test_helper,
6767
)
6868

69-
def test_fused_calibration(self) -> None:
69+
def test_calibration_fused_tasks(self) -> None:
7070
rec_metric_value_test_launcher(
7171
target_clazz=CalibrationMetric,
7272
target_compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION,
@@ -80,6 +80,20 @@ def test_fused_calibration(self) -> None:
8080
entry_point=metric_test_helper,
8181
)
8282

83+
def test_calibration_fused_tasks_and_states(self) -> None:
84+
rec_metric_value_test_launcher(
85+
target_clazz=CalibrationMetric,
86+
target_compute_mode=RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION,
87+
test_clazz=TestCalibrationMetric,
88+
metric_name=CalibrationMetricTest.task_name,
89+
task_names=["t1", "t2", "t3"],
90+
fused_update_limit=0,
91+
compute_on_all_ranks=False,
92+
should_validate_update=False,
93+
world_size=WORLD_SIZE,
94+
entry_point=metric_test_helper,
95+
)
96+
8397

8498
class CalibrationGPUSyncTest(unittest.TestCase):
8599
clazz: Type[RecMetric] = CalibrationMetric

torchrec/metrics/tests/test_ctr.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ class CTRMetricTest(unittest.TestCase):
4646
clazz: Type[RecMetric] = CTRMetric
4747
task_name: str = "ctr"
4848

49-
def test_unfused_ctr(self) -> None:
49+
def test_ctr_unfused(self) -> None:
5050
rec_metric_value_test_launcher(
5151
target_clazz=CTRMetric,
5252
target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION,
@@ -60,7 +60,7 @@ def test_unfused_ctr(self) -> None:
6060
entry_point=metric_test_helper,
6161
)
6262

63-
def test_fused_ctr(self) -> None:
63+
def test_ctr_fused_tasks(self) -> None:
6464
rec_metric_value_test_launcher(
6565
target_clazz=CTRMetric,
6666
target_compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION,
@@ -74,6 +74,20 @@ def test_fused_ctr(self) -> None:
7474
entry_point=metric_test_helper,
7575
)
7676

77+
def test_ctr_fused_tasks_and_states(self) -> None:
78+
rec_metric_value_test_launcher(
79+
target_clazz=CTRMetric,
80+
target_compute_mode=RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION,
81+
test_clazz=TestCTRMetric,
82+
metric_name=CTRMetricTest.task_name,
83+
task_names=["t1", "t2", "t3"],
84+
fused_update_limit=0,
85+
compute_on_all_ranks=False,
86+
should_validate_update=False,
87+
world_size=WORLD_SIZE,
88+
entry_point=metric_test_helper,
89+
)
90+
7791

7892
class CTRGPUSyncTest(unittest.TestCase):
7993
clazz: Type[RecMetric] = CTRMetric

torchrec/metrics/tests/test_hindsight_target_pr.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ class TestHindsightTargetPRMetricTest(unittest.TestCase):
126126
precision_task_name: str = "hindsight_target_precision"
127127
recall_task_name: str = "hindsight_target_recall"
128128

129-
def test_unfused_hindsight_target_precision(self) -> None:
129+
def test_hindsight_target_precision_unfused(self) -> None:
130130
rec_metric_value_test_launcher(
131131
target_clazz=HindsightTargetPRMetric,
132132
target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION,
@@ -140,7 +140,7 @@ def test_unfused_hindsight_target_precision(self) -> None:
140140
entry_point=metric_test_helper,
141141
)
142142

143-
def test_unfused_hindsight_target_recall(self) -> None:
143+
def test_hindsight_target_recall_unfused(self) -> None:
144144
rec_metric_value_test_launcher(
145145
target_clazz=HindsightTargetPRMetric,
146146
target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION,

torchrec/metrics/tests/test_mse.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ class MSEMetricTest(unittest.TestCase):
7070
task_name: str = "mse"
7171
rmse_task_name: str = "rmse"
7272

73-
def test_unfused_mse(self) -> None:
73+
def test_mse_unfused(self) -> None:
7474
rec_metric_value_test_launcher(
7575
target_clazz=MSEMetric,
7676
target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION,
@@ -84,7 +84,7 @@ def test_unfused_mse(self) -> None:
8484
entry_point=metric_test_helper,
8585
)
8686

87-
def test_fused_mse(self) -> None:
87+
def test_mse_fused_tasks(self) -> None:
8888
rec_metric_value_test_launcher(
8989
target_clazz=MSEMetric,
9090
target_compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION,
@@ -98,7 +98,21 @@ def test_fused_mse(self) -> None:
9898
entry_point=metric_test_helper,
9999
)
100100

101-
def test_unfused_rmse(self) -> None:
101+
def test_mse_fused_tasks_and_states(self) -> None:
102+
rec_metric_value_test_launcher(
103+
target_clazz=MSEMetric,
104+
target_compute_mode=RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION,
105+
test_clazz=TestMSEMetric,
106+
metric_name=MSEMetricTest.task_name,
107+
task_names=["t1", "t2", "t3"],
108+
fused_update_limit=0,
109+
compute_on_all_ranks=False,
110+
should_validate_update=False,
111+
world_size=WORLD_SIZE,
112+
entry_point=metric_test_helper,
113+
)
114+
115+
def test_rmse_unfused(self) -> None:
102116
rec_metric_value_test_launcher(
103117
target_clazz=MSEMetric,
104118
target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION,
@@ -112,7 +126,7 @@ def test_unfused_rmse(self) -> None:
112126
entry_point=metric_test_helper,
113127
)
114128

115-
def test_fused_rmse(self) -> None:
129+
def test_rmse_fused_tasks(self) -> None:
116130
rec_metric_value_test_launcher(
117131
target_clazz=MSEMetric,
118132
target_compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION,
@@ -126,6 +140,20 @@ def test_fused_rmse(self) -> None:
126140
entry_point=metric_test_helper,
127141
)
128142

143+
def test_rmse_fused_tasks_and_states(self) -> None:
144+
rec_metric_value_test_launcher(
145+
target_clazz=MSEMetric,
146+
target_compute_mode=RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION,
147+
test_clazz=TestRMSEMetric,
148+
metric_name=MSEMetricTest.rmse_task_name,
149+
task_names=["t1", "t2", "t3"],
150+
fused_update_limit=0,
151+
compute_on_all_ranks=False,
152+
should_validate_update=False,
153+
world_size=WORLD_SIZE,
154+
entry_point=metric_test_helper,
155+
)
156+
129157

130158
class MSEGPUSyncTest(unittest.TestCase):
131159
clazz: Type[RecMetric] = MSEMetric

0 commit comments

Comments
 (0)