diff --git a/torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py b/torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py index 358da4d33..08b66d3f5 100644 --- a/torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py +++ b/torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py @@ -10,7 +10,8 @@ import copy import enum import unittest -from typing import List +from contextlib import contextmanager +from typing import Generator, List from unittest.mock import MagicMock import torch @@ -43,6 +44,29 @@ class ModelType(enum.Enum): class TrainPipelineUtilsTest(TrainPipelineSparseDistTestBase): + # pyre-fixme[56]: Pyre was not able to infer the type of argument + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + def test_rewrite_model_apply_jit(self) -> None: + @contextmanager + def apply_jit_context(events: list[str]) -> Generator[None, None, None]: + events.append("__enter__") + yield + events.append("__exit__") + + events = [] + _rewrite_model( + model=self._setup_model(), + context=TrainPipelineContext(), + dist_stream=None, + apply_jit=True, + apply_jit_context=apply_jit_context(events), + ) + + self.assertEqual(events, ["__enter__", "__exit__"]) + # pyre-fixme[56]: Pyre was not able to infer the type of argument @unittest.skipIf( not torch.cuda.is_available(), diff --git a/torchrec/distributed/train_pipeline/train_pipelines.py b/torchrec/distributed/train_pipeline/train_pipelines.py index b8be13994..5026f9008 100644 --- a/torchrec/distributed/train_pipeline/train_pipelines.py +++ b/torchrec/distributed/train_pipeline/train_pipelines.py @@ -393,6 +393,8 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]): (applicable to 2D sharding only) if set and DMP collection is enabled for 2D sharding, sync DMPs every N batches (default to 1, i.e. every batch, None to disable) + apply_jit_context (Optional[ContextManager]): a context manager that + will surround the application of the JIT """ # The PipelinedForward class that is used in _rewrite_model @@ -413,6 +415,7 @@ def __init__( ] = None, dmp_collection_sync_interval_batches: Optional[int] = 1, enqueue_batch_after_forward: bool = False, + apply_jit_context: Optional[ContextManager[None]] = None, ) -> None: self._model = model self._optimizer = optimizer @@ -420,6 +423,7 @@ def __init__( self._execute_all_batches = execute_all_batches self._apply_jit = apply_jit self._enqueue_batch_after_forward = enqueue_batch_after_forward + self._apply_jit_context = apply_jit_context if device.type == "cuda": # use two data streams to support two concurrent batches @@ -716,6 +720,7 @@ def _pipeline_model( apply_jit=self._apply_jit, pipelined_forward=pipelined_forward, pipeline_postproc=self._pipeline_postproc, + apply_jit_context=self._apply_jit_context, ) # initializes input dist, so we can override input dist forwards self.start_sparse_data_dist(batch, context) @@ -904,6 +909,8 @@ class TrainPipelineFusedSparseDist(TrainPipelineSparseDist[In, Out]): TODO: pipeline_postproc, custom_model_fwd, strict use_emb_lookuo_stream (bool): if true invoke the compute_and_output_dist (for batch i+1) using a new stream, else re-using the data_dist stream + apply_jit_context (ContextManager): a context manager that will surround the + application of the JIT """ # The PipelinedForward class that is used in _rewrite_model @@ -922,6 +929,7 @@ def __init__( ] = None, strict: bool = False, emb_lookup_stream: str = "data_dist", # new, current, data_dist (default) + apply_jit_context: Optional[ContextManager[None]] = None, ) -> None: super().__init__( model=model, @@ -932,6 +940,7 @@ def __init__( context_type=EmbeddingTrainPipelineContext, pipeline_postproc=pipeline_postproc, custom_model_fwd=custom_model_fwd, + apply_jit_context=apply_jit_context, ) if emb_lookup_stream == "new": self._emb_lookup_stream: Optional[torch.Stream] = ( @@ -1066,6 +1075,8 @@ class TrainPipelineSemiSync(TrainPipelineSparseDist[In, Out]): (applicable to 2D sharding only) if set and DMP collection is enabled for 2D sharding, sync DMPs every N batches (default to 1, i.e. every batch, None to disable) + apply_jit_context (ContextManager): a context manager that will surround the + application of the JIT """ # The PipelinedForward class that is used in _rewrite_model @@ -1086,6 +1097,7 @@ def __init__( ] = None, strict: bool = False, dmp_collection_sync_interval_batches: Optional[int] = 1, + apply_jit_context: Optional[ContextManager[None]] = None, ) -> None: super().__init__( model=model, @@ -1097,6 +1109,7 @@ def __init__( pipeline_postproc=pipeline_postproc, custom_model_fwd=custom_model_fwd, dmp_collection_sync_interval_batches=dmp_collection_sync_interval_batches, + apply_jit_context=apply_jit_context, ) self._start_batch = start_batch self._stash_gradients = stash_gradients @@ -1378,6 +1391,8 @@ class PrefetchTrainPipelineSparseDist(TrainPipelineSparseDist[In, Out]): execute_all_batches (bool): executes remaining batches in pipeline after exhausting dataloader iterator. apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules. + apply_jit_context (ContextManager): a context manager that will surround the + application of the JIT """ # The PipelinedForward class that is used in _rewrite_model @@ -1394,6 +1409,7 @@ def __init__( custom_model_fwd: Optional[ Callable[[Optional[In]], Tuple[torch.Tensor, Out]] ] = None, + apply_jit_context: Optional[ContextManager[None]] = None, ) -> None: super().__init__( model=model, @@ -1404,6 +1420,7 @@ def __init__( context_type=PrefetchTrainPipelineContext, pipeline_postproc=pipeline_postproc, custom_model_fwd=custom_model_fwd, + apply_jit_context=apply_jit_context, ) self._context = PrefetchTrainPipelineContext(version=0) self._prefetch_stream: Optional[torch.Stream] = ( @@ -1535,6 +1552,8 @@ class EvalPipelineSparseDist(TrainPipelineSparseDist[In, Out]): device (torch.device): device where device transfer, sparse data dist, and forward/backward pass will happen. apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules. + apply_jit_context (Optional[ContextManager]): a context manager that + will surround the application of the JIT """ # The PipelinedForward class that is used in _rewrite_model @@ -1546,8 +1565,16 @@ def __init__( optimizer: torch.optim.Optimizer, device: torch.device, apply_jit: bool = False, + apply_jit_context: Optional[ContextManager[None]] = None, ) -> None: - super().__init__(model, optimizer, device, True, apply_jit) + super().__init__( + model, + optimizer, + device, + True, + apply_jit, + apply_jit_context=apply_jit_context, + ) self._batch_loader: Optional[DataLoadingThread[In]] = None def __del__(self) -> None: @@ -1909,6 +1936,7 @@ def __init__( custom_model_fwd: Optional[ Callable[[Optional[In]], Tuple[torch.Tensor, Out]] ] = None, + apply_jit_context: Optional[ContextManager[None]] = None, ) -> None: super().__init__( model, @@ -1919,6 +1947,7 @@ def __init__( context_type, pipeline_postproc, custom_model_fwd, + apply_jit_context=apply_jit_context, ) torch._logging.set_logs(compiled_autograd_verbose=True) diff --git a/torchrec/distributed/train_pipeline/utils.py b/torchrec/distributed/train_pipeline/utils.py index 2a561f80a..953b8832e 100644 --- a/torchrec/distributed/train_pipeline/utils.py +++ b/torchrec/distributed/train_pipeline/utils.py @@ -13,7 +13,7 @@ import itertools import logging from collections import defaultdict, OrderedDict -from contextlib import AbstractContextManager +from contextlib import AbstractContextManager, nullcontext from dataclasses import dataclass, field from itertools import chain @@ -22,6 +22,7 @@ Any, Callable, cast, + ContextManager, Dict, Generator, Generic, @@ -1540,6 +1541,7 @@ def _rewrite_model( # noqa C901 pipelined_forward: Type[BaseForward[TrainPipelineContext]] = PipelinedForward, pipeline_postproc: bool = False, default_stream: Optional[torch.Stream] = None, + apply_jit_context: Optional[ContextManager[None]] = None, ) -> Tuple[ List[ShardedModule], torch.nn.Module, @@ -1643,10 +1645,14 @@ def _rewrite_model( # noqa C901 # JIT script unsharded modules if applicable. if apply_jit: - graph_model = torch.fx.GraphModule(model, graph) - _jit_modules(graph_model, "") - if isinstance(input_model, DistributedModelParallel): - input_model.module = graph_model + if apply_jit_context is None: + apply_jit_context = nullcontext() + + with apply_jit_context: + graph_model = torch.fx.GraphModule(model, graph) + _jit_modules(graph_model, "") + if isinstance(input_model, DistributedModelParallel): + input_model.module = graph_model if non_pipelined_sharded_modules: logger.warn(