Skip to content

Commit a927a5c

Browse files
kddnewtonfacebook-github-bot
authored andcommitted
Allow a context manager to be called around apply_jit (#2927)
Summary: When running torch.jit.script on the various forward functions, you can run into issues if there are any other utilites interacting with the function definitions. As an example, if you have another JIT running, you need to disable it throughout this process. This commit adds the ability to additionally pass an apply_jit_context context manager wherever apply_jit is currently passed that will be called around the application of the torch jit. Reviewed By: SonicField Differential Revision: D73781040
1 parent 949278c commit a927a5c

File tree

3 files changed

+66
-7
lines changed

3 files changed

+66
-7
lines changed

torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
import copy
1111
import enum
1212
import unittest
13-
from typing import List
13+
from contextlib import contextmanager
14+
from typing import Generator, List
1415
from unittest.mock import MagicMock
1516

1617
import torch
@@ -43,6 +44,29 @@ class ModelType(enum.Enum):
4344

4445

4546
class TrainPipelineUtilsTest(TrainPipelineSparseDistTestBase):
47+
# pyre-fixme[56]: Pyre was not able to infer the type of argument
48+
@unittest.skipIf(
49+
not torch.cuda.is_available(),
50+
"Not enough GPUs, this test requires at least one GPU",
51+
)
52+
def test_rewrite_model_apply_jit(self) -> None:
53+
@contextmanager
54+
def apply_jit_context(events: list[str]) -> Generator[None, None, None]:
55+
events.append("__enter__")
56+
yield
57+
events.append("__exit__")
58+
59+
events = []
60+
_rewrite_model(
61+
model=self._setup_model(),
62+
context=TrainPipelineContext(),
63+
dist_stream=None,
64+
apply_jit=True,
65+
apply_jit_context=apply_jit_context(events),
66+
)
67+
68+
self.assertEqual(events, ["__enter__", "__exit__"])
69+
4670
# pyre-fixme[56]: Pyre was not able to infer the type of argument
4771
@unittest.skipIf(
4872
not torch.cuda.is_available(),

torchrec/distributed/train_pipeline/train_pipelines.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,8 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]):
392392
(applicable to 2D sharding only)
393393
if set and DMP collection is enabled for 2D sharding,
394394
sync DMPs every N batches (default to 1, i.e. every batch, None to disable)
395+
apply_jit_context (Optional[ContextManager]): a context manager that
396+
will surround the application of the JIT
395397
"""
396398

397399
# The PipelinedForward class that is used in _rewrite_model
@@ -412,13 +414,15 @@ def __init__(
412414
] = None,
413415
dmp_collection_sync_interval_batches: Optional[int] = 1,
414416
enqueue_batch_after_forward: bool = False,
417+
apply_jit_context: Optional[ContextManager[None]] = None,
415418
) -> None:
416419
self._model = model
417420
self._optimizer = optimizer
418421
self._device = device
419422
self._execute_all_batches = execute_all_batches
420423
self._apply_jit = apply_jit
421424
self._enqueue_batch_after_forward = enqueue_batch_after_forward
425+
self._apply_jit_context = apply_jit_context
422426

423427
if device.type == "cuda":
424428
# use two data streams to support two concurrent batches
@@ -716,6 +720,7 @@ def _pipeline_model(
716720
apply_jit=self._apply_jit,
717721
pipelined_forward=pipelined_forward,
718722
pipeline_postproc=self._pipeline_postproc,
723+
apply_jit_context=self._apply_jit_context,
719724
)
720725
# initializes input dist, so we can override input dist forwards
721726
self.start_sparse_data_dist(batch, context)
@@ -914,6 +919,8 @@ class TrainPipelineFusedSparseDist(TrainPipelineSparseDist[In, Out]):
914919
TODO: pipeline_postproc, custom_model_fwd, strict
915920
use_emb_lookuo_stream (bool): if true invoke the compute_and_output_dist
916921
(for batch i+1) using a new stream, else re-using the data_dist stream
922+
apply_jit_context (ContextManager): a context manager that will surround the
923+
application of the JIT
917924
"""
918925

