diff --git a/src/prime_rl/trainer/sft/train.py b/src/prime_rl/trainer/sft/train.py index 1da7fe60af..3917f414c7 100644 --- a/src/prime_rl/trainer/sft/train.py +++ b/src/prime_rl/trainer/sft/train.py @@ -287,8 +287,19 @@ def run_eval_loop(data_iter): total_token_count = torch.tensor(0, dtype=torch.int64, device="cuda") nan_count = torch.tensor(0, device="cuda") + # Variable-length packing yields different per-rank batch counts. Under FSDP + # every forward is a collective, so all ranks must agree on when to stop — + # otherwise the first rank to exit deadlocks the rest in the next all-gather. + # Sync per batch and exit together as soon as any rank exhausts its iterator. + data_iter = iter(data_iter) + with torch.no_grad(): - for micro_batch in data_iter: + while True: + micro_batch = next(data_iter, None) + has_data = torch.tensor(micro_batch is not None, dtype=torch.int32, device="cuda") + dist.all_reduce(has_data, op=dist.ReduceOp.MIN) + if has_data.item() == 0: + break loss_sum, token_count = compute_loss(micro_batch) if not torch.isnan(loss_sum.detach()): total_loss_sum += loss_sum.detach()