Skip to content

Back out "Fuse "states" level tensors to reduce all gather during metrics compute" #2893

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion torchrec/metrics/metrics_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ class RecComputeMode(Enum):

FUSED_TASKS_COMPUTATION = 1
UNFUSED_TASKS_COMPUTATION = 2
FUSED_TASKS_AND_STATES_COMPUTATION = 3


_DEFAULT_WINDOW_SIZE = 10_000_000
Expand Down
5 changes: 1 addition & 4 deletions torchrec/metrics/precision_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,10 +196,7 @@ def __init__(
process_group: Optional[dist.ProcessGroup] = None,
**kwargs: Any,
) -> None:
if compute_mode in [
RecComputeMode.FUSED_TASKS_COMPUTATION,
RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION,
]:
if compute_mode == RecComputeMode.FUSED_TASKS_COMPUTATION:
raise RecMetricException(
"Fused computation is not supported for precision session-level metrics"
)
Expand Down
47 changes: 12 additions & 35 deletions torchrec/metrics/rec_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,19 +134,13 @@ def __init__(
window_size: int,
compute_on_all_ranks: bool = False,
should_validate_update: bool = False,
fuse_state_tensors: bool = False,
process_group: Optional[dist.ProcessGroup] = None,
fused_update_limit: int = 0,
allow_missing_label_with_zero_weight: bool = False,
*args: Any,
**kwargs: Any,
) -> None:
super().__init__(
process_group=process_group,
fuse_state_tensors=fuse_state_tensors,
*args,
**kwargs,
)
super().__init__(process_group=process_group, *args, **kwargs)

