fix(sft): prevent validation deadlock with FSDP#2516
Conversation
Refactor the data iteration logic in the SFT training loop to handle cases where the data iterator may be exhausted. Replace the for loop with a while loop that checks for data availability using `next(data_iter, None)` and a tensor flag to ensure the training process exits gracefully when no more data is available. This change enhances robustness and prevents potential runtime errors during training.
|
Hey,
I am not sure to understand when this happen |
@samsja CatDataset (cat packing) produces a different number of packed chunks per rank for variable-length data (seq_len packing drops tails). Validation runs to exhaustion (max_epochs=1), so ranks execute different numbers of FSDP forwards. Since each forward contains collectives, some ranks exit early while others are still in all-gathers |
we never had this issue, pretty sure its handle somewhere else, can you show how to reproduce the issue ? |
@samsja [val] [val.data] |
Fix #2515: unsynchronized iteration in
run_eval_loopthat deadlocks when ranks have different validation batch counts.With FSDP, each forward pass is a collective. Ranks get different batch counts from variable-length data. When one rank exits the loop first, others deadlock on the all-gather.
Fix: one scalar all-reduce per batch so all ranks exit together.
Note
Medium Risk
Changes distributed validation control flow to add a per-batch
all_reducesynchronization, which could affect validation performance or mask dataloader issues but is localized to eval.Overview
Fixes a distributed validation deadlock where ranks could iterate different numbers of validation batches under FSDP and hang in collectives.
run_eval_loopnow iterates withnext(..., None)and uses a per-batchdist.all_reduceon ahas_dataflag so all ranks exit the eval loop together before aggregating loss/token/NAN counts.Reviewed by Cursor Bugbot for commit c890d17. Bugbot is set up for automated code reviews on this repo. Configure here.