Skip to content

Commit 32df093

Browse files
🤝 validate gradient_accumulation_steps vs steps_per_generation for on-policy GRPO (#3493)
Co-authored-by: Shirin Yamani <[email protected]>
1 parent 0336e4b commit 32df093

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

‎trl/trainer/grpo_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1204,7 +1204,7 @@ def _generate_and_score_completions(
12041204
# When using num_iterations == 1 and steps_per_generation <= gradient_accumulation_steps
12051205
# old_per_token_logps == per_token_logps, so we can skip it's computation here, and use
12061206
# per_token_logps.detach() instead.
1207-
if self.num_iterations > 1 or self.args.steps_per_generation > self.args.gradient_accumulation_steps:
1207+
if self.num_iterations > 1 or self.args.gradient_accumulation_steps % self.args.steps_per_generation != 0:
12081208
old_per_token_logps = self._get_per_token_logps(
12091209
self.model, prompt_completion_ids, attention_mask, logits_to_keep, batch_size
12101210
)

0 commit comments

Comments
 (0)