Skip to content

🤝 validate gradient_accumulation_steps vs steps_per_generation for on-policy GRPO #3493

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 25, 2025

Conversation

HarryHsing
Copy link
Contributor

@HarryHsing HarryHsing commented May 25, 2025

What does this PR do?

  • Adds an explicit validation in GRPOConfig.__post_init__.
  • When num_iterations == 1 (pure on-policy training) but
    gradient_accumulation_steps is not a multiple of
    steps_per_generation, the final remainder slices of each
    generation batch are consumed after optimizer.step().
    That silently turns them into off-policy data.
    The new check raises a clear ValueError instead.

Why is it needed?

Many users raise steps_per_generation to reduce the frequency of
model.generate().
If gradient_accumulation_steps is not an exact multiple of
steps_per_generation, the leftover slices are processed with an
updated policy, receive no importance-sampling correction, and can
destabilise training. Fail-fast validation removes that hidden pitfall.


Reference code location

The mismatch shows up a few lines below
this condition in grpo_trainer.py:

# original code path
if self.num_iterations > 1 or \
   self.args.steps_per_generation > self.args.gradient_accumulation_steps:
    old_per_token_logps = ...

Because the guard only checks “greater than,” a setting like
steps_per_generation = 3, gradient_accumulation_steps = 4
slips through unnoticed.
The final slice is processed after optimizer.step(), using an updated policy,
but with no importance-sampling correction — turning it into unintended off-policy training. Below is an example of what it means;

#                                     |   GPU 0  |   GPU 1  |
#
#                global_step   _step   <───>  num_generations = 2
#                                      <──────> per_device_train_batch_size = 3
# grad_accum   ▲  ▲   0          0     0   0   1   1   2   2   ← generate; train on slice-0  (π₀)
#    = 4          |   0          1     3   3   4   4   5   5   ← train on slice-1  (π₀)
#                 |
# steps_per_gen=3 ▼   0          2     6   6   7   7   8   8   ← train on slice-2  (π₀)
#               ▼     0          3     9   9  10  10  11  11   ← train on slice-3  (π₀)
#                                                         └── GA=4 reached → **optimizer.step() → π₁**
#
#                     1          4    12  12  13  13  14  14  ← **first off-policy slice** (data from π₀, grads on π₁)
#                     1          5    15  15  16  16  17  17  ← still off-policy

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@qgallouedec
Copy link
Member

Nice catch! I think the fix should be here instead:

if self.num_iterations > 1 or self.args.steps_per_generation > self.args.gradient_accumulation_steps:

also, num_iteration=1 doesn't imply pure online. It's only the case if gradient_accumulation_steps % steps_per_generation == 0. So the right way I think is:

- if self.num_iterations > 1 or self.args.steps_per_generation > self.args.gradient_accumulation_steps:
+ if self.num_iterations > 1 or self.gradient_accumulation_steps % self.steps_per_generation != 0:

…n_steps % steps_per_generation != 0

If `gradient_accumulation_steps` is *not* an exact multiple of
`steps_per_generation`, the final (“remainder”) slices are processed after
`optimizer.step()` has already updated the weights.  They therefore use an
updated policy but carry importance weights that were computed under the
previous policy, effectively turning those updates into unintended
off-policy training.
…n_steps % steps_per_generation != 0

If `gradient_accumulation_steps` is *not* an exact multiple of
`steps_per_generation`, the final (“remainder”) slices are processed after
`optimizer.step()` has already updated the weights.  They therefore use an
updated policy but carry importance weights that were computed under the
previous policy, effectively turning those updates into unintended
off-policy training.
@HarryHsing HarryHsing reopened this May 27, 2025
@HarryHsing
Copy link
Contributor Author

HarryHsing commented May 27, 2025

Nice catch! I think the fix should be here instead:

if self.num_iterations > 1 or self.args.steps_per_generation > self.args.gradient_accumulation_steps:

also, num_iteration=1 doesn't imply pure online. It's only the case if gradient_accumulation_steps % steps_per_generation == 0. So the right way I think is:

- if self.num_iterations > 1 or self.args.steps_per_generation > self.args.gradient_accumulation_steps:
+ if self.num_iterations > 1 or self.gradient_accumulation_steps % self.steps_per_generation != 0:

@qgallouedec thanks for the great catch! I have updated the logic as suggested.

What changed

  • Updated the guard in grpo_trainer.py

    if self.num_iterations > 1 or \
       self.args.gradient_accumulation_steps % self.args.steps_per_generation != 0:
        old_per_token_logps = ...
  • Removed the duplicate validation that had been added in GRPOConfig.__post_init__, so the check lives in one place.

Could you please take another look when you have a moment? Thanks!

@shirinyamani shirinyamani self-requested a review June 25, 2025 09:57
Copy link
Member

@shirinyamani shirinyamani left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @HarryHsing
thanks for your contribution! So you mean this? (key is Processed after the optim step!!!!)

        #                                      |   GPU 0  |   GPU 1  |
        #
        #                 global_step   step    <-───>  num_generations=2
        #                                       <-───────> per_device_train_batch_size=3
        #  grad_accum    ▲  ▲  0          0     0   0   1   1   2   2   <- Generate for the first `steps_per_generation` (prompts 0 to 11); store the completions; use the first slice to compute the loss
        #     =4            |  0          1     3   3   4   4   5   5   <- Take the stored generations and use the second slice to compute the loss
        #                   |
        #  steps_per_gen=3  ▼  1          2     6   6   7   7   8   8   <- Take the stored generations and use the third slice to compute the loss
        #                ▼     1          3     9   9  10  10  11  11   <-  Processed after the optim step!!!!
        #
        #                      2          4    12  12  13  13  14  14   <- Generate for the second `steps_per_generation` (prompts 12 to 23); store the completions; use the first slice to compute the loss
        #                      2          5    15  15  16  16  17  17   <- Take the stored generations and use the second slice to compute the loss
        #                                          ...                                        ...

