-
Notifications
You must be signed in to change notification settings - Fork 2.1k
🤝 validate gradient_accumulation_steps vs steps_per_generation for on-policy GRPO #3493
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Nice catch! I think the fix should be here instead: trl/trl/trainer/grpo_trainer.py Line 1127 in c7e3f09
also, - if self.num_iterations > 1 or self.args.steps_per_generation > self.args.gradient_accumulation_steps:
+ if self.num_iterations > 1 or self.gradient_accumulation_steps % self.steps_per_generation != 0: |
…n_steps % steps_per_generation != 0 If `gradient_accumulation_steps` is *not* an exact multiple of `steps_per_generation`, the final (“remainder”) slices are processed after `optimizer.step()` has already updated the weights. They therefore use an updated policy but carry importance weights that were computed under the previous policy, effectively turning those updates into unintended off-policy training.
…n_steps % steps_per_generation != 0 If `gradient_accumulation_steps` is *not* an exact multiple of `steps_per_generation`, the final (“remainder”) slices are processed after `optimizer.step()` has already updated the weights. They therefore use an updated policy but carry importance weights that were computed under the previous policy, effectively turning those updates into unintended off-policy training.
@qgallouedec thanks for the great catch! I have updated the logic as suggested. What changed
Could you please take another look when you have a moment? Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @HarryHsing
thanks for your contribution! So you mean this? (key is Processed after the optim step!!!!
)
# | GPU 0 | GPU 1 |
#
# global_step step <-───> num_generations=2
# <-───────> per_device_train_batch_size=3
# grad_accum ▲ ▲ 0 0 0 0 1 1 2 2 <- Generate for the first `steps_per_generation` (prompts 0 to 11); store the completions; use the first slice to compute the loss
# =4 | 0 1 3 3 4 4 5 5 <- Take the stored generations and use the second slice to compute the loss
# |
# steps_per_gen=3 ▼ 1 2 6 6 7 7 8 8 <- Take the stored generations and use the third slice to compute the loss
# ▼ 1 3 9 9 10 10 11 11 <- Processed after the optim step!!!!
#
# 2 4 12 12 13 13 14 14 <- Generate for the second `steps_per_generation` (prompts 12 to 23); store the completions; use the first slice to compute the loss
# 2 5 15 15 16 16 17 17 <- Take the stored generations and use the second slice to compute the loss
# ... ...
Thanks a lot for double-checking, @shirinyamani 🙏
|
Oh yes! I think it totally make sense now! |
Thanks again, @shirinyamani! |
I'm finding the interplay between steps_per_generation, and grad_accumulation_steps very difficult to follow, though the above diagram helps a lot. On top of that, there is num_iterations as well. Can you please explain what is the point of steps_per_generation?? Why would one want to set it differently from number of gradient accumulation steps? If there isn't really a use-case, it would help greatly to get rid of one of these variables and simplify this logic a bit.. |
When I apply my own sampler, I noticed that As shown in here, the batch size includes the |
…-policy GRPO (huggingface#3493) Co-authored-by: Shirin Yamani <[email protected]>
What does this PR do?
GRPOConfig.__post_init__
.num_iterations == 1
(pure on-policy training) butgradient_accumulation_steps
is not a multiple ofsteps_per_generation
, the final remainder slices of eachgeneration batch are consumed after
optimizer.step()
.That silently turns them into off-policy data.
The new check raises a clear
ValueError
instead.Why is it needed?
Many users raise
steps_per_generation
to reduce the frequency ofmodel.generate()
.If gradient_accumulation_steps is not an exact multiple of
steps_per_generation, the leftover slices are processed with an
updated policy, receive no importance-sampling correction, and can
destabilise training. Fail-fast validation removes that hidden pitfall.
Reference code location
The mismatch shows up a few lines below
this condition in
grpo_trainer.py
:Because the guard only checks “greater than,” a setting like
steps_per_generation = 3, gradient_accumulation_steps = 4
slips through unnoticed.
The final slice is processed after optimizer.step(), using an updated policy,
but with no importance-sampling correction — turning it into unintended off-policy training. Below is an example of what it means;