From 0fb9ab20c65be5f53922844d47158627b19fc402 Mon Sep 17 00:00:00 2001 From: zhaozheng <976525070@qq.com> Date: Thu, 11 Sep 2025 23:31:41 +0800 Subject: [PATCH] multiplethread_pipeline --- .../train_pipeline/train_pipelines.py | 40 ++++++++++++------- torchrec/distributed/train_pipeline/utils.py | 12 +++++- 2 files changed, 36 insertions(+), 16 deletions(-) diff --git a/torchrec/distributed/train_pipeline/train_pipelines.py b/torchrec/distributed/train_pipeline/train_pipelines.py index aad1c3ea9..2304313db 100644 --- a/torchrec/distributed/train_pipeline/train_pipelines.py +++ b/torchrec/distributed/train_pipeline/train_pipelines.py @@ -60,6 +60,7 @@ _prefetch_embeddings, _rewrite_model, _start_data_dist, + _prepare_data_dist, _start_embedding_lookup, _to_device, _wait_for_batch, @@ -646,6 +647,10 @@ def fill_pipeline(self, dataloader_iter: Iterator[In]) -> None: logger.info("fill_pipeline: failed to load batch i+1") return + def _data_processing_worker(self) -> None: + if len(self.batches) >= 2: + self.start_sparse_data_dist(self.batches[1], self.contexts[1], async_op=True) + def _wait_for_batch(self) -> None: batch_id = self.contexts[0].index if len(self.contexts) > 0 else "?" with record_function(f"## wait_for_batch {batch_id} ##"): @@ -688,10 +693,6 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out: # the input_dist of batches[0] has be invoked in previous iter. TODO: fact check self._wait_for_batch() - if len(self.batches) >= 2: - # invoke splits all_to_all comms (first part of input_dist) - self.start_sparse_data_dist(self.batches[1], self.contexts[1]) - if not self._enqueue_batch_after_forward: # batch i+2: load data and copy to gpu, the dataload iter will first exhaust here self.enqueue_batch(dataloader_iter) @@ -701,20 +702,26 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out: self._state = PipelineState.CALL_FWD losses, output = self._model_fwd(self.batches[0]) - if self._enqueue_batch_after_forward: - # batch i+2: load data and copy to gpu, the dataload iter will first exhaust here. - # Start this step after the forward of batch i, so that the H2D copy doesn't compete - # for pcie bandwidth with embedding lookup from UVM/UVM_CACHING. - self.enqueue_batch(dataloader_iter) - - if len(self.batches) >= 2: - # invoke data (values, lengths, etc.) all_to_all comms (second part of input_dist) - self.wait_sparse_data_dist(self.contexts[1]) + async_op = True + if async_op == True: + _data_processing_future = self._data_processing_executor.submit( + self._data_processing_worker, + ) + else: + self._data_processing_worker() if self._model.training: # backward self._state = PipelineState.CALL_BWD self._backward(losses) + if async_op == True: + _data_processing_future.result() + if len(self.batches) >= 2: + _fuse_input_dist_splits(self.contexts[1]) + + # batch i+2 + self.enqueue_batch(dataloader_iter) + self.sync_embeddings( self._model, @@ -840,7 +847,7 @@ def _next_batch(self, dataloader_iter: Iterator[In]) -> Optional[In]: return batch def start_sparse_data_dist( - self, batch: Optional[In], context: TrainPipelineContext + self, batch: Optional[In], context: TrainPipelineContext, async_op: bool = False, ) -> None: """ Waits for batch to finish getting copied to GPU, then starts the input dist. @@ -853,8 +860,13 @@ def start_sparse_data_dist( # Temporarily set context for next iter to populate cache with use_context_for_postprocs(self._pipelined_postprocs, context): + if async_op: + _prepare_data_dist(self._pipelined_modules, batch, context) + else: _start_data_dist(self._pipelined_modules, batch, context) + + def wait_sparse_data_dist(self, context: TrainPipelineContext) -> None: """ Waits on the input dist splits requests to get the input dist tensors requests, diff --git a/torchrec/distributed/train_pipeline/utils.py b/torchrec/distributed/train_pipeline/utils.py index de030ad46..a1b6b1a4b 100644 --- a/torchrec/distributed/train_pipeline/utils.py +++ b/torchrec/distributed/train_pipeline/utils.py @@ -118,8 +118,8 @@ def _wait_for_events( ), f"{type(batch)} must implement Multistreamable interface" batch.record_stream(stream) - -def _start_data_dist( +# We only asynchronously move the computation and blocking wait parts, while keeping the communication synchronous because Torch's communication is not thread-safe. +def _prepare_data_dist( pipelined_modules: List[ShardedModule], batch: Pipelineable, context: TrainPipelineContext, @@ -157,6 +157,14 @@ def _start_data_dist( context.input_dist_splits_requests[forward.name] = module.input_dist( module_ctx, *args, **kwargs ) + +def _start_data_dist( + pipelined_modules: List[ShardedModule], + batch: Pipelineable, + context: TrainPipelineContext, +) -> None: + + _prepare_data_dist(pipelined_modules, batch, context) _fuse_input_dist_splits(context)