@HarryHsing
Copy link
Contributor Author

HarryHsing commented Jun 25, 2025

Thanks a lot for double-checking, @shirinyamani 🙏
Yes, your timeline nails the issue.
I only had to shift the off-policy slice down one row (it shows up after optimizer.step() actually fires).
Let me know if I misunderstood anything or if further tweaks are needed, happy to revise!

#                                     |   GPU 0  |   GPU 1  |
#
#                global_step   _step   <───>  num_generations = 2
#                                      <──────> per_device_train_batch_size = 3
# grad_accum   ▲  ▲   0          0     0   0   1   1   2   2   ← generate; train on slice-0  (π₀)
#    = 4          |   0          1     3   3   4   4   5   5   ← train on slice-1  (π₀)
#                 |
# steps_per_gen=3 ▼   0          2     6   6   7   7   8   8   ← train on slice-2  (π₀)
#               ▼     0          3     9   9  10  10  11  11   ← train on slice-3  (π₀)
#                                                         └── GA=4 reached → **optimizer.step() → π₁**
#
#                     1          4    12  12  13  13  14  14  ← **first off-policy slice** (data from π₀, grads on π₁)
#                     1          5    15  15  16  16  17  17  ← still off-policy

@shirinyamani
Copy link
Member

Thanks a lot for double-checking, @shirinyamani 🙏 Yes, your timeline nails the issue. I only had to shift the off-policy slice down one row (it shows up after optimizer.step() actually fires). Let me know if I misunderstood anything or if further tweaks are needed, happy to revise!

#                                     |   GPU 0  |   GPU 1  |
#
#                global_step   _step   <───>  num_generations = 2
#                                      <──────> per_device_train_batch_size = 3
# grad_accum   ▲  ▲   0          0     0   0   1   1   2   2   ← generate; train on slice-0  (π₀)
#    = 4          |   0          1     3   3   4   4   5   5   ← train on slice-1  (π₀)
#                 |
# steps_per_gen=3 ▼   1          2     6   6   7   7   8   8   ← train on slice-2  (π₀)
#               ▼     1          3     9   9  10  10  11  11   ← train on slice-3  (π₀)
#                                                         └── GA=4 reached → **optimizer.step() → π₁**
#
#                     2          4    12  12  13  13  14  14  ← **first off-policy slice** (data from π₀, grads on π₁)
#                     2          5    15  15  16  16  17  17  ← still off-policy

Oh yes! I think it totally make sense now!
I'll add this visual to the description for future reference!

@shirinyamani shirinyamani changed the title feat(grpo): validate gradient_accumulation_steps vs steps_per_generation for on-policy GRPO 🤝 validate gradient_accumulation_steps vs steps_per_generation for on-policy GRPO Jun 25, 2025
@shirinyamani shirinyamani self-requested a review June 25, 2025 14:48
@HarryHsing
Copy link
Contributor Author

Thanks a lot for double-checking, @shirinyamani 🙏 Yes, your timeline nails the issue. I only had to shift the off-policy slice down one row (it shows up after optimizer.step() actually fires). Let me know if I misunderstood anything or if further tweaks are needed, happy to revise!

#                                     |   GPU 0  |   GPU 1  |
#
#                global_step   _step   <───>  num_generations = 2
#                                      <──────> per_device_train_batch_size = 3
# grad_accum   ▲  ▲   0          0     0   0   1   1   2   2   ← generate; train on slice-0  (π₀)
#    = 4          |   0          1     3   3   4   4   5   5   ← train on slice-1  (π₀)
#                 |
# steps_per_gen=3 ▼   0          2     6   6   7   7   8   8   ← train on slice-2  (π₀)
#               ▼     0          3     9   9  10  10  11  11   ← train on slice-3  (π₀)
#                                                         └── GA=4 reached → **optimizer.step() → π₁**
#
#                     1          4    12  12  13  13  14  14  ← **first off-policy slice** (data from π₀, grads on π₁)
#                     1          5    15  15  16  16  17  17  ← still off-policy

Thanks again, @shirinyamani!
I adjusted global_step so it now increments only after the GA = 4 optimizer.step().

@shirinyamani shirinyamani merged commit 32df093 into huggingface:main Jun 25, 2025
9 of 10 checks passed
@ankur6ue
Copy link

I'm finding the interplay between steps_per_generation, and grad_accumulation_steps very difficult to follow, though the above diagram helps a lot. On top of that, there is num_iterations as well. Can you please explain what is the point of steps_per_generation?? Why would one want to set it differently from number of gradient accumulation steps? If there isn't really a use-case, it would help greatly to get rid of one of these variables and simplify this logic a bit..

@CM-BF
Copy link

CM-BF commented Jul 23, 2025

When I apply my own sampler, I noticed that steps_per_generation and grad_accum do not interplay as shown in the figure.

As shown in here, the batch size includes the steps_per_generation first, and then in Transformer here, the large batch is acquired grad_accum times, which means one update includes train_batch_size * grad_accum * steps_per_generation. I believe this is not the desired behavior, where steps_per_generation is only for fast generation instead of working perpendicular to grad_accum. Please let me know if I am wrong. Thanks.

marcandrelarochelle pushed a commit to marcandrelarochelle/trl that referenced this pull request Jul 29, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants