diff --git a/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py b/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py index 3b2b94196..9e14fd89d 100644 --- a/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py +++ b/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py @@ -2155,7 +2155,8 @@ def gpu_postproc(x: StageOut) -> StageOut: name="start_sparse_data_dist", runnable=sdd.start_sparse_data_dist, stream=sdd.data_dist_stream, - fill_callback=sdd.wait_sparse_data_dist, + fill_callback=sdd.fill_callback, + data_exhausted_callback=sdd.data_exhausted_callback, ), ] pipeline = StagedTrainPipeline( @@ -2222,7 +2223,8 @@ def gpu_postproc(x: StageOut) -> StageOut: name="start_sparse_data_dist", runnable=sdd.start_sparse_data_dist, stream=sdd.data_dist_stream, - fill_callback=sdd.wait_sparse_data_dist, + fill_callback=sdd.fill_callback, + data_exhausted_callback=sdd.data_exhausted_callback, ), ] @@ -2333,7 +2335,8 @@ def test_model_detach(self) -> None: name="start_sparse_data_dist", runnable=sdd.start_sparse_data_dist, stream=sdd.data_dist_stream, - fill_callback=sdd.wait_sparse_data_dist, + fill_callback=sdd.fill_callback, + data_exhausted_callback=sdd.data_exhausted_callback, ), ] @@ -2526,7 +2529,8 @@ def gpu_postproc(x: StageOut) -> StageOut: name="start_sparse_data_dist", runnable=sdd.start_sparse_data_dist, stream=sdd.data_dist_stream, - fill_callback=sdd.wait_sparse_data_dist, + fill_callback=sdd.fill_callback, + data_exhausted_callback=sdd.data_exhausted_callback, ), PipelineStage( name="prefetch", diff --git a/torchrec/distributed/train_pipeline/train_pipelines.py b/torchrec/distributed/train_pipeline/train_pipelines.py index 3a8073d10..08e2dd4b9 100644 --- a/torchrec/distributed/train_pipeline/train_pipelines.py +++ b/torchrec/distributed/train_pipeline/train_pipelines.py @@ -58,6 +58,7 @@ StageOut, StageOutputWithEvent, TrainPipelineContext, + use_context_for_postprocs, ) from torchrec.distributed.types import Awaitable from torchrec.pt2.checks import is_torchdynamo_compiling @@ -792,19 +793,9 @@ def start_sparse_data_dist( with self._stream_context(self._data_dist_stream): _wait_for_batch(batch, self._memcpy_stream) - original_contexts = [p.get_context() for p in self._pipelined_postprocs] - # Temporarily set context for next iter to populate cache - for postproc_mod in self._pipelined_postprocs: - postproc_mod.set_context(context) - - _start_data_dist(self._pipelined_modules, batch, context) - - # Restore context for model fwd - for module, context in zip( - self._pipelined_postprocs, original_contexts - ): - module.set_context(context) + with use_context_for_postprocs(self._pipelined_postprocs, context): + _start_data_dist(self._pipelined_modules, batch, context) def wait_sparse_data_dist(self, context: TrainPipelineContext) -> None: """ @@ -1325,22 +1316,15 @@ def start_sparse_data_dist( return # Temporarily set context for next iter to populate cache - original_contexts = [p.get_context() for p in self._pipelined_postprocs] - for postproc_mod in self._pipelined_postprocs: - postproc_mod.set_context(context) - - with record_function(f"## start_sparse_data_dist {context.index} ##"): - with self._stream_context(self._data_dist_stream): - _wait_for_events(batch, context, self._data_dist_stream) - model_input = self.extract_model_input_from_batch(batch) - _start_data_dist(self._pipelined_modules, model_input, context) - event = torch.get_device_module(self._device).Event() - event.record() - context.events.append(event) - - # Restore context for model forward - for module, context in zip(self._pipelined_postprocs, original_contexts): - module.set_context(context) + with use_context_for_postprocs(self._pipelined_postprocs, context): + with record_function(f"## start_sparse_data_dist {context.index} ##"): + with self._stream_context(self._data_dist_stream): + _wait_for_events(batch, context, self._data_dist_stream) + model_input = self.extract_model_input_from_batch(batch) + _start_data_dist(self._pipelined_modules, model_input, context) + event = torch.get_device_module(self._device).Event() + event.record() + context.events.append(event) def start_embedding_lookup( self, @@ -1727,8 +1711,6 @@ def _run_with_event( inputs: Optional[In], stream: torch.Stream, ) -> StageOutputWithEvent: - if inputs is None: - return (None, None) with self._stream_context(stream): # If there is no previous event, data is entering the pipeline if event is not None: @@ -1783,12 +1765,19 @@ def _run_stage( assert batch_to_wait_with_event is not None batch_to_wait, event = batch_to_wait_with_event - new_result = self._run_with_event( - runnable=stage.runnable, - event=event, - inputs=batch_to_wait, - stream=stage.stream, - ) + if batch_to_wait is not None: + new_result = self._run_with_event( + runnable=stage.runnable, + event=event, + inputs=batch_to_wait, + stream=stage.stream, + ) + else: + new_result = (None, None) + if ( + data_exhausted_callback := stage.data_exhausted_callback + ) is not None: + data_exhausted_callback() self._stage_outputs[batch_offset] = new_result if self._debug_mode: diff --git a/torchrec/distributed/train_pipeline/utils.py b/torchrec/distributed/train_pipeline/utils.py index 2df5bc3c3..b49f5b89e 100644 --- a/torchrec/distributed/train_pipeline/utils.py +++ b/torchrec/distributed/train_pipeline/utils.py @@ -7,10 +7,11 @@ # pyre-strict +import contextlib import copy import itertools import logging -from collections import defaultdict, OrderedDict +from collections import defaultdict, deque, OrderedDict from contextlib import AbstractContextManager from dataclasses import dataclass, field @@ -20,7 +21,9 @@ Any, Callable, cast, + Deque, Dict, + Generator, Generic, Iterable, Iterator, @@ -35,6 +38,7 @@ import torch from torch import distributed as dist +from torch.utils.hooks import RemovableHandle from torchrec.distributed.types import LazyAwaitable if not torch._running_with_deploy(): @@ -167,12 +171,17 @@ class PipelineStage: stream (torch.cuda.streams.Stream): Stream to run on. Often each stage has a unique stream, but having different pipelines share a stream provides more synchronization semantics. + fill_callback (Optional[Callable[[], None]])) - optional step to run after the main + runnable during filling the pipeline + data_exhausted_callback (Optional[Callable[[], None]])) - optional callback to run + when data is ehxausted """ name: str runnable: RunnableType stream: torch.Stream fill_callback: Optional[Callable[[], None]] = None + data_exhausted_callback: Optional[Callable[[], None]] = None @dataclass @@ -1762,7 +1771,7 @@ def _prefetch_embeddings( assert forward._name in context.input_dist_tensors_requests request = context.input_dist_tensors_requests.pop(forward._name) assert isinstance(request, Awaitable) - with record_function("## wait_sparse_data_dist ##"): + with record_function(f"## _prefetch_embeddings {context.index} ##"): # Finish waiting on the dist_stream, # in case some delayed stream scheduling happens during the wait() call. with stream_context(data_dist_stream): @@ -1797,6 +1806,28 @@ def _prefetch_embeddings( return data_per_sharded_module +@contextlib.contextmanager +def use_context_for_postprocs( + pipelined_postprocs: List[PipelinedPostproc], + next_batch_context: TrainPipelineContext, +) -> Generator[None, None, None]: + """ + Temporarily set pipelined postproc context for next iter to populate cache. + """ + # Save original context for model fwd + original_contexts = [p.get_context() for p in pipelined_postprocs] + + # Temporarily set context for next iter to populate cache + for postproc_mod in pipelined_postprocs: + postproc_mod.set_context(next_batch_context) + + yield + + # Restore context for model fwd + for module, context in zip(pipelined_postprocs, original_contexts): + module.set_context(context) + + class SparseDataDistUtil(Generic[In]): """ Helper class exposing methods for sparse data dist and prefetch pipelining. @@ -1808,6 +1839,7 @@ class SparseDataDistUtil(Generic[In]): apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules. prefetch_stream (Optional[torch.cuda.Stream]): Stream on which model prefetch runs Defaults to `None`. This needs to be passed in to enable prefetch pipelining. + pipeline_postproc (bool): whether to pipeline postproc modules. Defaults to `False`. Example:: sdd = SparseDataDistUtil( @@ -1840,6 +1872,11 @@ class SparseDataDistUtil(Generic[In]): return StagedTrainPipeline(pipeline_stages=pipeline) """ + _TRAIN_CONTEXT_VERSION = 1 + # Convenience flag to perform additional assertions on contexts + # to make sure contexts are advancing correctly. + _WITH_CONTEXT_ASSERTIONS = False + def __init__( self, model: torch.nn.Module, @@ -1850,20 +1887,24 @@ def __init__( super().__init__() self.model = model self.data_dist_stream = data_dist_stream - self.prefetch_stream = prefetch_stream self.apply_jit = apply_jit - self.context = ( - PrefetchTrainPipelineContext(version=0) - if prefetch_stream - else TrainPipelineContext(version=0) - ) + self.prefetch_stream = prefetch_stream + self._next_index: int = 0 + self._contexts: Deque[TrainPipelineContext] = deque() self.initialized = False self._pipelined_modules: List[ShardedModule] = [] self._pipelined_postprocs: List[PipelinedPostproc] = [] - # pyre-ignore - self.fwd_hook = None + self.fwd_hook: Optional[RemovableHandle] = None self._device: torch.device = data_dist_stream.device + self._stream_context: Callable[ + [Optional[torch.Stream]], torch.cuda.StreamContext + ] = ( + torch.get_device_module(self._device).stream + if self._device.type in ["cuda", "mtia"] + else torch.cuda.stream + ) + # pyre-ignore self._original_forwards: List[Callable[..., Any]] = [] self._original_kjt_dist_forwards: List[ @@ -1872,7 +1913,7 @@ def __init__( self._pipelined_forward: Type[BaseForward[TrainPipelineContext]] = cast( Type[BaseForward[TrainPipelineContext]], - (PrefetchPipelinedForward if prefetch_stream else PipelinedForward), + (PrefetchPipelinedForward if self._with_prefetch else PipelinedForward), ) self._default_stream: Optional[torch.Stream] = ( @@ -1881,6 +1922,128 @@ def __init__( else None ) + @property + def _with_prefetch(self) -> bool: + return self.prefetch_stream is not None + + # === Debugging helpers === # + @property + def _have_pipelined_modules(self) -> bool: + return len(self._pipelined_modules) > 0 + + @property + def _have_pipelined_postprocs(self) -> bool: + return len(self._pipelined_postprocs) > 0 + + def _pipelined_modules_fqns(self) -> Set[str]: + return {module.forward._name for module in self._pipelined_modules} + + def _pipelined_postprocs_fqns(self) -> Set[str]: + return {module._fqn for module in self._pipelined_postprocs} + + # === Debugging helpers === # + + # ==== Context management === # + # In short: version=1 contexts essentially represent "passing of time" + # and have one-to-one correspondence to batches. "Monolithic" torchrec pipelines + # (e.g. TrainPipelineSparseDist) explicitly manage batches and contexts together + # (see TrainPipelineSparseDist.enqueue_batch), however StagedTrainPipeline abstracts + # that away + supports stages that don't require contexts (in fact, SDD is the only one) + # So we just manage contexts and batches together in lockstep. `forward_hook` is critical + # for that, check the comment there. + # + # Essentially, StagedTrainPipeline during a single `progress` call runs each stage + # for a different batch, keeping the stage outputs in a `_stage_outputs` list, and + # advancing the list at the beginning of the `progress`. + # Tricky part is that SparseDataDistUtil might be participating in TWO stages: + # * "main" with start_data_dist -> wait_data_dist pair for `runnable` and `fill_callback` + # * "prefetch" with prefetch -> load_prefetch for `runnable` and `fill_callback` + # + # For this to work, we: + # (1) need to manage contexts in a lockstep with batch advancing through stages (_advance_context) + # (2) perform various actions (start dist, wait dist, etc.) against the correct contexts + # ("named" contexts below and how they are used in start/wait sparse_dist, prefetch, etc.) + # (3) set contexts for the _pipelined_modules and _pipelined_postprocs to the "current batch context" + # for the model to run correctly (_set_module_context) + # + # SDD Util (both with and without prefetch) gets away with just two contexts: + # * context[0] is always the "current batch" context - used for prefetch (here) and model forward (outside this class) + # * context[1] is always the "next batch" context - used for start/wait_sparse_data_dist + + def _create_context(self, index: int) -> TrainPipelineContext: + version = self._TRAIN_CONTEXT_VERSION + return ( + PrefetchTrainPipelineContext(index=index, version=version) + if self._with_prefetch + else TrainPipelineContext(index=index, version=version) + ) + + def _add_context(self) -> None: + self._contexts.append(self._create_context(self._next_index)) + self._next_index += 1 + + def _advance_context(self) -> None: + self._contexts.popleft() + self._add_context() + self._set_module_context(self._context_for_model_forward()) + + def _set_module_context(self, context: TrainPipelineContext) -> None: + for module in self._pipelined_modules: + module.forward.set_context(context) + + for postproc_module in self._pipelined_postprocs: + # This ensures that next iter model fwd uses cached results + postproc_module.set_context(context) + + # ====== "Named" contexts - to make it clearer which contexts are used for which operation ====== # + # This is purely convenience methods, feel free to remove if they get in the way + def _current_context(self) -> TrainPipelineContext: + return self._contexts[0] + + def _next_context(self) -> TrainPipelineContext: + return self._contexts[1] + + def _context_for_model_forward(self) -> TrainPipelineContext: + context = self._current_context() + if self._WITH_CONTEXT_ASSERTIONS and self.initialized: + target_fqns = self._pipelined_modules_fqns() + if self._with_prefetch: + assert isinstance(context, PrefetchTrainPipelineContext) + assert context.module_input_post_prefetch.keys() == target_fqns + assert context.module_contexts_post_prefetch.keys() == target_fqns + else: + assert context.input_dist_tensors_requests.keys() == target_fqns + assert context.module_contexts.keys() == target_fqns + return context + + def _start_dist_context(self) -> TrainPipelineContext: + return self._next_context() + + def _wait_dist_context(self) -> TrainPipelineContext: + # Note: see comment on the forward_hook in _initialize method + context = self._next_context() + if self._WITH_CONTEXT_ASSERTIONS and self.initialized: + if self._have_pipelined_modules: + assert ( + len(context.fused_splits_awaitables) > 0 + ), f"fused_splits_awaitables was empty on {context.index=} - was start_sparse_data_dist called?" + return context + + def _prefetch_context(self) -> PrefetchTrainPipelineContext: + ctx = self._current_context() + assert isinstance( + ctx, PrefetchTrainPipelineContext + ), "Pass prefetch_stream into SparseDataDistUtil to use prefetch_context()" + if self._WITH_CONTEXT_ASSERTIONS and self.initialized: + target_fqns = self._pipelined_modules_fqns() + assert ctx.input_dist_tensors_requests.keys() == target_fqns + assert ctx.module_contexts.keys() == target_fqns + return ctx + + # ====== End "Named" contexts ====== # + + # === End context management === # + def detach(self) -> torch.nn.Module: """ Removes sparse data dist (SDD) pipelining from model forward and input dist. @@ -1888,7 +2051,8 @@ def detach(self) -> torch.nn.Module: detach() can be called at any point, and inflight batches do not need to be flushed before calling it. Calling pipeline.progress() will re-attach the model - to the pipeline and the pipeline will progress normally from the point it was detached (i.e. inflight batches will be kept when calling detach). + to the pipeline and the pipeline will progress normally from the point it was + detached (i.e. inflight batches will be kept when calling detach). While the model is detached, it is equivalent to the model before passing to the pipeline, so forward and backward passes, and optimizer updates can be @@ -1909,112 +2073,141 @@ def detach(self) -> torch.nn.Module: self.initialized = False return self.model - def start_sparse_data_dist(self, batch: In) -> In: - if not self.initialized: - # Step 1: Pipeline input dist in trec sharded modules - # TODO (yhshin): support postproc modules for `StagedTrainPipeline` - ( - self._pipelined_modules, - self.model, - self._original_forwards, - self._pipelined_postprocs, - _, - ) = _rewrite_model( - model=self.model, - context=self.context, - dist_stream=self.data_dist_stream, - batch=batch, - apply_jit=self.apply_jit, - pipelined_forward=self._pipelined_forward, - default_stream=self._default_stream, - ) - # initializes input dist, so we can override input dist forwards - _start_data_dist(self._pipelined_modules, batch, self.context) - self._original_kjt_dist_forwards = _override_input_dist_forwards( - self._pipelined_modules - ) + def _initialize(self, batch: In) -> None: + # Step 0: initialize two initial contexts + self._contexts.append(self._create_context(-1)) # throwaway context + self._add_context() + # Step 1: Pipeline input dist in trec sharded modules + ( + self._pipelined_modules, + self.model, + self._original_forwards, + self._pipelined_postprocs, + _, + ) = _rewrite_model( + model=self.model, + context=self._next_context(), + dist_stream=self.data_dist_stream, + batch=batch, + apply_jit=self.apply_jit, + pipelined_forward=self._pipelined_forward, + default_stream=self._default_stream, + ) + # Setting the stage for the first batch + # initialize input dist + _start_data_dist(self._pipelined_modules, batch, self._start_dist_context()) + # so we can override input dist forwards + self._original_kjt_dist_forwards = _override_input_dist_forwards( + self._pipelined_modules + ) - # Step 2: Register post-forward hook to wait SDD - def forward_hook( - module: torch.nn.Module, - input: Union[torch.Tensor, Tuple[torch.Tensor]], - output: Union[torch.Tensor, Tuple[torch.Tensor]], - ) -> None: - if self.prefetch_stream is not None: - # Need to load prefetch before wait_sparse_data_dist - self.load_prefetch() - self.wait_sparse_data_dist() + # Step 2: Register post-forward hook to wait SDD and advance contexts + def forward_hook( + module: torch.nn.Module, + input: Union[torch.Tensor, Tuple[torch.Tensor]], + output: Union[torch.Tensor, Tuple[torch.Tensor]], + ) -> None: + # Note: tricky part - a bit delicate choreography between + # StagedPipeline and this class (see D59786807 for details) + # wait_dist need to be called as post_forward hook + # at the end of the batch N, so that the data is awaited + # before start of the next batch. + self.wait_sparse_data_dist() + # _advance_context should be called after wait_sparse_data_dist, + # but before start_data_dist for the next batch + # which means right here, and nowhere else + self._advance_context() + # ... this can be made more explicit by adding dedicated hooks for "batch start"/"batch end" events + # to the StagedPipeline, PipelineStage and this class, but hook seems to be doing an adequate job for now + + self.fwd_hook = self.model.register_forward_hook(forward_hook) + + self.initialized = True + + def fill_callback(self) -> None: + """ + Used by StagedTrainPipeline during only during initial pipeline filling. + + At that part, model.forward is not executed, so forward hook is not called. + """ + self.wait_sparse_data_dist() + self._advance_context() - self.fwd_hook = self.model.register_forward_hook(forward_hook) + def data_exhausted_callback(self) -> None: + """ + Called by StagedTrainPipeline when all batches were processed. + """ + if self.fwd_hook is not None: + self.fwd_hook.remove() + self.fwd_hook = None - self.initialized = True + def start_sparse_data_dist(self, batch: In) -> In: + if not self.initialized: + self._initialize(batch) - _start_data_dist(self._pipelined_modules, batch, self.context) + next_ctx = self._start_dist_context() + with record_function(f"## start_sparse_data_dist {next_ctx.index} ##"): + with use_context_for_postprocs(self._pipelined_postprocs, next_ctx): + _start_data_dist(self._pipelined_modules, batch, next_ctx) return batch def wait_sparse_data_dist(self) -> None: - with record_function("## wait_sparse_data_dist ##"): - with torch.get_device_module(self._device).stream(self.data_dist_stream): - self.context.module_contexts = ( - self.context.module_contexts_next_batch.copy() - ) - self.context.input_dist_tensors_requests.clear() - for names, awaitable in self.context.fused_splits_awaitables: + """ + Waits on the input dist splits requests to get the input dist tensors requests, + and populates the context with them. + """ + context = self._wait_dist_context() + with record_function(f"## wait_sparse_data_dist {context.index} ##"): + with self._stream_context(self.data_dist_stream): + for names, awaitable in context.fused_splits_awaitables: for name, request in zip(names, awaitable.wait()): - self.context.input_dist_tensors_requests[name] = request + context.input_dist_tensors_requests[name] = request + # these won't be used by the rest of the pipeline, so just deleting them to free + # the memory they occupy + context.input_dist_splits_requests.clear() + context.fused_splits_awaitables.clear() def prefetch(self, batch: In) -> In: """ Waits for input dist to finish, then prefetches data. """ assert isinstance( - self.context, PrefetchTrainPipelineContext + self._prefetch_context(), PrefetchTrainPipelineContext ), "Pass prefetch_stream into SparseDataDistUtil to use prefetch() as a stage" - self.context.module_input_post_prefetch_next_batch.clear() - # pyre-ignore - self.context.module_contexts_post_prefetch_next_batch.clear() + prefetch_context: PrefetchTrainPipelineContext = self._prefetch_context() - data_per_pipelined_module = _prefetch_embeddings( - batch, - # pyre-ignore - self.context, - self._pipelined_modules, - self._device, - torch.get_device_module(self._device).stream, - self.data_dist_stream, - self._default_stream, - ) - for sharded_module in self._pipelined_modules: - forward = sharded_module.forward - data = data_per_pipelined_module[forward._name] - # pyre-ignore [16] - self.context.module_input_post_prefetch_next_batch[forward._name] = data - self.context.module_contexts_post_prefetch_next_batch[forward._name] = ( - self.context.module_contexts.pop(forward._name) + with self._stream_context(self.prefetch_stream): + data_per_pipelined_module = _prefetch_embeddings( + batch, + prefetch_context, + self._pipelined_modules, + self._device, + self._stream_context, # pyre-ignore[6]: stream_context in _prefetch_embeddings seems to be mis-typed + self.data_dist_stream, + self._default_stream, ) + for sharded_module in self._pipelined_modules: + forward = sharded_module.forward + data = data_per_pipelined_module[forward._name] + prefetch_context.module_input_post_prefetch[forward._name] = data + prefetch_context.module_contexts_post_prefetch[forward._name] = ( + prefetch_context.module_contexts.pop(forward._name) + ) return batch def load_prefetch(self) -> None: + """ + DEPRECATED: exists for backward compatibility + """ assert isinstance( - self.context, PrefetchTrainPipelineContext + self._prefetch_context(), PrefetchTrainPipelineContext ), "Pass prefetch_stream into SparseDataDistUtil to use load_prefetch()" - self.context.module_input_post_prefetch.clear() - # pyre-ignore - self.context.module_contexts_post_prefetch.clear() - - with record_function("## load_sharded_module_prefetch ##"): - with torch.get_device_module(self._device).stream(self.prefetch_stream): - for sharded_module in self._pipelined_modules: - forward = sharded_module.forward - assert isinstance(forward, PrefetchPipelinedForward) - self.context.module_input_post_prefetch[forward._name] = ( - self.context.module_input_post_prefetch_next_batch[ - forward._name - ] - ) - self.context.module_contexts_post_prefetch[forward._name] = ( - self.context.module_contexts_post_prefetch_next_batch[ - forward._name - ] - ) + """ + Version=0 did + module_input_post_prefetch = module_input_post_prefetch_for_next_batch + module_contexts_post_prefetch = module_contexts_post_prefetch_for_next_batch + with version=1, there's nothing to do - they are managed at a context level, + so this is essentially done by _advance_context + prefetch above + """ + pass