@@ -1621,8 +1621,6 @@ def _run_with_event(
1621
1621
inputs : Optional [In ],
1622
1622
stream : torch .Stream ,
1623
1623
) -> StageOutputWithEvent :
1624
- if inputs is None :
1625
- return (None , None )
1626
1624
with self ._stream_context (stream ):
1627
1625
# If there is no previous event, data is entering the pipeline
1628
1626
if event is not None :
@@ -1666,6 +1664,11 @@ def _run_stage(
1666
1664
"""
1667
1665
stage = self ._pipeline_stages [stage_idx ]
1668
1666
1667
+ if self ._debug_mode :
1668
+ logger .info (
1669
+ f"Running ## Pipeline Stage { stage_idx } : { stage .name } for batch { batch_offset + self ._num_steps } ##" ,
1670
+ )
1671
+
1669
1672
with record_function (
1670
1673
f"## Pipeline Stage { stage_idx } : { stage .name } for batch { batch_offset + self ._num_steps } ##"
1671
1674
):
@@ -1677,23 +1680,38 @@ def _run_stage(
1677
1680
assert batch_to_wait_with_event is not None
1678
1681
batch_to_wait , event = batch_to_wait_with_event
1679
1682
1680
- new_result = self ._run_with_event (
1681
- runnable = stage .runnable ,
1682
- event = event ,
1683
- inputs = batch_to_wait ,
1684
- stream = stage .stream ,
1685
- )
1683
+ if batch_to_wait is not None :
1684
+ logger .info (
1685
+ f"Executing ## Pipeline Stage { stage_idx } : { stage .name } for batch { batch_offset + self ._num_steps } ##" ,
1686
+ )
1687
+ new_result = self ._run_with_event (
1688
+ runnable = stage .runnable ,
1689
+ event = event ,
1690
+ inputs = batch_to_wait ,
1691
+ stream = stage .stream ,
1692
+ )
1693
+ else :
1694
+ logger .info (
1695
+ f"Skipping due to None ## Pipeline Stage { stage_idx } : { stage .name } for batch { batch_offset + self ._num_steps } ##" ,
1696
+ )
1697
+ new_result = (None , None )
1698
+ if (
1699
+ data_exhausted_callback := stage .data_exhausted_callback
1700
+ ) is not None :
1701
+ data_exhausted_callback ()
1686
1702
1687
1703
self ._stage_outputs [batch_offset ] = new_result
1688
1704
if self ._debug_mode :
1689
1705
logger .info (
1690
- f"Running ## Pipeline Stage { stage_idx } : { stage .name } for batch { batch_offset + self ._num_steps } ##" ,
1706
+ f"Finshed ## Pipeline Stage { stage_idx } : { stage .name } for batch { batch_offset + self ._num_steps } ##" ,
1691
1707
)
1692
1708
1693
1709
if fill and (fill_callback := stage .fill_callback ) is not None :
1694
1710
if self ._debug_mode :
1695
- logger .info (f"Finished callback for { stage .name } " )
1711
+ logger .info (f"Started callback for { stage .name } " )
1696
1712
fill_callback ()
1713
+ if self ._debug_mode :
1714
+ logger .info (f"Finished callback for { stage .name } " )
1697
1715
1698
1716
return new_result
1699
1717
@@ -1779,6 +1797,9 @@ def progress(
1779
1797
1780
1798
self ._num_steps += 1
1781
1799
1800
+ if self ._debug_mode :
1801
+ logger .info (f"Starting pipeline step { self ._num_steps } " )
1802
+
1782
1803
for stage_idx in range (self .num_stages ):
1783
1804
stage_output_idx = self .num_stages - 1 - stage_idx
1784
1805
self ._run_stage (
@@ -1799,6 +1820,8 @@ def progress(
1799
1820
self .flush_end ()
1800
1821
return self .progress (dataloader_iter )
1801
1822
1823
+ if self ._debug_mode :
1824
+ logger .info (f"Finished pipeline step { self ._num_steps } " )
1802
1825
return out
1803
1826
1804
1827
0 commit comments