Skip to content

Allow a context manager to be called around apply_jit #2927

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
import copy
import enum
import unittest
from typing import List
from contextlib import contextmanager
from typing import Generator, List
from unittest.mock import MagicMock

import torch
Expand Down Expand Up @@ -43,6 +44,29 @@ class ModelType(enum.Enum):


class TrainPipelineUtilsTest(TrainPipelineSparseDistTestBase):
# pyre-fixme[56]: Pyre was not able to infer the type of argument
@unittest.skipIf(
not torch.cuda.is_available(),
"Not enough GPUs, this test requires at least one GPU",
)
def test_rewrite_model_apply_jit(self) -> None:
@contextmanager
def apply_jit_context(events: list[str]) -> Generator[None, None, None]:
events.append("__enter__")
yield
events.append("__exit__")

events = []
_rewrite_model(
model=self._setup_model(),
context=TrainPipelineContext(),
dist_stream=None,
apply_jit=True,
apply_jit_context=apply_jit_context(events),
)

self.assertEqual(events, ["__enter__", "__exit__"])

# pyre-fixme[56]: Pyre was not able to infer the type of argument
@unittest.skipIf(
not torch.cuda.is_available(),
Expand Down
31 changes: 30 additions & 1 deletion torchrec/distributed/train_pipeline/train_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,8 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]):
(applicable to 2D sharding only)
if set and DMP collection is enabled for 2D sharding,
sync DMPs every N batches (default to 1, i.e. every batch, None to disable)
apply_jit_context (Optional[ContextManager]): a context manager that
will surround the application of the JIT
"""

# The PipelinedForward class that is used in _rewrite_model
Expand All @@ -413,13 +415,15 @@ def __init__(
] = None,
dmp_collection_sync_interval_batches: Optional[int] = 1,
enqueue_batch_after_forward: bool = False,
apply_jit_context: Optional[ContextManager[None]] = None,
) -> None:
self._model = model
self._optimizer = optimizer
self._device = device
self._execute_all_batches = execute_all_batches
self._apply_jit = apply_jit
self._enqueue_batch_after_forward = enqueue_batch_after_forward
self._apply_jit_context = apply_jit_context

if device.type == "cuda":
# use two data streams to support two concurrent batches
Expand Down Expand Up @@ -716,6 +720,7 @@ def _pipeline_model(
apply_jit=self._apply_jit,
pipelined_forward=pipelined_forward,
pipeline_postproc=self._pipeline_postproc,
apply_jit_context=self._apply_jit_context,
)
# initializes input dist, so we can override input dist forwards
self.start_sparse_data_dist(batch, context)
Expand Down Expand Up @@ -904,6 +909,8 @@ class TrainPipelineFusedSparseDist(TrainPipelineSparseDist[In, Out]):
TODO: pipeline_postproc, custom_model_fwd, strict
use_emb_lookuo_stream (bool): if true invoke the compute_and_output_dist
(for batch i+1) using a new stream, else re-using the data_dist stream
apply_jit_context (ContextManager): a context manager that will surround the
application of the JIT
"""

