@@ -1624,11 +1624,9 @@ def _run_with_event(
1624
1624
self ,
1625
1625
runnable : RunnableType ,
1626
1626
event : Optional [torch .Event ],
1627
- inputs : Optional [ In ] ,
1627
+ inputs : In ,
1628
1628
stream : torch .Stream ,
1629
1629
) -> StageOutputWithEvent :
1630
- if inputs is None :
1631
- return (None , None )
1632
1630
with self ._stream_context (stream ):
1633
1631
# If there is no previous event, data is entering the pipeline
1634
1632
if event is not None :
@@ -1672,6 +1670,11 @@ def _run_stage(
1672
1670
"""
1673
1671
stage = self ._pipeline_stages [stage_idx ]
1674
1672
1673
+ if self ._debug_mode :
1674
+ logger .info (
1675
+ f"Running ## Pipeline Stage { stage_idx } : { stage .name } for batch { batch_offset + self ._num_steps } ##" ,
1676
+ )
1677
+
1675
1678
with record_function (
1676
1679
f"## Pipeline Stage { stage_idx } : { stage .name } for batch { batch_offset + self ._num_steps } ##"
1677
1680
):
@@ -1683,23 +1686,40 @@ def _run_stage(
1683
1686
assert batch_to_wait_with_event is not None
1684
1687
batch_to_wait , event = batch_to_wait_with_event
1685
1688
1686
- new_result = self ._run_with_event (
1687
- runnable = stage .runnable ,
1688
- event = event ,
1689
- inputs = batch_to_wait ,
1690
- stream = stage .stream ,
1691
- )
1689
+ if batch_to_wait is not None :
1690
+ if self ._debug_mode :
1691
+ logger .info (
1692
+ f"Executing ## Pipeline Stage { stage_idx } : { stage .name } for batch { batch_offset + self ._num_steps } ##" ,
1693
+ )
1694
+ new_result = self ._run_with_event (
1695
+ runnable = stage .runnable ,
1696
+ event = event ,
1697
+ inputs = batch_to_wait ,
1698
+ stream = stage .stream ,
1699
+ )
1700
+ else :
1701
+ if self ._debug_mode :
1702
+ logger .info (
1703
+ f"Skipping due to None ## Pipeline Stage { stage_idx } : { stage .name } for batch { batch_offset + self ._num_steps } ##" ,
1704
+ )
1705
+ new_result = (None , None )
1706
+ if (
1707
+ data_exhausted_callback := stage .data_exhausted_callback
1708
+ ) is not None :
1709
+ data_exhausted_callback ()
1692
1710
1693
1711
self ._stage_outputs [batch_offset ] = new_result
1694
1712
if self ._debug_mode :
1695
1713
logger .info (
1696
- f"Running ## Pipeline Stage { stage_idx } : { stage .name } for batch { batch_offset + self ._num_steps } ##" ,
1714
+ f"Finshed ## Pipeline Stage { stage_idx } : { stage .name } for batch { batch_offset + self ._num_steps } ##" ,
1697
1715
)
1698
1716
1699
1717
if fill and (fill_callback := stage .fill_callback ) is not None :
1700
1718
if self ._debug_mode :
1701
- logger .info (f"Finished callback for { stage .name } " )
1719
+ logger .info (f"Started callback for { stage .name } " )
1702
1720
fill_callback ()
1721
+ if self ._debug_mode :
1722
+ logger .info (f"Finished callback for { stage .name } " )
1703
1723
1704
1724
return new_result
1705
1725
@@ -1785,6 +1805,9 @@ def progress(
1785
1805
1786
1806
self ._num_steps += 1
1787
1807
1808
+ if self ._debug_mode :
1809
+ logger .info (f"Starting pipeline step { self ._num_steps } " )
1810
+
1788
1811
for stage_idx in range (self .num_stages ):
1789
1812
stage_output_idx = self .num_stages - 1 - stage_idx
1790
1813
self ._run_stage (
@@ -1805,6 +1828,8 @@ def progress(
1805
1828
self .flush_end ()
1806
1829
return self .progress (dataloader_iter )
1807
1830
1831
+ if self ._debug_mode :
1832
+ logger .info (f"Finished pipeline step { self ._num_steps } " )
1808
1833
return out
1809
1834
1810
1835
0 commit comments