Skip to content

Commit a7ec180

Browse files
che-shfacebook-github-bot
authored andcommitted
Add support for pipelined postproc (#2940)
Summary: Pull Request resolved: #2940 Add support for pipelined postprocs to the SparseDataDistUtil - this allows pipelined postprocs in the StagedTrainPipeline. Differential Revision: D73824601
1 parent ddcd2b9 commit a7ec180

File tree

3 files changed

+383
-116
lines changed

3 files changed

+383
-116
lines changed

torchrec/distributed/train_pipeline/tests/test_train_pipelines.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2053,7 +2053,8 @@ def gpu_postproc(x: StageOut) -> StageOut:
20532053
name="start_sparse_data_dist",
20542054
runnable=sdd.start_sparse_data_dist,
20552055
stream=sdd.data_dist_stream,
2056-
fill_callback=sdd.wait_sparse_data_dist,
2056+
fill_callback=sdd.wait_sdd_fill_callback,
2057+
data_exhausted_callback=sdd.data_exhausted_callback,
20572058
),
20582059
]
20592060
pipeline = StagedTrainPipeline(
@@ -2120,7 +2121,8 @@ def gpu_postproc(x: StageOut) -> StageOut:
21202121
name="start_sparse_data_dist",
21212122
runnable=sdd.start_sparse_data_dist,
21222123
stream=sdd.data_dist_stream,
2123-
fill_callback=sdd.wait_sparse_data_dist,
2124+
fill_callback=sdd.wait_sdd_fill_callback,
2125+
data_exhausted_callback=sdd.data_exhausted_callback,
21242126
),
21252127
]
21262128

@@ -2231,7 +2233,8 @@ def test_model_detach(self) -> None:
22312233
name="start_sparse_data_dist",
22322234
runnable=sdd.start_sparse_data_dist,
22332235
stream=sdd.data_dist_stream,
2234-
fill_callback=sdd.wait_sparse_data_dist,
2236+
fill_callback=sdd.wait_sdd_fill_callback,
2237+
data_exhausted_callback=sdd.data_exhausted_callback,
22352238
),
22362239
]
22372240

@@ -2424,7 +2427,8 @@ def gpu_postproc(x: StageOut) -> StageOut:
24242427
name="start_sparse_data_dist",
24252428
runnable=sdd.start_sparse_data_dist,
24262429
stream=sdd.data_dist_stream,
2427-
fill_callback=sdd.wait_sparse_data_dist,
2430+
fill_callback=sdd.wait_sdd_fill_callback,
2431+
data_exhausted_callback=sdd.data_exhausted_callback,
24282432
),
24292433
PipelineStage(
24302434
name="prefetch",

torchrec/distributed/train_pipeline/train_pipelines.py

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1624,11 +1624,9 @@ def _run_with_event(
16241624
self,
16251625
runnable: RunnableType,
16261626
event: Optional[torch.Event],
1627-
inputs: Optional[In],
1627+
inputs: In,
16281628
stream: torch.Stream,
16291629
) -> StageOutputWithEvent:
1630-
if inputs is None:
1631-
return (None, None)
16321630
with self._stream_context(stream):
16331631
# If there is no previous event, data is entering the pipeline
16341632
if event is not None:
@@ -1672,6 +1670,11 @@ def _run_stage(
16721670
"""
16731671
stage = self._pipeline_stages[stage_idx]
16741672

1673+
if self._debug_mode:
1674+
logger.info(
1675+
f"Running ## Pipeline Stage {stage_idx} : {stage.name} for batch {batch_offset + self._num_steps} ##",
1676+
)
1677+
16751678
with record_function(
16761679
f"## Pipeline Stage {stage_idx} : {stage.name} for batch {batch_offset + self._num_steps} ##"
16771680
):
@@ -1683,23 +1686,40 @@ def _run_stage(
16831686
assert batch_to_wait_with_event is not None
16841687
batch_to_wait, event = batch_to_wait_with_event
16851688

1686-
new_result = self._run_with_event(
1687-
runnable=stage.runnable,
1688-
event=event,
1689-
inputs=batch_to_wait,
1690-
stream=stage.stream,
1691-
)
1689+
if batch_to_wait is not None:
1690+
if self._debug_mode:
1691+
logger.info(
1692+
f"Executing ## Pipeline Stage {stage_idx} : {stage.name} for batch {batch_offset + self._num_steps} ##",
1693+
)
1694+
new_result = self._run_with_event(
1695+
runnable=stage.runnable,
1696+
event=event,
1697+
inputs=batch_to_wait,
1698+
stream=stage.stream,
1699+
)
1700+
else:
1701+
if self._debug_mode:
1702+
logger.info(
1703+
f"Skipping due to None ## Pipeline Stage {stage_idx} : {stage.name} for batch {batch_offset + self._num_steps} ##",
1704+
)
1705+
new_result = (None, None)
1706+
if (
1707+
data_exhausted_callback := stage.data_exhausted_callback
1708+
) is not None:
1709+
data_exhausted_callback()
16921710

16931711
self._stage_outputs[batch_offset] = new_result
16941712
if self._debug_mode:
16951713
logger.info(
1696-
f"Running ## Pipeline Stage {stage_idx} : {stage.name} for batch {batch_offset + self._num_steps} ##",
1714+
f"Finshed ## Pipeline Stage {stage_idx} : {stage.name} for batch {batch_offset + self._num_steps} ##",
16971715
)
16981716

16991717
if fill and (fill_callback := stage.fill_callback) is not None:
17001718
if self._debug_mode:
1701-
logger.info(f"Finished callback for {stage.name}")
1719+
logger.info(f"Started callback for {stage.name}")
17021720
fill_callback()
1721+
if self._debug_mode:
1722+
logger.info(f"Finished callback for {stage.name}")
17031723

17041724
return new_result
17051725

@@ -1785,6 +1805,9 @@ def progress(
17851805

17861806
self._num_steps += 1
17871807

1808+
if self._debug_mode:
1809+
logger.info(f"Starting pipeline step {self._num_steps}")
1810+
17881811
for stage_idx in range(self.num_stages):
17891812
stage_output_idx = self.num_stages - 1 - stage_idx
17901813
self._run_stage(
@@ -1805,6 +1828,8 @@ def progress(
18051828
self.flush_end()
18061829
return self.progress(dataloader_iter)
18071830

1831+
if self._debug_mode:
1832+
logger.info(f"Finished pipeline step {self._num_steps}")
18081833
return out
18091834

18101835

0 commit comments

Comments
 (0)