11
11
import contextlib
12
12
import logging
13
13
from collections import deque
14
+ from contextlib import nullcontext
14
15
from dataclasses import dataclass
15
16
from typing import (
16
17
Any ,
@@ -319,6 +320,9 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
319
320
return output
320
321
321
322
323
+ _apply_jit_context_default : ContextManager [None ] = nullcontext ()
324
+
325
+
322
326
class TrainPipelineSparseDist (TrainPipeline [In , Out ]):
323
327
"""
324
328
This pipeline overlaps device transfer, and `ShardedModule.input_dist()` with
@@ -344,6 +348,8 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]):
344
348
execute_all_batches (bool): executes remaining batches in pipeline after
345
349
exhausting dataloader iterator.
346
350
apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules.
351
+ apply_jit_context (ContextManager): a context manager that will surround the
352
+ application of the JIT
347
353
"""
348
354
349
355
# The PipelinedForward class that is used in _rewrite_model
@@ -362,12 +368,14 @@ def __init__(
362
368
custom_model_fwd : Optional [
363
369
Callable [[Optional [In ]], Tuple [torch .Tensor , Out ]]
364
370
] = None ,
371
+ apply_jit_context : ContextManager [None ] = _apply_jit_context_default ,
365
372
) -> None :
366
373
self ._model = model
367
374
self ._optimizer = optimizer
368
375
self ._device = device
369
376
self ._execute_all_batches = execute_all_batches
370
377
self ._apply_jit = apply_jit
378
+ self ._apply_jit_context = apply_jit_context
371
379
372
380
if device .type == "cuda" :
373
381
# use two data streams to support two concurrent batches
@@ -643,6 +651,7 @@ def _pipeline_model(
643
651
apply_jit = self ._apply_jit ,
644
652
pipelined_forward = pipelined_forward ,
645
653
pipeline_postproc = self ._pipeline_postproc ,
654
+ apply_jit_context = self ._apply_jit_context ,
646
655
)
647
656
# initializes input dist, so we can override input dist forwards
648
657
self .start_sparse_data_dist (batch , context )
@@ -993,6 +1002,8 @@ class TrainPipelineSemiSync(TrainPipelineSparseDist[In, Out]):
993
1002
start_batch (int): batch to begin semi-sync training. Typically small period of synchronous training reduces early stage NEX.
994
1003
stash_gradients (bool): if True, will store gradients for each parameter to insure true "Semi-Sync"
995
1004
training. If False, will update dense optimizer as soon as gradients available (naive "Semi-Sync)
1005
+ apply_jit_context (ContextManager): a context manager that will surround the
1006
+ application of the JIT
996
1007
"""
997
1008
998
1009
# The PipelinedForward class that is used in _rewrite_model
@@ -1012,6 +1023,7 @@ def __init__(
1012
1023
Callable [[Optional [In ]], Tuple [torch .Tensor , Out ]]
1013
1024
] = None ,
1014
1025
strict : bool = False ,
1026
+ apply_jit_context : ContextManager [None ] = _apply_jit_context_default ,
1015
1027
) -> None :
1016
1028
super ().__init__ (
1017
1029
model = model ,
@@ -1022,6 +1034,7 @@ def __init__(
1022
1034
context_type = EmbeddingTrainPipelineContext ,
1023
1035
pipeline_postproc = pipeline_postproc ,
1024
1036
custom_model_fwd = custom_model_fwd ,
1037
+ apply_jit_context = apply_jit_context ,
1025
1038
)
1026
1039
self ._start_batch = start_batch
1027
1040
self ._stash_gradients = stash_gradients
@@ -1305,6 +1318,8 @@ class PrefetchTrainPipelineSparseDist(TrainPipelineSparseDist[In, Out]):
1305
1318
execute_all_batches (bool): executes remaining batches in pipeline after
1306
1319
exhausting dataloader iterator.
1307
1320
apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules.
1321
+ apply_jit_context (ContextManager): a context manager that will surround the
1322
+ application of the JIT
1308
1323
"""
1309
1324
1310
1325
# The PipelinedForward class that is used in _rewrite_model
@@ -1321,6 +1336,7 @@ def __init__(
1321
1336
custom_model_fwd : Optional [
1322
1337
Callable [[Optional [In ]], Tuple [torch .Tensor , Out ]]
1323
1338
] = None ,
1339
+ apply_jit_context : ContextManager [None ] = _apply_jit_context_default ,
1324
1340
) -> None :
1325
1341
super ().__init__ (
1326
1342
model = model ,
@@ -1331,6 +1347,7 @@ def __init__(
1331
1347
context_type = PrefetchTrainPipelineContext ,
1332
1348
pipeline_postproc = pipeline_postproc ,
1333
1349
custom_model_fwd = custom_model_fwd ,
1350
+ apply_jit_context = apply_jit_context ,
1334
1351
)
1335
1352
self ._context = PrefetchTrainPipelineContext (version = 0 )
1336
1353
self ._prefetch_stream : Optional [torch .Stream ] = (
@@ -1462,6 +1479,8 @@ class EvalPipelineSparseDist(TrainPipelineSparseDist[In, Out]):
1462
1479
device (torch.device): device where device transfer, sparse data dist, and
1463
1480
forward/backward pass will happen.
1464
1481
apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules.
1482
+ apply_jit_context (ContextManager): a context manager that will surround the
1483
+ application of the JIT
1465
1484
"""
1466
1485
1467
1486
# The PipelinedForward class that is used in _rewrite_model
@@ -1473,8 +1492,16 @@ def __init__(
1473
1492
optimizer : torch .optim .Optimizer ,
1474
1493
device : torch .device ,
1475
1494
apply_jit : bool = False ,
1495
+ apply_jit_context : ContextManager [None ] = _apply_jit_context_default ,
1476
1496
) -> None :
1477
- super ().__init__ (model , optimizer , device , True , apply_jit )
1497
+ super ().__init__ (
1498
+ model ,
1499
+ optimizer ,
1500
+ device ,
1501
+ True ,
1502
+ apply_jit ,
1503
+ apply_jit_context = apply_jit_context ,
1504
+ )
1478
1505
self ._batch_loader : Optional [DataLoadingThread [In ]] = None
1479
1506
1480
1507
def __del__ (self ) -> None :
@@ -1836,6 +1863,7 @@ def __init__(
1836
1863
custom_model_fwd : Optional [
1837
1864
Callable [[Optional [In ]], Tuple [torch .Tensor , Out ]]
1838
1865
] = None ,
1866
+ apply_jit_context : ContextManager [None ] = _apply_jit_context_default ,
1839
1867
) -> None :
1840
1868
super ().__init__ (
1841
1869
model ,
@@ -1846,6 +1874,7 @@ def __init__(
1846
1874
context_type ,
1847
1875
pipeline_postproc ,
1848
1876
custom_model_fwd ,
1877
+ apply_jit_context = apply_jit_context ,
1849
1878
)
1850
1879
1851
1880
torch ._logging .set_logs (compiled_autograd_verbose = True )
0 commit comments