@@ -392,6 +392,8 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]):
392
392
(applicable to 2D sharding only)
393
393
if set and DMP collection is enabled for 2D sharding,
394
394
sync DMPs every N batches (default to 1, i.e. every batch, None to disable)
395
+ apply_jit_context (Optional[ContextManager]): a context manager that
396
+ will surround the application of the JIT
395
397
"""
396
398
397
399
# The PipelinedForward class that is used in _rewrite_model
@@ -412,13 +414,15 @@ def __init__(
412
414
] = None ,
413
415
dmp_collection_sync_interval_batches : Optional [int ] = 1 ,
414
416
enqueue_batch_after_forward : bool = False ,
417
+ apply_jit_context : Optional [ContextManager [None ]] = None ,
415
418
) -> None :
416
419
self ._model = model
417
420
self ._optimizer = optimizer
418
421
self ._device = device
419
422
self ._execute_all_batches = execute_all_batches
420
423
self ._apply_jit = apply_jit
421
424
self ._enqueue_batch_after_forward = enqueue_batch_after_forward
425
+ self ._apply_jit_context = apply_jit_context
422
426
423
427
if device .type == "cuda" :
424
428
# use two data streams to support two concurrent batches
@@ -716,6 +720,7 @@ def _pipeline_model(
716
720
apply_jit = self ._apply_jit ,
717
721
pipelined_forward = pipelined_forward ,
718
722
pipeline_postproc = self ._pipeline_postproc ,
723
+ apply_jit_context = self ._apply_jit_context ,
719
724
)
720
725
# initializes input dist, so we can override input dist forwards
721
726
self .start_sparse_data_dist (batch , context )
@@ -914,6 +919,8 @@ class TrainPipelineFusedSparseDist(TrainPipelineSparseDist[In, Out]):
914
919
TODO: pipeline_postproc, custom_model_fwd, strict
915
920
use_emb_lookuo_stream (bool): if true invoke the compute_and_output_dist
916
921
(for batch i+1) using a new stream, else re-using the data_dist stream
922
+ apply_jit_context (ContextManager): a context manager that will surround the
923
+ application of the JIT
917
924
"""
918
925
919
926
# The PipelinedForward class that is used in _rewrite_model
@@ -932,6 +939,7 @@ def __init__(
932
939
] = None ,
933
940
strict : bool = False ,
934
941
emb_lookup_stream : str = "data_dist" , # new, current, data_dist (default)
942
+ apply_jit_context : Optional [ContextManager [None ]] = None ,
935
943
) -> None :
936
944
super ().__init__ (
937
945
model = model ,
@@ -942,6 +950,7 @@ def __init__(
942
950
context_type = EmbeddingTrainPipelineContext ,
943
951
pipeline_postproc = pipeline_postproc ,
944
952
custom_model_fwd = custom_model_fwd ,
953
+ apply_jit_context = apply_jit_context ,
945
954
)
946
955
if emb_lookup_stream == "new" :
947
956
self ._emb_lookup_stream : Optional [torch .Stream ] = (
@@ -1076,6 +1085,8 @@ class TrainPipelineSemiSync(TrainPipelineSparseDist[In, Out]):
1076
1085
(applicable to 2D sharding only)
1077
1086
if set and DMP collection is enabled for 2D sharding,
1078
1087
sync DMPs every N batches (default to 1, i.e. every batch, None to disable)
1088
+ apply_jit_context (ContextManager): a context manager that will surround the
1089
+ application of the JIT
1079
1090
"""
1080
1091
1081
1092
# The PipelinedForward class that is used in _rewrite_model
@@ -1096,6 +1107,7 @@ def __init__(
1096
1107
] = None ,
1097
1108
strict : bool = False ,
1098
1109
dmp_collection_sync_interval_batches : Optional [int ] = 1 ,
1110
+ apply_jit_context : Optional [ContextManager [None ]] = None ,
1099
1111
) -> None :
1100
1112
super ().__init__ (
1101
1113
model = model ,
@@ -1107,6 +1119,7 @@ def __init__(
1107
1119
pipeline_postproc = pipeline_postproc ,
1108
1120
custom_model_fwd = custom_model_fwd ,
1109
1121
dmp_collection_sync_interval_batches = dmp_collection_sync_interval_batches ,
1122
+ apply_jit_context = apply_jit_context ,
1110
1123
)
1111
1124
self ._start_batch = start_batch
1112
1125
self ._stash_gradients = stash_gradients
@@ -1395,6 +1408,8 @@ class PrefetchTrainPipelineSparseDist(TrainPipelineSparseDist[In, Out]):
1395
1408
execute_all_batches (bool): executes remaining batches in pipeline after
1396
1409
exhausting dataloader iterator.
1397
1410
apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules.
1411
+ apply_jit_context (ContextManager): a context manager that will surround the
1412
+ application of the JIT
1398
1413
"""
1399
1414
1400
1415
# The PipelinedForward class that is used in _rewrite_model
@@ -1411,6 +1426,7 @@ def __init__(
1411
1426
custom_model_fwd : Optional [
1412
1427
Callable [[Optional [In ]], Tuple [torch .Tensor , Out ]]
1413
1428
] = None ,
1429
+ apply_jit_context : Optional [ContextManager [None ]] = None ,
1414
1430
) -> None :
1415
1431
super ().__init__ (
1416
1432
model = model ,
@@ -1421,6 +1437,7 @@ def __init__(
1421
1437
context_type = PrefetchTrainPipelineContext ,
1422
1438
pipeline_postproc = pipeline_postproc ,
1423
1439
custom_model_fwd = custom_model_fwd ,
1440
+ apply_jit_context = apply_jit_context ,
1424
1441
)
1425
1442
self ._context = PrefetchTrainPipelineContext (version = 0 )
1426
1443
self ._prefetch_stream : Optional [torch .Stream ] = (
@@ -1552,6 +1569,8 @@ class EvalPipelineSparseDist(TrainPipelineSparseDist[In, Out]):
1552
1569
device (torch.device): device where device transfer, sparse data dist, and
1553
1570
forward/backward pass will happen.
1554
1571
apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules.
1572
+ apply_jit_context (Optional[ContextManager]): a context manager that
1573
+ will surround the application of the JIT
1555
1574
"""
1556
1575
1557
1576
# The PipelinedForward class that is used in _rewrite_model
@@ -1563,8 +1582,16 @@ def __init__(
1563
1582
optimizer : torch .optim .Optimizer ,
1564
1583
device : torch .device ,
1565
1584
apply_jit : bool = False ,
1585
+ apply_jit_context : Optional [ContextManager [None ]] = None ,
1566
1586
) -> None :
1567
- super ().__init__ (model , optimizer , device , True , apply_jit )
1587
+ super ().__init__ (
1588
+ model ,
1589
+ optimizer ,
1590
+ device ,
1591
+ True ,
1592
+ apply_jit ,
1593
+ apply_jit_context = apply_jit_context ,
1594
+ )
1568
1595
self ._batch_loader : Optional [DataLoadingThread [In ]] = None
1569
1596
1570
1597
def __del__ (self ) -> None :
@@ -1926,6 +1953,7 @@ def __init__(
1926
1953
custom_model_fwd : Optional [
1927
1954
Callable [[Optional [In ]], Tuple [torch .Tensor , Out ]]
1928
1955
] = None ,
1956
+ apply_jit_context : Optional [ContextManager [None ]] = None ,
1929
1957
) -> None :
1930
1958
super ().__init__ (
1931
1959
model ,
@@ -1936,6 +1964,7 @@ def __init__(
1936
1964
context_type ,
1937
1965
pipeline_postproc ,
1938
1966
custom_model_fwd ,
1967
+ apply_jit_context = apply_jit_context ,
1939
1968
)
1940
1969
1941
1970
torch ._logging .set_logs (compiled_autograd_verbose = True )
0 commit comments