From c890d17d3185c85ab383fa00d2d98707ab53ed38 Mon Sep 17 00:00:00 2001 From: Zach Wang Date: Sat, 16 May 2026 00:25:57 +0000 Subject: [PATCH] fix(trainer): improve data iteration in SFT training loop Refactor the data iteration logic in the SFT training loop to handle cases where the data iterator may be exhausted. Replace the for loop with a while loop that checks for data availability using `next(data_iter, None)` and a tensor flag to ensure the training process exits gracefully when no more data is available. This change enhances robustness and prevents potential runtime errors during training. --- src/prime_rl/trainer/sft/train.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/prime_rl/trainer/sft/train.py b/src/prime_rl/trainer/sft/train.py index 032f3b136b..e0426ffc41 100644 --- a/src/prime_rl/trainer/sft/train.py +++ b/src/prime_rl/trainer/sft/train.py @@ -265,8 +265,15 @@ def run_eval_loop(data_iter): total_token_count = torch.tensor(0, dtype=torch.int64, device="cuda") nan_count = torch.tensor(0, device="cuda") + has_data = torch.ones(1, dtype=torch.int32, device="cuda") + data_iter = iter(data_iter) with torch.no_grad(): - for micro_batch in data_iter: + while True: + micro_batch = next(data_iter, None) + has_data.fill_(micro_batch is not None) + dist.all_reduce(has_data, op=dist.ReduceOp.MIN) + if has_data.item() == 0: + break loss_sum, token_count = compute_loss(micro_batch) if not torch.isnan(loss_sum.detach()): total_loss_sum += loss_sum.detach()