Skip to content

Commit 9c67b0c

Browse files
kddnewtonfacebook-github-bot
authored andcommitted
Allow a context manager to be called around apply_jit (#2927)
Summary: When running torch.jit.script on the various forward functions, you can run into issues if there are any other utilites interacting with the function definitions. As an example, if you have another JIT running, you need to disable it throughout this process. This commit adds the ability to additionally pass an apply_jit_context context manager wherever apply_jit is currently passed that will be called around the application of the torch jit. Differential Revision: D73781040
1 parent f03bd79 commit 9c67b0c

File tree

3 files changed

+66
-7
lines changed

3 files changed

+66
-7
lines changed

torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py

+25-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
import copy
1111
import enum
1212
import unittest
13-
from typing import List
13+
from contextlib import contextmanager
14+
from typing import Generator, List
1415
from unittest.mock import MagicMock
1516

1617
import torch
@@ -42,6 +43,29 @@ class ModelType(enum.Enum):
4243

4344

4445
class TrainPipelineUtilsTest(TrainPipelineSparseDistTestBase):
46+
# pyre-fixme[56]: Pyre was not able to infer the type of argument
47+
@unittest.skipIf(
48+
not torch.cuda.is_available(),
49+
"Not enough GPUs, this test requires at least one GPU",
50+
)
51+
def test_rewrite_model_apply_jit(self) -> None:
52+
@contextmanager
53+
def apply_jit_context(events: list[str]) -> Generator[None, None, None]:
54+
events.append("__enter__")
55+
yield
56+
events.append("__exit__")
57+
58+
events = []
59+
_rewrite_model(
60+
model=self._setup_model(),
61+
context=TrainPipelineContext(),
62+
dist_stream=None,
63+
apply_jit=True,
64+
apply_jit_context=apply_jit_context(events),
65+
)
66+
67+
self.assertEqual(events, ["__enter__", "__exit__"])
68+
4569
# pyre-fixme[56]: Pyre was not able to infer the type of argument
4670
@unittest.skipIf(
4771
not torch.cuda.is_available(),

torchrec/distributed/train_pipeline/train_pipelines.py

+30-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import contextlib
1212
import logging
1313
from collections import deque
14+
from contextlib import nullcontext
1415
from dataclasses import dataclass
1516
from typing import (
1617
Any,
@@ -319,6 +320,9 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
319320
return output
320321

321322

323+
_apply_jit_context_default: ContextManager[None] = nullcontext()
324+
325+
322326
class TrainPipelineSparseDist(TrainPipeline[In, Out]):
323327
"""
324328
This pipeline overlaps device transfer, and `ShardedModule.input_dist()` with
@@ -344,6 +348,8 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]):
344348
execute_all_batches (bool): executes remaining batches in pipeline after
345349
exhausting dataloader iterator.
346350
apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules.
351+
apply_jit_context (ContextManager): a context manager that will surround the
352+
application of the JIT
347353
"""
348354

349355
# The PipelinedForward class that is used in _rewrite_model
@@ -362,12 +368,14 @@ def __init__(
362368
custom_model_fwd: Optional[
363369
Callable[[Optional[In]], Tuple[torch.Tensor, Out]]
364370
] = None,
371+
apply_jit_context: ContextManager[None] = _apply_jit_context_default,
365372
) -> None:
366373
self._model = model
367374
self._optimizer = optimizer
368375
self._device = device
369376
self._execute_all_batches = execute_all_batches
370377
self._apply_jit = apply_jit
378+
self._apply_jit_context = apply_jit_context
371379

372380
if device.type == "cuda":
373381
# use two data streams to support two concurrent batches
@@ -643,6 +651,7 @@ def _pipeline_model(
643651
apply_jit=self._apply_jit,
644652
pipelined_forward=pipelined_forward,
645653
pipeline_postproc=self._pipeline_postproc,
654+
apply_jit_context=self._apply_jit_context,
646655
)
647656
# initializes input dist, so we can override input dist forwards
648657
self.start_sparse_data_dist(batch, context)
@@ -993,6 +1002,8 @@ class TrainPipelineSemiSync(TrainPipelineSparseDist[In, Out]):
9931002
start_batch (int): batch to begin semi-sync training. Typically small period of synchronous training reduces early stage NEX.
9941003
stash_gradients (bool): if True, will store gradients for each parameter to insure true "Semi-Sync"
9951004
training. If False, will update dense optimizer as soon as gradients available (naive "Semi-Sync)
1005+
apply_jit_context (ContextManager): a context manager that will surround the
1006+
application of the JIT
9961007
"""
9971008

9981009
# The PipelinedForward class that is used in _rewrite_model
@@ -1012,6 +1023,7 @@ def __init__(
10121023
Callable[[Optional[In]], Tuple[torch.Tensor, Out]]
10131024
] = None,
10141025
strict: bool = False,
1026+
apply_jit_context: ContextManager[None] = _apply_jit_context_default,
10151027
) -> None:
10161028
super().__init__(
10171029
model=model,
@@ -1022,6 +1034,7 @@ def __init__(
10221034
context_type=EmbeddingTrainPipelineContext,
10231035
pipeline_postproc=pipeline_postproc,
10241036
custom_model_fwd=custom_model_fwd,
1037+
apply_jit_context=apply_jit_context,
10251038
)
10261039
self._start_batch = start_batch
10271040
self._stash_gradients = stash_gradients
@@ -1305,6 +1318,8 @@ class PrefetchTrainPipelineSparseDist(TrainPipelineSparseDist[In, Out]):
13051318
execute_all_batches (bool): executes remaining batches in pipeline after
13061319
exhausting dataloader iterator.
13071320
apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules.
1321+
apply_jit_context (ContextManager): a context manager that will surround the
1322+
application of the JIT
13081323
"""
13091324

13101325
# The PipelinedForward class that is used in _rewrite_model
@@ -1321,6 +1336,7 @@ def __init__(
13211336
custom_model_fwd: Optional[
13221337
Callable[[Optional[In]], Tuple[torch.Tensor, Out]]
13231338
] = None,
1339+
apply_jit_context: ContextManager[None] = _apply_jit_context_default,
13241340
) -> None:
13251341
super().__init__(
13261342
model=model,
@@ -1331,6 +1347,7 @@ def __init__(
13311347
context_type=PrefetchTrainPipelineContext,
13321348
pipeline_postproc=pipeline_postproc,
13331349
custom_model_fwd=custom_model_fwd,
1350+
apply_jit_context=apply_jit_context,
13341351
)
13351352
self._context = PrefetchTrainPipelineContext(version=0)
13361353
self._prefetch_stream: Optional[torch.Stream] = (
@@ -1462,6 +1479,8 @@ class EvalPipelineSparseDist(TrainPipelineSparseDist[In, Out]):
14621479
device (torch.device): device where device transfer, sparse data dist, and
14631480
forward/backward pass will happen.
14641481
apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules.
1482+
apply_jit_context (ContextManager): a context manager that will surround the
1483+
application of the JIT
14651484
"""
14661485

14671486
# The PipelinedForward class that is used in _rewrite_model
@@ -1473,8 +1492,16 @@ def __init__(
14731492
optimizer: torch.optim.Optimizer,
14741493
device: torch.device,
14751494
apply_jit: bool = False,
1495+
apply_jit_context: ContextManager[None] = _apply_jit_context_default,
14761496
) -> None:
1477-
super().__init__(model, optimizer, device, True, apply_jit)
1497+
super().__init__(
1498+
model,
1499+
optimizer,
1500+
device,
1501+
True,
1502+
apply_jit,
1503+
apply_jit_context=apply_jit_context,
1504+
)
14781505
self._batch_loader: Optional[DataLoadingThread[In]] = None
14791506

14801507
def __del__(self) -> None:
@@ -1836,6 +1863,7 @@ def __init__(
18361863
custom_model_fwd: Optional[
18371864
Callable[[Optional[In]], Tuple[torch.Tensor, Out]]
18381865
] = None,
1866+
apply_jit_context: ContextManager[None] = _apply_jit_context_default,
18391867
) -> None:
18401868
super().__init__(
18411869
model,
@@ -1846,6 +1874,7 @@ def __init__(
18461874
context_type,
18471875
pipeline_postproc,
18481876
custom_model_fwd,
1877+
apply_jit_context=apply_jit_context,
18491878
)
18501879

18511880
torch._logging.set_logs(compiled_autograd_verbose=True)

torchrec/distributed/train_pipeline/utils.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import itertools
1212
import logging
1313
from collections import defaultdict, OrderedDict
14-
from contextlib import AbstractContextManager
14+
from contextlib import AbstractContextManager, nullcontext
1515
from dataclasses import dataclass, field
1616

1717
from itertools import chain
@@ -20,6 +20,7 @@
2020
Any,
2121
Callable,
2222
cast,
23+
ContextManager,
2324
Dict,
2425
Generic,
2526
Iterable,
@@ -1480,6 +1481,9 @@ def _pipeline_detach_model(
14801481
setattr(model, postproc_mod.fqn, postproc_mod.postproc_module)
14811482

14821483

1484+
_rewrite_model_apply_jit_context_default: ContextManager[None] = nullcontext()
1485+
1486+
14831487
# pyre-ignore[3] Return type must be specified as type that does not contain
14841488
def _rewrite_model( # noqa C901
14851489
model: torch.nn.Module,
@@ -1490,6 +1494,7 @@ def _rewrite_model( # noqa C901
14901494
pipelined_forward: Type[BaseForward[TrainPipelineContext]] = PipelinedForward,
14911495
pipeline_postproc: bool = False,
14921496
default_stream: Optional[torch.Stream] = None,
1497+
apply_jit_context: ContextManager[None] = _rewrite_model_apply_jit_context_default,
14931498
) -> Tuple[
14941499
List[ShardedModule],
14951500
torch.nn.Module,
@@ -1598,10 +1603,11 @@ def _rewrite_model( # noqa C901
15981603

15991604
# JIT script unsharded modules if applicable.
16001605
if apply_jit:
1601-
graph_model = torch.fx.GraphModule(model, graph)
1602-
_jit_modules(graph_model, "")
1603-
if isinstance(input_model, DistributedModelParallel):
1604-
input_model.module = graph_model
1606+
with apply_jit_context:
1607+
graph_model = torch.fx.GraphModule(model, graph)
1608+
_jit_modules(graph_model, "")
1609+
if isinstance(input_model, DistributedModelParallel):
1610+
input_model.module = graph_model
16051611

16061612
if non_pipelined_sharded_modules:
16071613
logger.warn(

0 commit comments

Comments
 (0)