self._my_rank = my_rank
self._n_tasks = n_tasks
Expand Down Expand Up @@ -347,11 +341,7 @@ def __init__(
# TODO(stellaya): consider to inherit from TorchMetrics.Metric or
# TorchMetrics.MetricCollection.
if (
compute_mode
in [
RecComputeMode.FUSED_TASKS_COMPUTATION,
RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION,
]
compute_mode == RecComputeMode.FUSED_TASKS_COMPUTATION
and fused_update_limit > 0
):
raise ValueError(
Expand Down Expand Up @@ -386,10 +376,7 @@ def __init__(
f"Local window size must be larger than batch size. Got local window size {self._window_size} and batch size {self._batch_size}."
)

if compute_mode in [
RecComputeMode.FUSED_TASKS_COMPUTATION,
RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION,
]:
if compute_mode == RecComputeMode.FUSED_TASKS_COMPUTATION:
task_per_metric = len(self._tasks)
self._tasks_iter = self._fused_tasks_iter
else:
Expand All @@ -398,11 +385,7 @@ def __init__(

for task_config in (
[self._tasks]
if compute_mode
in [
RecComputeMode.FUSED_TASKS_COMPUTATION,
RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION,
]
if compute_mode == RecComputeMode.FUSED_TASKS_COMPUTATION
else self._tasks
):
# pyre-ignore
Expand All @@ -411,16 +394,13 @@ def __init__(
# according to https://github.com/python/mypy/issues/3048.
# pyre-fixme[45]: Cannot instantiate abstract class `RecMetricCoputation`.
metric_computation = self._computation_class(
my_rank=my_rank,
batch_size=batch_size,
n_tasks=task_per_metric,
window_size=self._window_size,
compute_on_all_ranks=compute_on_all_ranks,
should_validate_update=self._should_validate_update,
fuse_state_tensors=(
compute_mode == RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION
),
process_group=process_group,
my_rank,
batch_size,
task_per_metric,
self._window_size,
compute_on_all_ranks,
self._should_validate_update,
process_group,
**{**kwargs, **self._get_task_kwargs(task_config)},
)
required_inputs = self._get_task_required_inputs(task_config)
Expand Down Expand Up @@ -547,10 +527,7 @@ def _update(
**kwargs: Dict[str, Any],
) -> None:
with torch.no_grad():
if self._compute_mode in [
RecComputeMode.FUSED_TASKS_COMPUTATION,
RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION,
]:
if self._compute_mode == RecComputeMode.FUSED_TASKS_COMPUTATION:
task_names = [task.name for task in self._tasks]

if not isinstance(predictions, torch.Tensor):
Expand Down
5 changes: 1 addition & 4 deletions torchrec/metrics/recall_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,10 +235,7 @@ def __init__(
process_group: Optional[dist.ProcessGroup] = None,
**kwargs: Any,
) -> None:
if compute_mode in [
RecComputeMode.FUSED_TASKS_COMPUTATION,
RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION,
]:
if compute_mode == RecComputeMode.FUSED_TASKS_COMPUTATION:
raise RecMetricException(
"Fused computation is not supported for recall session-level metrics"
)
Expand Down
5 changes: 1 addition & 4 deletions torchrec/metrics/test_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,10 +291,7 @@ def get_target_rec_metric_value(
labels, predictions, weights, _ = parse_task_model_outputs(
tasks, model_outs[i]
)
if target_compute_mode in [
RecComputeMode.FUSED_TASKS_COMPUTATION,
RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION,
]:
if target_compute_mode == RecComputeMode.FUSED_TASKS_COMPUTATION:
labels = torch.stack(list(labels.values()))
predictions = torch.stack(list(predictions.values()))
weights = torch.stack(list(weights.values()))
Expand Down
18 changes: 2 additions & 16 deletions torchrec/metrics/tests/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class AccuracyMetricTest(unittest.TestCase):
clazz: Type[RecMetric] = AccuracyMetric
task_name: str = "accuracy"

def test_accuracy_unfused(self) -> None:
def test_unfused_accuracy(self) -> None:
rec_metric_value_test_launcher(
target_clazz=AccuracyMetric,
target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION,
Expand All @@ -66,7 +66,7 @@ def test_accuracy_unfused(self) -> None:
entry_point=metric_test_helper,
)

def test_accuracy_fused_tasks(self) -> None:
def test_fused_accuracy(self) -> None:
rec_metric_value_test_launcher(
target_clazz=AccuracyMetric,
target_compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION,
Expand All @@ -80,20 +80,6 @@ def test_accuracy_fused_tasks(self) -> None:
entry_point=metric_test_helper,
)

def test_accuracy_fused_tasks_and_states(self) -> None:
rec_metric_value_test_launcher(
target_clazz=AccuracyMetric,
target_compute_mode=RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION,
test_clazz=TestAccuracyMetric,
metric_name=AccuracyMetricTest.task_name,
task_names=["t1", "t2", "t3"],
fused_update_limit=0,
compute_on_all_ranks=False,
should_validate_update=False,
world_size=WORLD_SIZE,
entry_point=metric_test_helper,
)


class AccuracyMetricValueTest(unittest.TestCase):
r"""This set of tests verify the computation logic of accuracy in several
Expand Down
18 changes: 2 additions & 16 deletions torchrec/metrics/tests/test_cali_free_ne.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def test_cali_free_ne_unfused(self) -> None:
entry_point=metric_test_helper,
)

def test_cali_free_ne_fused_tasks(self) -> None:
def test_cali_free_ne_fused(self) -> None:
rec_metric_value_test_launcher(
target_clazz=CaliFreeNEMetric,
target_compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION,
Expand All @@ -103,21 +103,7 @@ def test_cali_free_ne_fused_tasks(self) -> None:
entry_point=metric_test_helper,
)

def test_cali_free_ne_fused_tasks_and_states(self) -> None:
rec_metric_value_test_launcher(
target_clazz=CaliFreeNEMetric,
target_compute_mode=RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION,
test_clazz=TestCaliFreeNEMetric,
metric_name=CaliFreeNEMetricTest.task_name,
task_names=["t1", "t2", "t3"],
fused_update_limit=0,
compute_on_all_ranks=False,
should_validate_update=False,
world_size=WORLD_SIZE,
entry_point=metric_test_helper,
)

def test_cali_free_ne_update_unfused(self) -> None:
def test_cali_free_ne_update_fused(self) -> None:
rec_metric_value_test_launcher(
target_clazz=CaliFreeNEMetric,
target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION,
Expand Down
18 changes: 2 additions & 16 deletions torchrec/metrics/tests/test_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class CalibrationMetricTest(unittest.TestCase):
clazz: Type[RecMetric] = CalibrationMetric
task_name: str = "calibration"

def test_calibration_unfused(self) -> None:
def test_unfused_calibration(self) -> None:
rec_metric_value_test_launcher(
target_clazz=CalibrationMetric,
target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION,
Expand All @@ -66,7 +66,7 @@ def test_calibration_unfused(self) -> None:
entry_point=metric_test_helper,
)

def test_calibration_fused_tasks(self) -> None:
def test_fused_calibration(self) -> None:
rec_metric_value_test_launcher(
target_clazz=CalibrationMetric,
target_compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION,
Expand All @@ -80,20 +80,6 @@ def test_calibration_fused_tasks(self) -> None:
entry_point=metric_test_helper,
)

def test_calibration_fused_tasks_and_states(self) -> None:
rec_metric_value_test_launcher(
target_clazz=CalibrationMetric,
target_compute_mode=RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION,
test_clazz=TestCalibrationMetric,
metric_name=CalibrationMetricTest.task_name,
task_names=["t1", "t2", "t3"],
fused_update_limit=0,
compute_on_all_ranks=False,
should_validate_update=False,
world_size=WORLD_SIZE,
entry_point=metric_test_helper,
)


class CalibrationGPUSyncTest(unittest.TestCase):
clazz: Type[RecMetric] = CalibrationMetric
Expand Down
18 changes: 2 additions & 16 deletions torchrec/metrics/tests/test_ctr.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class CTRMetricTest(unittest.TestCase):
clazz: Type[RecMetric] = CTRMetric
task_name: str = "ctr"

def test_ctr_unfused(self) -> None:
def test_unfused_ctr(self) -> None:
rec_metric_value_test_launcher(
target_clazz=CTRMetric,
target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION,
Expand All @@ -60,7 +60,7 @@ def test_ctr_unfused(self) -> None:
entry_point=metric_test_helper,
)

def test_ctr_fused_tasks(self) -> None:
def test_fused_ctr(self) -> None:
rec_metric_value_test_launcher(
target_clazz=CTRMetric,
target_compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION,
Expand All @@ -74,20 +74,6 @@ def test_ctr_fused_tasks(self) -> None:
entry_point=metric_test_helper,
)

def test_ctr_fused_tasks_and_states(self) -> None:
rec_metric_value_test_launcher(
target_clazz=CTRMetric,
target_compute_mode=RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION,
test_clazz=TestCTRMetric,
metric_name=CTRMetricTest.task_name,
task_names=["t1", "t2", "t3"],
fused_update_limit=0,
compute_on_all_ranks=False,
should_validate_update=False,
world_size=WORLD_SIZE,
entry_point=metric_test_helper,
)


class CTRGPUSyncTest(unittest.TestCase):
clazz: Type[RecMetric] = CTRMetric
Expand Down
4 changes: 2 additions & 2 deletions torchrec/metrics/tests/test_hindsight_target_pr.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ class TestHindsightTargetPRMetricTest(unittest.TestCase):
precision_task_name: str = "hindsight_target_precision"
recall_task_name: str = "hindsight_target_recall"

def test_hindsight_target_precision_unfused(self) -> None:
def test_unfused_hindsight_target_precision(self) -> None:
rec_metric_value_test_launcher(
target_clazz=HindsightTargetPRMetric,
target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION,
Expand All @@ -140,7 +140,7 @@ def test_hindsight_target_precision_unfused(self) -> None:
entry_point=metric_test_helper,
)

def test_hindsight_target_recall_unfused(self) -> None:
def test_unfused_hindsight_target_recall(self) -> None:
rec_metric_value_test_launcher(
target_clazz=HindsightTargetPRMetric,
target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION,
Expand Down
36 changes: 4 additions & 32 deletions torchrec/metrics/tests/test_mse.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class MSEMetricTest(unittest.TestCase):
task_name: str = "mse"
rmse_task_name: str = "rmse"

def test_mse_unfused(self) -> None:
def test_unfused_mse(self) -> None:
rec_metric_value_test_launcher(
target_clazz=MSEMetric,
target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION,
Expand All @@ -84,7 +84,7 @@ def test_mse_unfused(self) -> None:
entry_point=metric_test_helper,
)

def test_mse_fused_tasks(self) -> None:
def test_fused_mse(self) -> None:
rec_metric_value_test_launcher(
target_clazz=MSEMetric,
target_compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION,
Expand All @@ -98,21 +98,7 @@ def test_mse_fused_tasks(self) -> None:
entry_point=metric_test_helper,
)

def test_mse_fused_tasks_and_states(self) -> None:
rec_metric_value_test_launcher(
target_clazz=MSEMetric,
target_compute_mode=RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION,
test_clazz=TestMSEMetric,
metric_name=MSEMetricTest.task_name,
task_names=["t1", "t2", "t3"],
fused_update_limit=0,
compute_on_all_ranks=False,
should_validate_update=False,
world_size=WORLD_SIZE,
entry_point=metric_test_helper,
)

def test_rmse_unfused(self) -> None:
def test_unfused_rmse(self) -> None:
rec_metric_value_test_launcher(
target_clazz=MSEMetric,
target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION,
Expand All @@ -126,7 +112,7 @@ def test_rmse_unfused(self) -> None:
entry_point=metric_test_helper,
)

def test_rmse_fused_tasks(self) -> None:
def test_fused_rmse(self) -> None:
rec_metric_value_test_launcher(
target_clazz=MSEMetric,
target_compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION,
Expand All @@ -140,20 +126,6 @@ def test_rmse_fused_tasks(self) -> None:
entry_point=metric_test_helper,
)

def test_rmse_fused_tasks_and_states(self) -> None:
rec_metric_value_test_launcher(
target_clazz=MSEMetric,
target_compute_mode=RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION,
test_clazz=TestRMSEMetric,
metric_name=MSEMetricTest.rmse_task_name,
task_names=["t1", "t2", "t3"],
fused_update_limit=0,
compute_on_all_ranks=False,
should_validate_update=False,
world_size=WORLD_SIZE,
entry_point=metric_test_helper,
)


class MSEGPUSyncTest(unittest.TestCase):
clazz: Type[RecMetric] = MSEMetric
Expand Down
Loading
Loading