# The PipelinedForward class that is used in _rewrite_model
Expand All @@ -922,6 +929,7 @@ def __init__(
] = None,
strict: bool = False,
emb_lookup_stream: str = "data_dist", # new, current, data_dist (default)
apply_jit_context: Optional[ContextManager[None]] = None,
) -> None:
super().__init__(
model=model,
Expand All @@ -932,6 +940,7 @@ def __init__(
context_type=EmbeddingTrainPipelineContext,
pipeline_postproc=pipeline_postproc,
custom_model_fwd=custom_model_fwd,
apply_jit_context=apply_jit_context,
)
if emb_lookup_stream == "new":
self._emb_lookup_stream: Optional[torch.Stream] = (
Expand Down Expand Up @@ -1066,6 +1075,8 @@ class TrainPipelineSemiSync(TrainPipelineSparseDist[In, Out]):
(applicable to 2D sharding only)
if set and DMP collection is enabled for 2D sharding,
sync DMPs every N batches (default to 1, i.e. every batch, None to disable)
apply_jit_context (ContextManager): a context manager that will surround the
application of the JIT
"""

# The PipelinedForward class that is used in _rewrite_model
Expand All @@ -1086,6 +1097,7 @@ def __init__(
] = None,
strict: bool = False,
dmp_collection_sync_interval_batches: Optional[int] = 1,
apply_jit_context: Optional[ContextManager[None]] = None,
) -> None:
super().__init__(
model=model,
Expand All @@ -1097,6 +1109,7 @@ def __init__(
pipeline_postproc=pipeline_postproc,
custom_model_fwd=custom_model_fwd,
dmp_collection_sync_interval_batches=dmp_collection_sync_interval_batches,
apply_jit_context=apply_jit_context,
)
self._start_batch = start_batch
self._stash_gradients = stash_gradients
Expand Down Expand Up @@ -1378,6 +1391,8 @@ class PrefetchTrainPipelineSparseDist(TrainPipelineSparseDist[In, Out]):
execute_all_batches (bool): executes remaining batches in pipeline after
exhausting dataloader iterator.
apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules.
apply_jit_context (ContextManager): a context manager that will surround the
application of the JIT
"""

# The PipelinedForward class that is used in _rewrite_model
Expand All @@ -1394,6 +1409,7 @@ def __init__(
custom_model_fwd: Optional[
Callable[[Optional[In]], Tuple[torch.Tensor, Out]]
] = None,
apply_jit_context: Optional[ContextManager[None]] = None,
) -> None:
super().__init__(
model=model,
Expand All @@ -1404,6 +1420,7 @@ def __init__(
context_type=PrefetchTrainPipelineContext,
pipeline_postproc=pipeline_postproc,
custom_model_fwd=custom_model_fwd,
apply_jit_context=apply_jit_context,
)
self._context = PrefetchTrainPipelineContext(version=0)
self._prefetch_stream: Optional[torch.Stream] = (
Expand Down Expand Up @@ -1535,6 +1552,8 @@ class EvalPipelineSparseDist(TrainPipelineSparseDist[In, Out]):
device (torch.device): device where device transfer, sparse data dist, and
forward/backward pass will happen.
apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules.
apply_jit_context (Optional[ContextManager]): a context manager that
will surround the application of the JIT
"""

# The PipelinedForward class that is used in _rewrite_model
Expand All @@ -1546,8 +1565,16 @@ def __init__(
optimizer: torch.optim.Optimizer,
device: torch.device,
apply_jit: bool = False,
apply_jit_context: Optional[ContextManager[None]] = None,
) -> None:
super().__init__(model, optimizer, device, True, apply_jit)
super().__init__(
model,
optimizer,
device,
True,
apply_jit,
apply_jit_context=apply_jit_context,
)
self._batch_loader: Optional[DataLoadingThread[In]] = None

def __del__(self) -> None:
Expand Down Expand Up @@ -1909,6 +1936,7 @@ def __init__(
custom_model_fwd: Optional[
Callable[[Optional[In]], Tuple[torch.Tensor, Out]]
] = None,
apply_jit_context: Optional[ContextManager[None]] = None,
) -> None:
super().__init__(
model,
Expand All @@ -1919,6 +1947,7 @@ def __init__(
context_type,
pipeline_postproc,
custom_model_fwd,
apply_jit_context=apply_jit_context,
)

torch._logging.set_logs(compiled_autograd_verbose=True)
Expand Down
16 changes: 11 additions & 5 deletions torchrec/distributed/train_pipeline/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import itertools
import logging
from collections import defaultdict, OrderedDict
from contextlib import AbstractContextManager
from contextlib import AbstractContextManager, nullcontext
from dataclasses import dataclass, field

from itertools import chain
Expand All @@ -22,6 +22,7 @@
Any,
Callable,
cast,
ContextManager,
Dict,
Generator,
Generic,
Expand Down Expand Up @@ -1540,6 +1541,7 @@ def _rewrite_model( # noqa C901
pipelined_forward: Type[BaseForward[TrainPipelineContext]] = PipelinedForward,
pipeline_postproc: bool = False,
default_stream: Optional[torch.Stream] = None,
apply_jit_context: Optional[ContextManager[None]] = None,
) -> Tuple[
List[ShardedModule],
torch.nn.Module,
Expand Down Expand Up @@ -1643,10 +1645,14 @@ def _rewrite_model( # noqa C901

# JIT script unsharded modules if applicable.
if apply_jit:
graph_model = torch.fx.GraphModule(model, graph)
_jit_modules(graph_model, "")
if isinstance(input_model, DistributedModelParallel):
input_model.module = graph_model
if apply_jit_context is None:
apply_jit_context = nullcontext()

with apply_jit_context:
graph_model = torch.fx.GraphModule(model, graph)
_jit_modules(graph_model, "")
if isinstance(input_model, DistributedModelParallel):
input_model.module = graph_model

if non_pipelined_sharded_modules:
logger.warn(
Expand Down
Loading