From 5391139396300a77c8de5da4cc8b5d11ffd68f80 Mon Sep 17 00:00:00 2001 From: Evgenii Kolpakov Date: Thu, 8 May 2025 01:06:12 -0700 Subject: [PATCH] Add context manager to use next batch context for postprocs (#2939) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2939 Small refactor to reduce code repetition of setting and reverting pipelined postprocs context to the next batch's context Differential Revision: D73824600 --- .../train_pipeline/train_pipelines.py | 40 ++++++------------- torchrec/distributed/train_pipeline/utils.py | 25 ++++++++++++ 2 files changed, 37 insertions(+), 28 deletions(-) diff --git a/torchrec/distributed/train_pipeline/train_pipelines.py b/torchrec/distributed/train_pipeline/train_pipelines.py index 3a8073d10..1268190fa 100644 --- a/torchrec/distributed/train_pipeline/train_pipelines.py +++ b/torchrec/distributed/train_pipeline/train_pipelines.py @@ -58,6 +58,7 @@ StageOut, StageOutputWithEvent, TrainPipelineContext, + use_context_for_postprocs, ) from torchrec.distributed.types import Awaitable from torchrec.pt2.checks import is_torchdynamo_compiling @@ -792,19 +793,9 @@ def start_sparse_data_dist( with self._stream_context(self._data_dist_stream): _wait_for_batch(batch, self._memcpy_stream) - original_contexts = [p.get_context() for p in self._pipelined_postprocs] - # Temporarily set context for next iter to populate cache - for postproc_mod in self._pipelined_postprocs: - postproc_mod.set_context(context) - - _start_data_dist(self._pipelined_modules, batch, context) - - # Restore context for model fwd - for module, context in zip( - self._pipelined_postprocs, original_contexts - ): - module.set_context(context) + with use_context_for_postprocs(self._pipelined_postprocs, context): + _start_data_dist(self._pipelined_modules, batch, context) def wait_sparse_data_dist(self, context: TrainPipelineContext) -> None: """ @@ -1325,22 +1316,15 @@ def start_sparse_data_dist( return # Temporarily set context for next iter to populate cache - original_contexts = [p.get_context() for p in self._pipelined_postprocs] - for postproc_mod in self._pipelined_postprocs: - postproc_mod.set_context(context) - - with record_function(f"## start_sparse_data_dist {context.index} ##"): - with self._stream_context(self._data_dist_stream): - _wait_for_events(batch, context, self._data_dist_stream) - model_input = self.extract_model_input_from_batch(batch) - _start_data_dist(self._pipelined_modules, model_input, context) - event = torch.get_device_module(self._device).Event() - event.record() - context.events.append(event) - - # Restore context for model forward - for module, context in zip(self._pipelined_postprocs, original_contexts): - module.set_context(context) + with use_context_for_postprocs(self._pipelined_postprocs, context): + with record_function(f"## start_sparse_data_dist {context.index} ##"): + with self._stream_context(self._data_dist_stream): + _wait_for_events(batch, context, self._data_dist_stream) + model_input = self.extract_model_input_from_batch(batch) + _start_data_dist(self._pipelined_modules, model_input, context) + event = torch.get_device_module(self._device).Event() + event.record() + context.events.append(event) def start_embedding_lookup( self, diff --git a/torchrec/distributed/train_pipeline/utils.py b/torchrec/distributed/train_pipeline/utils.py index 2df5bc3c3..6bfeccb48 100644 --- a/torchrec/distributed/train_pipeline/utils.py +++ b/torchrec/distributed/train_pipeline/utils.py @@ -7,6 +7,7 @@ # pyre-strict +import contextlib import copy import itertools import logging @@ -21,6 +22,7 @@ Callable, cast, Dict, + Generator, Generic, Iterable, Iterator, @@ -1797,6 +1799,28 @@ def _prefetch_embeddings( return data_per_sharded_module +@contextlib.contextmanager +def use_context_for_postprocs( + pipelined_postprocs: List[PipelinedPostproc], + next_batch_context: TrainPipelineContext, +) -> Generator[None, None, None]: + """ + Temporarily set pipelined postproc context for next iter to populate cache. + """ + # Save original context for model fwd + original_contexts = [p.get_context() for p in pipelined_postprocs] + + # Temporarily set context for next iter to populate cache + for postproc_mod in pipelined_postprocs: + postproc_mod.set_context(next_batch_context) + + yield + + # Restore context for model fwd + for module, context in zip(pipelined_postprocs, original_contexts): + module.set_context(context) + + class SparseDataDistUtil(Generic[In]): """ Helper class exposing methods for sparse data dist and prefetch pipelining. @@ -1808,6 +1832,7 @@ class SparseDataDistUtil(Generic[In]): apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules. prefetch_stream (Optional[torch.cuda.Stream]): Stream on which model prefetch runs Defaults to `None`. This needs to be passed in to enable prefetch pipelining. + pipeline_postproc (bool): whether to pipeline postproc modules. Defaults to `False`. Example:: sdd = SparseDataDistUtil(