diff --git a/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py b/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py index b8494fa9e..9a22c80e0 100644 --- a/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py +++ b/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py @@ -2165,7 +2165,8 @@ def gpu_postproc(x: StageOut) -> StageOut: name="start_sparse_data_dist", runnable=sdd.start_sparse_data_dist, stream=sdd.data_dist_stream, - fill_callback=sdd.wait_sparse_data_dist, + fill_callback=sdd.wait_sdd_fill_callback, + data_exhausted_callback=sdd.data_exhausted_callback, ), ] pipeline = StagedTrainPipeline( @@ -2232,7 +2233,8 @@ def gpu_postproc(x: StageOut) -> StageOut: name="start_sparse_data_dist", runnable=sdd.start_sparse_data_dist, stream=sdd.data_dist_stream, - fill_callback=sdd.wait_sparse_data_dist, + fill_callback=sdd.wait_sdd_fill_callback, + data_exhausted_callback=sdd.data_exhausted_callback, ), ] @@ -2343,7 +2345,8 @@ def test_model_detach(self) -> None: name="start_sparse_data_dist", runnable=sdd.start_sparse_data_dist, stream=sdd.data_dist_stream, - fill_callback=sdd.wait_sparse_data_dist, + fill_callback=sdd.wait_sdd_fill_callback, + data_exhausted_callback=sdd.data_exhausted_callback, ), ] @@ -2536,7 +2539,8 @@ def gpu_postproc(x: StageOut) -> StageOut: name="start_sparse_data_dist", runnable=sdd.start_sparse_data_dist, stream=sdd.data_dist_stream, - fill_callback=sdd.wait_sparse_data_dist, + fill_callback=sdd.wait_sdd_fill_callback, + data_exhausted_callback=sdd.data_exhausted_callback, ), PipelineStage( name="prefetch", diff --git a/torchrec/distributed/train_pipeline/train_pipelines.py b/torchrec/distributed/train_pipeline/train_pipelines.py index b8be13994..3cb2af364 100644 --- a/torchrec/distributed/train_pipeline/train_pipelines.py +++ b/torchrec/distributed/train_pipeline/train_pipelines.py @@ -1707,11 +1707,9 @@ def _run_with_event( self, runnable: RunnableType, event: Optional[torch.Event], - inputs: Optional[In], + inputs: In, stream: torch.Stream, ) -> StageOutputWithEvent: - if inputs is None: - return (None, None) with self._stream_context(stream): # If there is no previous event, data is entering the pipeline if event is not None: @@ -1755,6 +1753,11 @@ def _run_stage( """ stage = self._pipeline_stages[stage_idx] + if self._debug_mode: + logger.info( + f"Running ## Pipeline Stage {stage_idx} : {stage.name} for batch {batch_offset + self._num_steps} ##", + ) + with record_function( f"## Pipeline Stage {stage_idx} : {stage.name} for batch {batch_offset + self._num_steps} ##" ): @@ -1766,23 +1769,40 @@ def _run_stage( assert batch_to_wait_with_event is not None batch_to_wait, event = batch_to_wait_with_event - new_result = self._run_with_event( - runnable=stage.runnable, - event=event, - inputs=batch_to_wait, - stream=stage.stream, - ) + if batch_to_wait is not None: + if self._debug_mode: + logger.info( + f"Executing ## Pipeline Stage {stage_idx} : {stage.name} for batch {batch_offset + self._num_steps} ##", + ) + new_result = self._run_with_event( + runnable=stage.runnable, + event=event, + inputs=batch_to_wait, + stream=stage.stream, + ) + else: + if self._debug_mode: + logger.info( + f"Skipping due to None ## Pipeline Stage {stage_idx} : {stage.name} for batch {batch_offset + self._num_steps} ##", + ) + new_result = (None, None) + if ( + data_exhausted_callback := stage.data_exhausted_callback + ) is not None: + data_exhausted_callback() self._stage_outputs[batch_offset] = new_result if self._debug_mode: logger.info( - f"Running ## Pipeline Stage {stage_idx} : {stage.name} for batch {batch_offset + self._num_steps} ##", + f"Finshed ## Pipeline Stage {stage_idx} : {stage.name} for batch {batch_offset + self._num_steps} ##", ) if fill and (fill_callback := stage.fill_callback) is not None: if self._debug_mode: - logger.info(f"Finished callback for {stage.name}") + logger.info(f"Started callback for {stage.name}") fill_callback() + if self._debug_mode: + logger.info(f"Finished callback for {stage.name}") return new_result @@ -1868,6 +1888,9 @@ def progress( self._num_steps += 1 + if self._debug_mode: + logger.info(f"Starting pipeline step {self._num_steps}") + for stage_idx in range(self.num_stages): stage_output_idx = self.num_stages - 1 - stage_idx self._run_stage( @@ -1888,6 +1911,8 @@ def progress( self.flush_end() return self.progress(dataloader_iter) + if self._debug_mode: + logger.info(f"Finished pipeline step {self._num_steps}") return out diff --git a/torchrec/distributed/train_pipeline/utils.py b/torchrec/distributed/train_pipeline/utils.py index 2a561f80a..e70308435 100644 --- a/torchrec/distributed/train_pipeline/utils.py +++ b/torchrec/distributed/train_pipeline/utils.py @@ -12,7 +12,7 @@ import copy import itertools import logging -from collections import defaultdict, OrderedDict +from collections import defaultdict, deque, OrderedDict from contextlib import AbstractContextManager from dataclasses import dataclass, field @@ -22,6 +22,7 @@ Any, Callable, cast, + Deque, Dict, Generator, Generic, @@ -38,6 +39,7 @@ import torch from torch import distributed as dist +from torch.utils.hooks import RemovableHandle from torchrec.distributed.types import LazyAwaitable if not torch._running_with_deploy(): @@ -170,12 +172,17 @@ class PipelineStage: stream (torch.cuda.streams.Stream): Stream to run on. Often each stage has a unique stream, but having different pipelines share a stream provides more synchronization semantics. + fill_callback (Optional[Callable[[], None]])) - optional step to run after the main + runnable during filling the pipeline + data_exhausted_callback (Optional[Callable[[], None]])) - optional callback to run + when data is ehxausted """ name: str runnable: RunnableType stream: torch.Stream fill_callback: Optional[Callable[[], None]] = None + data_exhausted_callback: Optional[Callable[[], None]] = None class BaseArgInfoStep(abc.ABC): @@ -1802,7 +1809,7 @@ def _prefetch_embeddings( assert forward._name in context.input_dist_tensors_requests request = context.input_dist_tensors_requests.pop(forward._name) assert isinstance(request, Awaitable) - with record_function("## wait_sparse_data_dist ##"): + with record_function(f"## _prefetch_embeddings {context.index} ##"): # Finish waiting on the dist_stream, # in case some delayed stream scheduling happens during the wait() call. with stream_context(data_dist_stream): @@ -1890,7 +1897,7 @@ class SparseDataDistUtil(Generic[In]): name="start_sparse_data_dist", runnable=sdd.start_sparse_data_dist, stream=sdd.data_dist_stream, - fill_callback=sdd.wait_sparse_data_dist, + fill_callback=sdd.wait_sdd_fill_callback, ), PipelineStage( name="prefetch", @@ -1903,6 +1910,11 @@ class SparseDataDistUtil(Generic[In]): return StagedTrainPipeline(pipeline_stages=pipeline) """ + _TRAIN_CONTEXT_VERSION = 1 + # Convenience flag to perform additional assertions on contexts + # to make sure contexts are advancing correctly. + _WITH_CONTEXT_ASSERTIONS = False + def __init__( self, model: torch.nn.Module, @@ -1913,20 +1925,24 @@ def __init__( super().__init__() self.model = model self.data_dist_stream = data_dist_stream - self.prefetch_stream = prefetch_stream self.apply_jit = apply_jit - self.context = ( - PrefetchTrainPipelineContext(version=0) - if prefetch_stream - else TrainPipelineContext(version=0) - ) + self.prefetch_stream = prefetch_stream + self._next_index: int = 0 + self._contexts: Deque[TrainPipelineContext] = deque() self.initialized = False self._pipelined_modules: List[ShardedModule] = [] self._pipelined_postprocs: List[PipelinedPostproc] = [] - # pyre-ignore - self.fwd_hook = None + self.fwd_hook: Optional[RemovableHandle] = None self._device: torch.device = data_dist_stream.device + self._stream_context: Callable[ + [Optional[torch.Stream]], torch.cuda.StreamContext + ] = ( + torch.get_device_module(self._device).stream + if self._device.type in ["cuda", "mtia"] + else torch.cuda.stream + ) + # pyre-ignore self._original_forwards: List[Callable[..., Any]] = [] self._original_kjt_dist_forwards: List[ @@ -1935,7 +1951,7 @@ def __init__( self._pipelined_forward: Type[BaseForward[TrainPipelineContext]] = cast( Type[BaseForward[TrainPipelineContext]], - (PrefetchPipelinedForward if prefetch_stream else PipelinedForward), + (PrefetchPipelinedForward if self._with_prefetch else PipelinedForward), ) self._default_stream: Optional[torch.Stream] = ( @@ -1943,6 +1959,193 @@ def __init__( if self._device.type in ["cuda", "mtia"] else None ) + # When data iterator is exhausted, contexts should continue advancing until + # reaching the end (i.e. no longer called from the StagedTrainingPipeline) + # however normal invariants no longer apply (e.g. module_contexts might be empty + # before prefetch stage). Currently, all actions (`prefetch`, `start/wait_sparse_data_dist`) + # tolerate lack of data from the previous stage - so context assertions are mostly + # correctness invariant. However, if that changes, having invariants monitored/enforced + # during exhastion phase might become necessary. + self._exhausting_mode = False + + @property + def _with_prefetch(self) -> bool: + return self.prefetch_stream is not None + + def _is_reattaching(self) -> bool: + return len(self._contexts) > 0 + + def should_assert_context_invariants(self, ctx: TrainPipelineContext) -> bool: + return ( + self._WITH_CONTEXT_ASSERTIONS + and self.initialized + and not self._exhausting_mode + and ( + ctx.index is not None and ctx.index >= 0 + ) # "fake contexts" to support pipeline initialization + ) + + # === Debugging helpers === # + @property + def _have_pipelined_modules(self) -> bool: + return len(self._pipelined_modules) > 0 + + @property + def _have_pipelined_postprocs(self) -> bool: + return len(self._pipelined_postprocs) > 0 + + def _pipelined_modules_fqns(self) -> Set[str]: + return {module.forward._name for module in self._pipelined_modules} + + def _pipelined_postprocs_fqns(self) -> Set[str]: + return {module._fqn for module in self._pipelined_postprocs} + + # === Debugging helpers === # + + # ==== Context management === # + # In short: version=1 contexts essentially represent "passing of time" + # and have one-to-one correspondence to batches. "Monolithic" torchrec pipelines + # (e.g. TrainPipelineSparseDist) explicitly manage batches and contexts together + # (see TrainPipelineSparseDist.enqueue_batch), however StagedTrainPipeline abstracts + # that away + supports stages that don't require contexts (in fact, SDD is the only one) + # So we just manage contexts and batches together in lockstep - via _advance_context calls. + # + # Essentially, StagedTrainPipeline during a single `progress` call runs each stage + # for a different batch, keeping the stage outputs in a `_stage_outputs` list, and + # advancing the list at the beginning of the `progress`. + # Tricky part is that SparseDataDistUtil might be participating in TWO stages: + # * "main" with start_data_dist -> wait_data_dist pair for `runnable` and `fill_callback` + # * "prefetch" with prefetch -> load_prefetch for `runnable` and `fill_callback` + # + # For this to work, we: + # (1) need to manage contexts in a lockstep with batch advancing through stages (_advance_context) + # (2) perform various actions (start dist, wait dist, etc.) against the correct contexts + # ("named" contexts below and how they are used in start/wait sparse_dist, prefetch, etc.) + # (3) set contexts for the _pipelined_modules and _pipelined_postprocs to the "current batch context" + # for the model to run correctly (_set_module_context) + # + # SDD Util uses two or three contexts, depending on if prefetch is present + # * context[0] is always the "current batch" context - used for model forward (outside this class) + # * context[1] is used for prefetch if it is set, and start/wait_sparse_data_dist if not + # * context[2] is used for start/wait_sparse_data_dist if prefetch is not set + + def _create_context(self, index: int) -> TrainPipelineContext: + version = self._TRAIN_CONTEXT_VERSION + return ( + PrefetchTrainPipelineContext(index=index, version=version) + if self._with_prefetch + else TrainPipelineContext(index=index, version=version) + ) + + def _add_context(self) -> None: + self._contexts.append(self._create_context(self._next_index)) + self._next_index += 1 + + def _advance_context(self) -> None: + self._assert_contexts_count() + self._contexts.popleft() + self._add_context() + self._set_module_context(self._context_for_model_forward()) + + def _set_module_context(self, context: TrainPipelineContext) -> None: + for module in self._pipelined_modules: + module.forward.set_context(context) + + for postproc_module in self._pipelined_postprocs: + # This ensures that next iter model fwd uses cached results + postproc_module.set_context(context) + + def _assert_contexts_count(self) -> None: + if not self._WITH_CONTEXT_ASSERTIONS: + return + contexts_len = len(self._contexts) + expected = 3 if self._with_prefetch else 2 + assert ( + contexts_len == expected + ), f"Expected to have {expected} contexts, but had {contexts_len}" + + # ====== "Named" contexts - to make it clearer which contexts are used for which operation ====== # + # This is purely convenience methods, feel free to remove if they get in the way + def _current_context(self) -> TrainPipelineContext: + return self._contexts[0] + + def _assert_input_dist_tensors( + self, context: TrainPipelineContext, expected_fqns: Set[str] + ) -> None: + specified_keys = context.input_dist_tensors_requests.keys() + assert ( + specified_keys == expected_fqns + ), f"Context(idx:{context.index}).input_dist_tensors_requests {specified_keys} != pipelined modules fqns {expected_fqns}" + + def _assert_module_contexts( + self, context: TrainPipelineContext, expected_fqns: Set[str] + ) -> None: + specified_keys = context.module_contexts.keys() + assert ( + specified_keys == expected_fqns + ), f"Context(idx:{context.index}).module_contexts {specified_keys} != pipelined modules fqns {expected_fqns}" + + def _assert_module_contexts_post_prefetch( + self, context: PrefetchTrainPipelineContext, expected_fqns: Set[str] + ) -> None: + specified_keys = context.module_contexts_post_prefetch.keys() + assert ( + specified_keys == expected_fqns + ), f"Context(idx:{context.index}).module_contexts_post_prefetch {specified_keys} != pipelined modules fqns {expected_fqns}" + + def _assert_module_input_post_prefetch( + self, context: PrefetchTrainPipelineContext, expected_fqns: Set[str] + ) -> None: + specified_keys = context.module_input_post_prefetch.keys() + assert ( + specified_keys == expected_fqns + ), f"Context(idx:{context.index}).module_input_post_prefetch {specified_keys} != pipelined modules fqns {expected_fqns}" + + def _context_for_model_forward(self) -> TrainPipelineContext: + ctx = self._current_context() + if self.should_assert_context_invariants(ctx): + target_fqns = self._pipelined_modules_fqns() + if self._with_prefetch: + assert isinstance(ctx, PrefetchTrainPipelineContext) + self._assert_module_input_post_prefetch(ctx, target_fqns) + self._assert_module_contexts_post_prefetch(ctx, target_fqns) + else: + self._assert_input_dist_tensors(ctx, target_fqns) + self._assert_module_contexts(ctx, target_fqns) + return ctx + + def _start_dist_context(self) -> TrainPipelineContext: + if self._with_prefetch: + ctx = self._contexts[2] + else: + ctx = self._contexts[1] + + return ctx + + def _wait_dist_context(self) -> TrainPipelineContext: + # Note: see comment on the forward_hook in _initialize method + ctx = self._start_dist_context() + if self.should_assert_context_invariants(ctx): + if self._have_pipelined_modules: + assert ( + len(ctx.fused_splits_awaitables) > 0 + ), f"fused_splits_awaitables was empty on {ctx.index=} - was start_sparse_data_dist called?" + return ctx + + def _prefetch_context(self) -> PrefetchTrainPipelineContext: + ctx = self._contexts[1] + assert isinstance( + ctx, PrefetchTrainPipelineContext + ), "Pass prefetch_stream into SparseDataDistUtil to use prefetch_context()" + if self.should_assert_context_invariants(ctx): + target_fqns = self._pipelined_modules_fqns() + self._assert_input_dist_tensors(ctx, target_fqns) + self._assert_module_contexts(ctx, target_fqns) + return ctx + + # ====== End "Named" contexts ====== # + + # === End context management === # def detach(self) -> torch.nn.Module: """ @@ -1951,7 +2154,8 @@ def detach(self) -> torch.nn.Module: detach() can be called at any point, and inflight batches do not need to be flushed before calling it. Calling pipeline.progress() will re-attach the model - to the pipeline and the pipeline will progress normally from the point it was detached (i.e. inflight batches will be kept when calling detach). + to the pipeline and the pipeline will progress normally from the point it was + detached (i.e. inflight batches will be kept when calling detach). While the model is detached, it is equivalent to the model before passing to the pipeline, so forward and backward passes, and optimizer updates can be @@ -1972,112 +2176,150 @@ def detach(self) -> torch.nn.Module: self.initialized = False return self.model - def start_sparse_data_dist(self, batch: In) -> In: - if not self.initialized: - # Step 1: Pipeline input dist in trec sharded modules - # TODO (yhshin): support postproc modules for `StagedTrainPipeline` - ( - self._pipelined_modules, - self.model, - self._original_forwards, - self._pipelined_postprocs, - _, - ) = _rewrite_model( - model=self.model, - context=self.context, - dist_stream=self.data_dist_stream, - batch=batch, - apply_jit=self.apply_jit, - pipelined_forward=self._pipelined_forward, - default_stream=self._default_stream, - ) - # initializes input dist, so we can override input dist forwards - _start_data_dist(self._pipelined_modules, batch, self.context) - self._original_kjt_dist_forwards = _override_input_dist_forwards( - self._pipelined_modules - ) + def _initialize_or_reattach(self, batch: In) -> None: + # Step 0: Handle differences between initialization and reattaching + if self._is_reattaching(): + # if reattaching, contexts are already there, so we want to use + # the current context for model forward - as if continuing to run normally + context_for_rewrite = self._current_context() + else: + # if initializing, no contexts are present, so we add them: + if self._with_prefetch: + self._contexts.append(self._create_context(-2)) # throwaway context + self._contexts.append(self._create_context(-1)) # throwaway context + self._add_context() # actual context to be used for everything in the initial iteration + context_for_rewrite = self._contexts[-1] - # Step 2: Register post-forward hook to wait SDD - def forward_hook( - module: torch.nn.Module, - input: Union[torch.Tensor, Tuple[torch.Tensor]], - output: Union[torch.Tensor, Tuple[torch.Tensor]], - ) -> None: - if self.prefetch_stream is not None: - # Need to load prefetch before wait_sparse_data_dist - self.load_prefetch() - self.wait_sparse_data_dist() + self._assert_contexts_count() - self.fwd_hook = self.model.register_forward_hook(forward_hook) + # Step 1: Pipeline input dist in trec sharded modules + ( + self._pipelined_modules, + self.model, + self._original_forwards, + self._pipelined_postprocs, + _, + ) = _rewrite_model( + model=self.model, + context=context_for_rewrite, + dist_stream=self.data_dist_stream, + batch=batch, + apply_jit=self.apply_jit, + pipelined_forward=self._pipelined_forward, + default_stream=self._default_stream, + ) + # Setting the stage for the first batch + # initialize input dist + _start_data_dist(self._pipelined_modules, batch, self._start_dist_context()) + # so we can override input dist forwards + self._original_kjt_dist_forwards = _override_input_dist_forwards( + self._pipelined_modules + ) + + # Step 2: Register post-forward hook to wait SDD and advance contexts + def forward_hook( + module: torch.nn.Module, + input: Union[torch.Tensor, Tuple[torch.Tensor]], + output: Union[torch.Tensor, Tuple[torch.Tensor]], + ) -> None: + # Note: tricky part - a bit delicate choreography between + # StagedPipeline and this class + # (see https://github.com/pytorch/torchrec/pull/2239 for details) + # wait_dist need to be called as post_forward hook + # at the end of the batch N, so that the data is awaited + # before start of the next batch. + self.wait_sparse_data_dist() + # _advance_context should be called after wait_sparse_data_dist, + # but before start_data_dist for the next batch + # which means right here, and nowhere else + self._advance_context() + # ... this can be made more explicit by adding dedicated hooks for "batch start"/"batch end" events + # to the StagedPipeline, PipelineStage and this class, but hook seems to be doing an adequate job for now + + self.fwd_hook = self.model.register_forward_hook(forward_hook) + + self.initialized = True + + def wait_sdd_fill_callback(self) -> None: + """ + Used by StagedTrainPipeline during only during initial pipeline filling. + + At that part, model.forward is not executed, so forward hook is not called. + """ + self.wait_sparse_data_dist() + self._advance_context() + + def data_exhausted_callback(self) -> None: + """ + Called by StagedTrainPipeline when all batches were processed. + """ + self._exhausting_mode = True - self.initialized = True + def start_sparse_data_dist(self, batch: In) -> In: + if not self.initialized: + self._initialize_or_reattach(batch) - _start_data_dist(self._pipelined_modules, batch, self.context) + ctx = self._start_dist_context() + with record_function(f"## start_sparse_data_dist {ctx.index} ##"): + with use_context_for_postprocs(self._pipelined_postprocs, ctx): + _start_data_dist(self._pipelined_modules, batch, ctx) return batch def wait_sparse_data_dist(self) -> None: - with record_function("## wait_sparse_data_dist ##"): - with torch.get_device_module(self._device).stream(self.data_dist_stream): - self.context.module_contexts = ( - self.context.module_contexts_next_batch.copy() - ) - self.context.input_dist_tensors_requests.clear() - for names, awaitable in self.context.fused_splits_awaitables: + """ + Waits on the input dist splits requests to get the input dist tensors requests, + and populates the context with them. + """ + ctx = self._wait_dist_context() + with record_function(f"## wait_sparse_data_dist {ctx.index} ##"): + with self._stream_context(self.data_dist_stream): + for names, awaitable in ctx.fused_splits_awaitables: for name, request in zip(names, awaitable.wait()): - self.context.input_dist_tensors_requests[name] = request + ctx.input_dist_tensors_requests[name] = request + # these won't be used by the rest of the pipeline, so just deleting them to free + # the memory they occupy + ctx.input_dist_splits_requests.clear() + ctx.fused_splits_awaitables.clear() def prefetch(self, batch: In) -> In: """ Waits for input dist to finish, then prefetches data. """ assert isinstance( - self.context, PrefetchTrainPipelineContext + self._prefetch_context(), PrefetchTrainPipelineContext ), "Pass prefetch_stream into SparseDataDistUtil to use prefetch() as a stage" - self.context.module_input_post_prefetch_next_batch.clear() - # pyre-ignore - self.context.module_contexts_post_prefetch_next_batch.clear() + ctx: PrefetchTrainPipelineContext = self._prefetch_context() - data_per_pipelined_module = _prefetch_embeddings( - batch, - # pyre-ignore - self.context, - self._pipelined_modules, - self._device, - torch.get_device_module(self._device).stream, - self.data_dist_stream, - self._default_stream, - ) - for sharded_module in self._pipelined_modules: - forward = sharded_module.forward - data = data_per_pipelined_module[forward._name] - # pyre-ignore [16] - self.context.module_input_post_prefetch_next_batch[forward._name] = data - self.context.module_contexts_post_prefetch_next_batch[forward._name] = ( - self.context.module_contexts.pop(forward._name) + with self._stream_context(self.prefetch_stream): + data_per_pipelined_module = _prefetch_embeddings( + batch, + ctx, + self._pipelined_modules, + self._device, + self._stream_context, + self.data_dist_stream, + self._default_stream, ) + # TODO (eugenykolpakov): investigate if these can be moved outside of the `with stream_context(...)` block + # This might impact memory fragmentation (since CUDA caching allocator is stream-aware), + # so need to check how memory behaves with different streams + for sharded_module in self._pipelined_modules: + forward = sharded_module.forward + data = data_per_pipelined_module[forward._name] + ctx.module_input_post_prefetch[forward._name] = data + ctx.module_contexts_post_prefetch[forward._name] = ( + ctx.module_contexts.pop(forward._name) + ) return batch def load_prefetch(self) -> None: - assert isinstance( - self.context, PrefetchTrainPipelineContext - ), "Pass prefetch_stream into SparseDataDistUtil to use load_prefetch()" - self.context.module_input_post_prefetch.clear() - # pyre-ignore - self.context.module_contexts_post_prefetch.clear() - - with record_function("## load_sharded_module_prefetch ##"): - with torch.get_device_module(self._device).stream(self.prefetch_stream): - for sharded_module in self._pipelined_modules: - forward = sharded_module.forward - assert isinstance(forward, PrefetchPipelinedForward) - self.context.module_input_post_prefetch[forward._name] = ( - self.context.module_input_post_prefetch_next_batch[ - forward._name - ] - ) - self.context.module_contexts_post_prefetch[forward._name] = ( - self.context.module_contexts_post_prefetch_next_batch[ - forward._name - ] - ) + """ + DEPRECATED: exists for backward compatibility + """ + # Version=0 did + # module_input_post_prefetch = module_input_post_prefetch_for_next_batch + # module_contexts_post_prefetch = module_contexts_post_prefetch_for_next_batch + # with version=1, there's nothing to do - they are managed at a context level, + # so this is essentially done by _advance_context + prefetch above + pass