diff --git a/src/prime_rl/trainer/sft/train.py b/src/prime_rl/trainer/sft/train.py index 032f3b136b..e0426ffc41 100644 --- a/src/prime_rl/trainer/sft/train.py +++ b/src/prime_rl/trainer/sft/train.py @@ -265,8 +265,15 @@ def run_eval_loop(data_iter): total_token_count = torch.tensor(0, dtype=torch.int64, device="cuda") nan_count = torch.tensor(0, device="cuda") + has_data = torch.ones(1, dtype=torch.int32, device="cuda") + data_iter = iter(data_iter) with torch.no_grad(): - for micro_batch in data_iter: + while True: + micro_batch = next(data_iter, None) + has_data.fill_(micro_batch is not None) + dist.all_reduce(has_data, op=dist.ReduceOp.MIN) + if has_data.item() == 0: + break loss_sum, token_count = compute_loss(micro_batch) if not torch.isnan(loss_sum.detach()): total_loss_sum += loss_sum.detach()