diff --git a/torchrec/distributed/train_pipeline/train_pipelines.py b/torchrec/distributed/train_pipeline/train_pipelines.py index aad1c3ea9..e0488080f 100644 --- a/torchrec/distributed/train_pipeline/train_pipelines.py +++ b/torchrec/distributed/train_pipeline/train_pipelines.py @@ -60,6 +60,7 @@ _prefetch_embeddings, _rewrite_model, _start_data_dist, + _prepare_data_dist, _start_embedding_lookup, _to_device, _wait_for_batch, @@ -459,6 +460,11 @@ def __init__( not is_torchdynamo_compiling() ), "Train Pipelines rely on cuda streams, which is not supported by Dynamo" + self._data_processing_executor = ThreadPoolExecutor( + max_workers=1, + thread_name_prefix=f"data_processor_rank" + ) + # pyre-ignore self._stream_context = ( torch.get_device_module(self._device).stream @@ -618,7 +624,7 @@ def fill_pipeline(self, dataloader_iter: Iterator[In]) -> None: """ # pipeline is already filled with max capacity (2) - if len(self.batches) >= 2: + if len(self.batches) >= 3: return # executes last batch in pipeline, when there is only one batch in the pipeline @@ -646,6 +652,10 @@ def fill_pipeline(self, dataloader_iter: Iterator[In]) -> None: logger.info("fill_pipeline: failed to load batch i+1") return + def _data_processing_worker(self) -> None: + if len(self.batches) >= 3: + self.start_sparse_data_dist(self.batches[2], self.contexts[2], async_op=True) + def _wait_for_batch(self) -> None: batch_id = self.contexts[0].index if len(self.contexts) > 0 else "?" with record_function(f"## wait_for_batch {batch_id} ##"): @@ -688,10 +698,6 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out: # the input_dist of batches[0] has be invoked in previous iter. TODO: fact check self._wait_for_batch() - if len(self.batches) >= 2: - # invoke splits all_to_all comms (first part of input_dist) - self.start_sparse_data_dist(self.batches[1], self.contexts[1]) - if not self._enqueue_batch_after_forward: # batch i+2: load data and copy to gpu, the dataload iter will first exhaust here self.enqueue_batch(dataloader_iter) @@ -700,21 +706,28 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out: with record_function(f"## forward {self.contexts[0].index} ##"): self._state = PipelineState.CALL_FWD losses, output = self._model_fwd(self.batches[0]) - - if self._enqueue_batch_after_forward: - # batch i+2: load data and copy to gpu, the dataload iter will first exhaust here. - # Start this step after the forward of batch i, so that the H2D copy doesn't compete - # for pcie bandwidth with embedding lookup from UVM/UVM_CACHING. - self.enqueue_batch(dataloader_iter) - if len(self.batches) >= 2: - # invoke data (values, lengths, etc.) all_to_all comms (second part of input_dist) self.wait_sparse_data_dist(self.contexts[1]) + async_op = True + if async_op == True: + _data_processing_future = self._data_processing_executor.submit( + self._data_processing_worker, + ) + else: + self._data_processing_worker() if self._model.training: # backward self._state = PipelineState.CALL_BWD self._backward(losses) + if async_op == True: + _data_processing_future.result() + if len(self.batches) >= 3: + _fuse_input_dist_splits(self.contexts[2]) + + # batch i+2 + self.enqueue_batch(dataloader_iter) + self.sync_embeddings( self._model, @@ -840,7 +853,7 @@ def _next_batch(self, dataloader_iter: Iterator[In]) -> Optional[In]: return batch def start_sparse_data_dist( - self, batch: Optional[In], context: TrainPipelineContext + self, batch: Optional[In], context: TrainPipelineContext, async_op: bool = False, ) -> None: """ Waits for batch to finish getting copied to GPU, then starts the input dist. @@ -853,8 +866,13 @@ def start_sparse_data_dist( # Temporarily set context for next iter to populate cache with use_context_for_postprocs(self._pipelined_postprocs, context): + if async_op: + _prepare_data_dist(self._pipelined_modules, batch, context) + else: _start_data_dist(self._pipelined_modules, batch, context) + + def wait_sparse_data_dist(self, context: TrainPipelineContext) -> None: """ Waits on the input dist splits requests to get the input dist tensors requests, diff --git a/torchrec/distributed/train_pipeline/utils.py b/torchrec/distributed/train_pipeline/utils.py index de030ad46..a1b6b1a4b 100644 --- a/torchrec/distributed/train_pipeline/utils.py +++ b/torchrec/distributed/train_pipeline/utils.py @@ -118,8 +118,8 @@ def _wait_for_events( ), f"{type(batch)} must implement Multistreamable interface" batch.record_stream(stream) - -def _start_data_dist( +# We only asynchronously move the computation and blocking wait parts, while keeping the communication synchronous because Torch's communication is not thread-safe. +def _prepare_data_dist( pipelined_modules: List[ShardedModule], batch: Pipelineable, context: TrainPipelineContext, @@ -157,6 +157,14 @@ def _start_data_dist( context.input_dist_splits_requests[forward.name] = module.input_dist( module_ctx, *args, **kwargs ) + +def _start_data_dist( + pipelined_modules: List[ShardedModule], + batch: Pipelineable, + context: TrainPipelineContext, +) -> None: + + _prepare_data_dist(pipelined_modules, batch, context) _fuse_input_dist_splits(context)