Skip to content

Commit 19836c1

Browse files
che-shfacebook-github-bot
authored andcommitted
Add support for pipelined postproc
Summary: Add support for pipelined postprocs to the SparseDataDistUtil - this allows pipelined postprocs in the StagedTrainPipeline. Differential Revision: D73824601
1 parent da3104a commit 19836c1

File tree

3 files changed

+380
-115
lines changed

3 files changed

+380
-115
lines changed

torchrec/distributed/train_pipeline/tests/test_train_pipelines.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2053,7 +2053,8 @@ def gpu_postproc(x: StageOut) -> StageOut:
20532053
name="start_sparse_data_dist",
20542054
runnable=sdd.start_sparse_data_dist,
20552055
stream=sdd.data_dist_stream,
2056-
fill_callback=sdd.wait_sparse_data_dist,
2056+
fill_callback=sdd.wait_sdd_fill_callback,
2057+
data_exhausted_callback=sdd.data_exhausted_callback,
20572058
),
20582059
]
20592060
pipeline = StagedTrainPipeline(
@@ -2120,7 +2121,8 @@ def gpu_postproc(x: StageOut) -> StageOut:
21202121
name="start_sparse_data_dist",
21212122
runnable=sdd.start_sparse_data_dist,
21222123
stream=sdd.data_dist_stream,
2123-
fill_callback=sdd.wait_sparse_data_dist,
2124+
fill_callback=sdd.wait_sdd_fill_callback,
2125+
data_exhausted_callback=sdd.data_exhausted_callback,
21242126
),
21252127
]
21262128

@@ -2231,7 +2233,8 @@ def test_model_detach(self) -> None:
22312233
name="start_sparse_data_dist",
22322234
runnable=sdd.start_sparse_data_dist,
22332235
stream=sdd.data_dist_stream,
2234-
fill_callback=sdd.wait_sparse_data_dist,
2236+
fill_callback=sdd.wait_sdd_fill_callback,
2237+
data_exhausted_callback=sdd.data_exhausted_callback,
22352238
),
22362239
]
22372240

@@ -2424,7 +2427,8 @@ def gpu_postproc(x: StageOut) -> StageOut:
24242427
name="start_sparse_data_dist",
24252428
runnable=sdd.start_sparse_data_dist,
24262429
stream=sdd.data_dist_stream,
2427-
fill_callback=sdd.wait_sparse_data_dist,
2430+
fill_callback=sdd.wait_sdd_fill_callback,
2431+
data_exhausted_callback=sdd.data_exhausted_callback,
24282432
),
24292433
PipelineStage(
24302434
name="prefetch",

torchrec/distributed/train_pipeline/train_pipelines.py

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1621,8 +1621,6 @@ def _run_with_event(
16211621
inputs: Optional[In],
16221622
stream: torch.Stream,
16231623
) -> StageOutputWithEvent:
1624-
if inputs is None:
1625-
return (None, None)
16261624
with self._stream_context(stream):
16271625
# If there is no previous event, data is entering the pipeline
16281626
if event is not None:
@@ -1666,6 +1664,11 @@ def _run_stage(
16661664
"""
16671665
stage = self._pipeline_stages[stage_idx]
16681666

1667+
if self._debug_mode:
1668+
logger.info(
1669+
f"Running ## Pipeline Stage {stage_idx} : {stage.name} for batch {batch_offset + self._num_steps} ##",
1670+
)
1671+
16691672
with record_function(
16701673
f"## Pipeline Stage {stage_idx} : {stage.name} for batch {batch_offset + self._num_steps} ##"
16711674
):
@@ -1677,23 +1680,38 @@ def _run_stage(
16771680
assert batch_to_wait_with_event is not None
16781681
batch_to_wait, event = batch_to_wait_with_event
16791682

1680-
new_result = self._run_with_event(
1681-
runnable=stage.runnable,
1682-
event=event,
1683-
inputs=batch_to_wait,
1684-
stream=stage.stream,
1685-
)
1683+
if batch_to_wait is not None:
1684+
logger.info(
1685+
f"Executing ## Pipeline Stage {stage_idx} : {stage.name} for batch {batch_offset + self._num_steps} ##",
1686+
)
1687+
new_result = self._run_with_event(
1688+
runnable=stage.runnable,
1689+
event=event,
1690+
inputs=batch_to_wait,
1691+
stream=stage.stream,
1692+
)
1693+
else:
1694+
logger.info(
1695+
f"Skipping due to None ## Pipeline Stage {stage_idx} : {stage.name} for batch {batch_offset + self._num_steps} ##",
1696+
)
1697+
new_result = (None, None)
1698+
if (
1699+
data_exhausted_callback := stage.data_exhausted_callback
1700+
) is not None:
1701+
data_exhausted_callback()
16861702

16871703
self._stage_outputs[batch_offset] = new_result
16881704
if self._debug_mode:
16891705
logger.info(
1690-
f"Running ## Pipeline Stage {stage_idx} : {stage.name} for batch {batch_offset + self._num_steps} ##",
1706+
f"Finshed ## Pipeline Stage {stage_idx} : {stage.name} for batch {batch_offset + self._num_steps} ##",
16911707
)
16921708

16931709
if fill and (fill_callback := stage.fill_callback) is not None:
16941710
if self._debug_mode:
1695-
logger.info(f"Finished callback for {stage.name}")
1711+
logger.info(f"Started callback for {stage.name}")
16961712
fill_callback()
1713+
if self._debug_mode:
1714+
logger.info(f"Finished callback for {stage.name}")
16971715

16981716
return new_result
16991717

@@ -1779,6 +1797,9 @@ def progress(
17791797

17801798
self._num_steps += 1
17811799

1800+
if self._debug_mode:
1801+
logger.info(f"Starting pipeline step {self._num_steps}")
1802+
17821803
for stage_idx in range(self.num_stages):
17831804
stage_output_idx = self.num_stages - 1 - stage_idx
17841805
self._run_stage(
@@ -1799,6 +1820,8 @@ def progress(
17991820
self.flush_end()
18001821
return self.progress(dataloader_iter)
18011822

1823+
if self._debug_mode:
1824+
logger.info(f"Finished pipeline step {self._num_steps}")
18021825
return out
18031826

18041827

0 commit comments

Comments
 (0)