Skip to content

Add support for pipelined postproc #2940

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2155,7 +2155,8 @@ def gpu_postproc(x: StageOut) -> StageOut:
name="start_sparse_data_dist",
runnable=sdd.start_sparse_data_dist,
stream=sdd.data_dist_stream,
fill_callback=sdd.wait_sparse_data_dist,
fill_callback=sdd.fill_callback,
data_exhausted_callback=sdd.data_exhausted_callback,
),
]
pipeline = StagedTrainPipeline(
Expand Down Expand Up @@ -2222,7 +2223,8 @@ def gpu_postproc(x: StageOut) -> StageOut:
name="start_sparse_data_dist",
runnable=sdd.start_sparse_data_dist,
stream=sdd.data_dist_stream,
fill_callback=sdd.wait_sparse_data_dist,
fill_callback=sdd.fill_callback,
data_exhausted_callback=sdd.data_exhausted_callback,
),
]

Expand Down Expand Up @@ -2333,7 +2335,8 @@ def test_model_detach(self) -> None:
name="start_sparse_data_dist",
runnable=sdd.start_sparse_data_dist,
stream=sdd.data_dist_stream,
fill_callback=sdd.wait_sparse_data_dist,
fill_callback=sdd.fill_callback,
data_exhausted_callback=sdd.data_exhausted_callback,
),
]

Expand Down Expand Up @@ -2526,7 +2529,8 @@ def gpu_postproc(x: StageOut) -> StageOut:
name="start_sparse_data_dist",
runnable=sdd.start_sparse_data_dist,
stream=sdd.data_dist_stream,
fill_callback=sdd.wait_sparse_data_dist,
fill_callback=sdd.fill_callback,
data_exhausted_callback=sdd.data_exhausted_callback,
),
PipelineStage(
name="prefetch",
Expand Down
61 changes: 25 additions & 36 deletions torchrec/distributed/train_pipeline/train_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
StageOut,
StageOutputWithEvent,
TrainPipelineContext,
use_context_for_postprocs,
)
from torchrec.distributed.types import Awaitable
from torchrec.pt2.checks import is_torchdynamo_compiling
Expand Down Expand Up @@ -792,19 +793,9 @@ def start_sparse_data_dist(
with self._stream_context(self._data_dist_stream):
_wait_for_batch(batch, self._memcpy_stream)

original_contexts = [p.get_context() for p in self._pipelined_postprocs]

# Temporarily set context for next iter to populate cache
for postproc_mod in self._pipelined_postprocs:
postproc_mod.set_context(context)

_start_data_dist(self._pipelined_modules, batch, context)

# Restore context for model fwd
for module, context in zip(
self._pipelined_postprocs, original_contexts
):
module.set_context(context)
with use_context_for_postprocs(self._pipelined_postprocs, context):
_start_data_dist(self._pipelined_modules, batch, context)

def wait_sparse_data_dist(self, context: TrainPipelineContext) -> None:
"""
Expand Down Expand Up @@ -1325,22 +1316,15 @@ def start_sparse_data_dist(
return

# Temporarily set context for next iter to populate cache
original_contexts = [p.get_context() for p in self._pipelined_postprocs]
for postproc_mod in self._pipelined_postprocs:
postproc_mod.set_context(context)

with record_function(f"## start_sparse_data_dist {context.index} ##"):
with self._stream_context(self._data_dist_stream):
_wait_for_events(batch, context, self._data_dist_stream)
model_input = self.extract_model_input_from_batch(batch)
_start_data_dist(self._pipelined_modules, model_input, context)
event = torch.get_device_module(self._device).Event()
event.record()
context.events.append(event)

# Restore context for model forward
for module, context in zip(self._pipelined_postprocs, original_contexts):
module.set_context(context)
with use_context_for_postprocs(self._pipelined_postprocs, context):
with record_function(f"## start_sparse_data_dist {context.index} ##"):
with self._stream_context(self._data_dist_stream):
_wait_for_events(batch, context, self._data_dist_stream)
model_input = self.extract_model_input_from_batch(batch)
_start_data_dist(self._pipelined_modules, model_input, context)
event = torch.get_device_module(self._device).Event()
event.record()
context.events.append(event)

def start_embedding_lookup(
self,
Expand Down Expand Up @@ -1727,8 +1711,6 @@ def _run_with_event(
inputs: Optional[In],
stream: torch.Stream,
) -> StageOutputWithEvent:
if inputs is None:
return (None, None)
with self._stream_context(stream):
# If there is no previous event, data is entering the pipeline
if event is not None:
Expand Down Expand Up @@ -1783,12 +1765,19 @@ def _run_stage(
assert batch_to_wait_with_event is not None
batch_to_wait, event = batch_to_wait_with_event

new_result = self._run_with_event(
runnable=stage.runnable,
event=event,
inputs=batch_to_wait,
stream=stage.stream,
)
if batch_to_wait is not None:
new_result = self._run_with_event(
runnable=stage.runnable,
event=event,
inputs=batch_to_wait,
stream=stage.stream,
)
else:
new_result = (None, None)
if (
data_exhausted_callback := stage.data_exhausted_callback
) is not None:
data_exhausted_callback()

self._stage_outputs[batch_offset] = new_result
if self._debug_mode:
Expand Down
Loading
Loading