919926
# The PipelinedForward class that is used in _rewrite_model
@@ -932,6 +939,7 @@ def __init__(
932939
] = None,
933940
strict: bool = False,
934941
emb_lookup_stream: str = "data_dist", # new, current, data_dist (default)
942+
apply_jit_context: Optional[ContextManager[None]] = None,
935943
) -> None:
936944
super().__init__(
937945
model=model,
@@ -942,6 +950,7 @@ def __init__(
942950
context_type=EmbeddingTrainPipelineContext,
943951
pipeline_postproc=pipeline_postproc,
944952
custom_model_fwd=custom_model_fwd,
953+
apply_jit_context=apply_jit_context,
945954
)
946955
if emb_lookup_stream == "new":
947956
self._emb_lookup_stream: Optional[torch.Stream] = (
@@ -1076,6 +1085,8 @@ class TrainPipelineSemiSync(TrainPipelineSparseDist[In, Out]):
10761085
(applicable to 2D sharding only)
10771086
if set and DMP collection is enabled for 2D sharding,
10781087
sync DMPs every N batches (default to 1, i.e. every batch, None to disable)
1088+
apply_jit_context (ContextManager): a context manager that will surround the
1089+
application of the JIT
10791090
"""
10801091

10811092
# The PipelinedForward class that is used in _rewrite_model
@@ -1096,6 +1107,7 @@ def __init__(
10961107
] = None,
10971108
strict: bool = False,
10981109
dmp_collection_sync_interval_batches: Optional[int] = 1,
1110+
apply_jit_context: Optional[ContextManager[None]] = None,
10991111
) -> None:
11001112
super().__init__(
11011113
model=model,
@@ -1107,6 +1119,7 @@ def __init__(
11071119
pipeline_postproc=pipeline_postproc,
11081120
custom_model_fwd=custom_model_fwd,
11091121
dmp_collection_sync_interval_batches=dmp_collection_sync_interval_batches,
1122+
apply_jit_context=apply_jit_context,
11101123
)
11111124
self._start_batch = start_batch
11121125
self._stash_gradients = stash_gradients
@@ -1395,6 +1408,8 @@ class PrefetchTrainPipelineSparseDist(TrainPipelineSparseDist[In, Out]):
13951408
execute_all_batches (bool): executes remaining batches in pipeline after
13961409
exhausting dataloader iterator.
13971410
apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules.
1411+
apply_jit_context (ContextManager): a context manager that will surround the
1412+
application of the JIT
13981413
"""
13991414

14001415
# The PipelinedForward class that is used in _rewrite_model
@@ -1411,6 +1426,7 @@ def __init__(
14111426
custom_model_fwd: Optional[
14121427
Callable[[Optional[In]], Tuple[torch.Tensor, Out]]
14131428
] = None,
1429+
apply_jit_context: Optional[ContextManager[None]] = None,
14141430
) -> None:
14151431
super().__init__(
14161432
model=model,
@@ -1421,6 +1437,7 @@ def __init__(
14211437
context_type=PrefetchTrainPipelineContext,
14221438
pipeline_postproc=pipeline_postproc,
14231439
custom_model_fwd=custom_model_fwd,
1440+
apply_jit_context=apply_jit_context,
14241441
)
14251442
self._context = PrefetchTrainPipelineContext(version=0)
14261443
self._prefetch_stream: Optional[torch.Stream] = (
@@ -1552,6 +1569,8 @@ class EvalPipelineSparseDist(TrainPipelineSparseDist[In, Out]):
15521569
device (torch.device): device where device transfer, sparse data dist, and
15531570
forward/backward pass will happen.
15541571
apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules.
1572+
apply_jit_context (Optional[ContextManager]): a context manager that
1573+
will surround the application of the JIT
15551574
"""
15561575

15571576
# The PipelinedForward class that is used in _rewrite_model
@@ -1563,8 +1582,16 @@ def __init__(
15631582
optimizer: torch.optim.Optimizer,
15641583
device: torch.device,
15651584
apply_jit: bool = False,
1585+
apply_jit_context: Optional[ContextManager[None]] = None,
15661586
) -> None:
1567-
super().__init__(model, optimizer, device, True, apply_jit)
1587+
super().__init__(
1588+
model,
1589+
optimizer,
1590+
device,
1591+
True,
1592+
apply_jit,
1593+
apply_jit_context=apply_jit_context,
1594+
)
15681595
self._batch_loader: Optional[DataLoadingThread[In]] = None
15691596

15701597
def __del__(self) -> None:
@@ -1926,6 +1953,7 @@ def __init__(
19261953
custom_model_fwd: Optional[
19271954
Callable[[Optional[In]], Tuple[torch.Tensor, Out]]
19281955
] = None,
1956+
apply_jit_context: Optional[ContextManager[None]] = None,
19291957
) -> None:
19301958
super().__init__(
19311959
model,
@@ -1936,6 +1964,7 @@ def __init__(
19361964
context_type,
19371965
pipeline_postproc,
19381966
custom_model_fwd,
1967+
apply_jit_context=apply_jit_context,
19391968
)
19401969

19411970
torch._logging.set_logs(compiled_autograd_verbose=True)

torchrec/distributed/train_pipeline/utils.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import itertools
1212
import logging
1313
from collections import defaultdict, OrderedDict
14-
from contextlib import AbstractContextManager
14+
from contextlib import AbstractContextManager, nullcontext
1515
from dataclasses import dataclass, field
1616

1717
from itertools import chain
@@ -20,6 +20,7 @@
2020
Any,
2121
Callable,
2222
cast,
23+
ContextManager,
2324
Dict,
2425
Generic,
2526
Iterable,
@@ -1537,6 +1538,7 @@ def _rewrite_model( # noqa C901
15371538
pipelined_forward: Type[BaseForward[TrainPipelineContext]] = PipelinedForward,
15381539
pipeline_postproc: bool = False,
15391540
default_stream: Optional[torch.Stream] = None,
1541+
apply_jit_context: Optional[ContextManager[None]] = None,
15401542
) -> Tuple[
15411543
List[ShardedModule],
15421544
torch.nn.Module,
@@ -1640,10 +1642,14 @@ def _rewrite_model( # noqa C901
16401642

16411643
# JIT script unsharded modules if applicable.
16421644
if apply_jit:
1643-
graph_model = torch.fx.GraphModule(model, graph)
1644-
_jit_modules(graph_model, "")
1645-
if isinstance(input_model, DistributedModelParallel):
1646-
input_model.module = graph_model
1645+
if apply_jit_context is None:
1646+
apply_jit_context = nullcontext()
1647+
1648+
with apply_jit_context:
1649+
graph_model = torch.fx.GraphModule(model, graph)
1650+
_jit_modules(graph_model, "")
1651+
if isinstance(input_model, DistributedModelParallel):
1652+
input_model.module = graph_model
16471653

16481654
if non_pipelined_sharded_modules:
16491655
logger.warn(

0 commit comments

Comments
 (0)