Skip to content

[skyrl-train] Implement loss reduction via advantage normalization and fix token_mean reduction strategy#1296

Open
justinvyu wants to merge 20 commits intoNovaSky-AI:mainfrom
justinvyu:token_mean_loss_reduction
Open

[skyrl-train] Implement loss reduction via advantage normalization and fix token_mean reduction strategy#1296
justinvyu wants to merge 20 commits intoNovaSky-AI:mainfrom
justinvyu:token_mean_loss_reduction

Conversation

@justinvyu
Copy link
Copy Markdown
Contributor

@justinvyu justinvyu commented Mar 9, 2026

This is a breaking change for the default token_mean loss behavior, as well as observed policy_loss metrics! See the "Differences in reported loss metric magnitudes" section below.

Summary

  • Change reduce_loss() to always returns a simple masked sum ((loss * mask).sum()). To achieve different reduction strategies, we pre-scale the advantages before they enter the loss function, which also aligns with how Tinker's API handles it.
    • Scales the loss by the DP size before calling backward() to counteract the default data parallel mean gradient all-reduce across workers to do a sum instead.
  • Fixes the token_mean loss reduction method to take a mean across all tokens in the minibatch rather than averaging across microbatches. Allows running with the old loss reduction with the token_mean_legacy config.

Loss reduction strategies

  • Option 1: token_mean

    • Average loss per token across the entire mini-batch.
    • This is the fixed version where the denominator is the total token count across the full mini-batch, so the gradient is independent of how the minibbatch is split into micro-batches.
  • Option 1b: token_mean_legacy

    • Compute token-mean loss within each micro-batch, then average across micro-batches.
    • This reproduces the token_mean behavior before this PR.
    • The problem: if micro-batches have different token counts, the effective weighting differs from a true global token mean. This is also less usable since changing micro batch size affects the loss and the training dynamics.
    • Kept as a fallback in case of performance regressions — we should remove this down the line.
  • Option 2: sequence_mean

    • Compute per-token loss within each sequence, average across sequences.
    • This is unchanged and is just implemented via advantage normalization instead.
  • Option 3: seq_mean_token_sum_norm

    • Dr. GRPO style — normalize by a fixed constant to avoid any length-dependent weighting.
    • This is unchanged and is just implemented via advantage normalization instead.

Mean all-reduce -> sum all-reduce

We need the loss to be summed across microbatches and data parallel workers:

  • DDP/FSDP defaults to a mean all-reduce for gradients across workers. This PR counteracts this by multiplying by the DP world size in order to keep the loss sum across data parallel groups.
  • Megatron also does a similar mean reduction across microbatches and workers, so we counteract this by multiplying by num microbatches and DP size to achieve the sum.

Difference in reported loss metric magnitudes

You will observe that the loss metric reported has a different magnitude compared to your older runs. This is beacuse the old token_mean implementation was somewhere between a true token mean and a sequence mean due to per-micro-batch normalization (ex: micro_batch_size=1 was equivalent to sequence mean).

The new token_mean is a proper minibatch token mean, while sequence_mean weights every sequence equally regardless of length. When comparing the loss produced by different reduction methods computed on the same advantages, from a real run:

  token_mean (new):  0.322   — every token weighted equally across the mini-batch
  token_mean (old):  0.065   — mean of per-micro-batch token means, where micro_batch_size=4
  sequence_mean:     0.00098 — every sequence weighted equally

The old token_mean gave each micro-batch equal weight rather than each token, so its scale depended on how advantages were distributed across micro-batches. The new implementation is invariant to micro-batch size.

Note that token_mean_legacy reports the old metrics still, and the sequence_mean and seq_mean_token_sum_norm modes also match exactly. See this comment for more details.

Tinker compatibility

Here was the first attempt at fixing the loss reduction across microbatches: #909

This method was to track total tokens and then do one big normalization at the optim_step in order to get an average per-token loss. But, we decided to align with Tinker's way of just summing up the loss at the end, and pushing any loss normalization to the user's advantage calculation.

The benefit is that users have full control of customizing their loss reduction strategy, rather than having it happen in our opaque forward_backward, optim_step implementation which would require some configuration argument that diverges from tinker's API. For example, we would need to add a config somewhere to determine how to average/sum the loss:

client.forward_backward(...)
client.optim_step(..., loss_reduction="token_mean")  # no longer tinker compatible

The current PR aligns with Tinker semantics:

Notice that for all objectives we sum the token-level losses over the sequence length unlike some other loss implementations. If you would like to explore different aggregation schemes, you can include that in the advantage tensor computation.

Example for loss_reduction="token_mean":

  • Move the 1/num_minibatch_tokens normalization into the advantage: loss = sum( -advantage_i * ratio_i for i in range(num_minibatch_tokens) ) / num_minibatch_tokens
  • -> sum( -(advantage_i / num_minibatch_tokens) * ratio_i for i in range(num_minibatch_tokens) )

Learning curve comparisons before/after the PR

FSDP (wandb)

Screenshot 2026-03-20 at 3 29 10 PM

Megatron (wandb)

1.7B:

Screenshot 2026-03-20 at 3 29 40 PM

30B lora:

master baseline:
Screenshot 2026-03-20 at 3 33 29 PM

token_mean_legacy + fixed token_mean:

Screenshot 2026-03-24 at 11 16 56 AM
Open with Devin

justinvyu and others added 3 commits March 9, 2026 11:51
… scale loss by dp_size for FSDP/Megatron parity

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…omparison

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
justinvyu and others added 7 commits March 9, 2026 18:27
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…uction

# Conflicts:
#	skyrl/backends/skyrl_train/utils/ppo_utils.py
#	skyrl/train/fully_async_trainer.py
#	skyrl/train/trainer.py
#	tests/backends/skyrl_train/gpu/test_grpo_sp_sanity.py
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
…ritic, rename token_mean_baseline to token_mean_legacy

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
… add unit tests

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@justinvyu justinvyu marked this pull request as ready for review March 20, 2026 22:34
gemini-code-assist[bot]

This comment was marked as resolved.

Copy link
Copy Markdown
Contributor

@devin-ai-integration devin-ai-integration bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

✅ Devin Review: No Issues Found

Devin Review analyzed this PR and found no bugs or issues to report.

Open in Devin Review

@justinvyu justinvyu changed the title [wip] loss reduction [skyrl-train] Implement loss reduction via advantage normalization and fix token_mean reduction strategy Mar 20, 2026
Copy link
Copy Markdown
Collaborator

@erictang000 erictang000 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks almost good to merge, super clean thanks for adding the token_mean_legacy path

just want to check my understanding + 1 question about the metrics code that I think I probably wrote on the old PR...

… mini-batch reduction

- Report unscaled loss metrics (remove * loss_scale / * dp_size) in both FSDP and Megatron workers
- Rename reduce_metrics -> reduce_metrics_across_microbatches (sums _loss for gradient accumulation)
- Add reduce_metrics_across_minibatches in trainer_utils (averages _loss for logging)
- Use sum all-reduce for _loss keys across DP workers to reconstruct full mini-batch loss

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Comment on lines +371 to +372
"final_loss": unscaled_loss.detach().item(),
"policy_loss": policy_loss.detach().item(),
Copy link
Copy Markdown
Contributor Author

@justinvyu justinvyu Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Metrics fix 1: remove dp_size multiplier in reported metrics, since there's no average that we need to correct for, since reduce_microbatch_metrics and all_reduce_metrics both do sums for *_loss metrics.

# pop out loss_fn_outputs since it's not a scalar metric and to avoid logging it
all_metrics.pop("loss_fn_outputs", None)
reduced_metrics = reduce_metrics(all_metrics)
reduced_metrics = reduce_metrics_across_minibatches(all_metrics)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Metrics fix 2: Take an average across minibatches instead of still summing. This is because the loss reduction normalization happens at the minibatch level. Across different minibatches we should just average, otherwise we'll increase the reported loss scale by ~num_minibatches

justinvyu and others added 8 commits March 27, 2026 11:53
…e_metrics

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…uction

# Conflicts:
#	skyrl/train/trainer.py
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Comment on lines 292 to +294
if resolved_loss_name == "cross_entropy":
loss = policy_loss
unscaled_loss = policy_loss
loss = unscaled_loss * grad_sum_correction_factor
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: should this affect the SFT case? SFT doesn't look at the normalized advantages either, similar to the critic loss case.

Before the PR, the SFT case does a sum across the negative log likelihoods within a microbatch, but still averaged over microbatches and dp workers.
Now, we are summing negative log likelihood across the entire minibatch. What's the desired behavior here?

@justinvyu
Copy link
Copy Markdown
Contributor Author

To sanity check the difference in loss metric magnitudes, I dumped the raw advantages on the first step and manually calculated the loss with the different reduction methods on the same dumped data.

Using dumped advantage tensors from a real GRPO run to compare old vs. new:

With micro_batch_size=4 (128 micro-batches), old and new differ by ~5x — matching what's observed in the actual run:

Average: old=0.065  new=0.322  ratio=4.92

With micro_batch_size=1, the old token_mean reduces to sequence_mean (each sequence weighted equally). The old values match sequence_mean exactly:

token_mean old:  Average=-0.024
sequence_mean:   Average=-0.024  ratio=1.0000

With micro_batch_size=512 (1 micro-batch = full mini-batch), old and new converge:

Average: old=0.322  new=0.322  ratio=1.0000

The new token_mean value (0.322) is the same regardless of micro_batch_size — which is the correct behavior. The old value varied between -0.024 (at micro_batch_size=1, i.e. sequence_mean) and 0.322 (at micro_batch_size=512) depending on how micro-batches were formed.

token_mean_legacy reproduces the old behavior. Runs using token_mean won't be directly comparable to before, but the difference is analogous to comparing token_mean vs. sequence_mean — a different weighting, not a bug.

Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Copy link
Copy Markdown
Contributor

@devin-ai-integration devin-ai-integration bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Devin Review found 3 new potential issues.

View 7 additional findings in Devin Review.

Open in Devin Review

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Validation error message missing 'token_mean_legacy' option

The assertion in validate_cfg correctly accepts token_mean_legacy as a valid loss_reduction value (line 278-281), but the error message string on line 283-284 still only lists ['token_mean', 'sequence_mean', 'seq_mean_token_sum_norm'], omitting the newly added token_mean_legacy. If a user provides an invalid value, the error message will be misleading about which options are available.

(Refers to lines 283-284)

Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

Comment on lines +883 to +884
unscaled_loss = policy_loss + kl_loss_term - entropy_loss_term
loss = unscaled_loss * grad_sum_correction_factor
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 KL loss and entropy loss gradients amplified by num_micro_batches in FSDP policy worker

The old optim_step divided accumulated gradients by micro_batches_accumulated (worker.py:926-931 in old code), which correctly averaged the contribution of all terms (policy, KL, entropy) across micro-batches. This scaling was removed. The policy loss term is unaffected because its reduction is now baked into advantage pre-scaling via apply_loss_reduction_to_advantages_minibatch. However, kl_loss_term and entropy_loss_term (lines 878, 866) are still computed as per-micro-batch masked means and are not pre-scaled. Their gradients now accumulate (sum) across micro-batches without any subsequent division, making their effective contribution num_micro_batches times larger than before. For example, with 4 micro-batches and entropy_loss_coef=0.01, the effective entropy coefficient becomes 0.04, which can destabilize training.

Concrete trace through old vs new gradient math

Old code per worker: grad_accumulated = sum_micro(grad(L_micro)), then optim_step divides by M → effective = mean_micro(grad(L_micro)). For KL/entropy (mean per micro), this gives grad(kl_mean).

New code per worker: grad_accumulated = sum_micro(grad(L_micro * dp_size)), no division → effective after FSDP avg = sum_micro(grad(L_micro)). For KL/entropy, this gives M * grad(kl_mean), i.e. M times the old value.

Prompt for agents
In skyrl/backends/skyrl_train/workers/worker.py, the RL path in _forward_backward_micro (around lines 857-885) computes unscaled_loss = policy_loss + kl_loss_term - entropy_loss_term, then multiplies by grad_sum_correction_factor. The problem is that policy_loss is a sum (pre-scaled by advantage normalization to produce correct gradients when summed across micro-batches), but kl_loss_term and entropy_loss_term are per-micro-batch means that should be averaged across micro-batches, not summed. Since optim_step no longer divides by micro_batches_accumulated, the KL and entropy terms are effectively multiplied by num_micro_batches.

Fix: Scale kl_loss_term and entropy_loss_term by (1 / num_micro_batches) before adding them to the loss. You can pass the number of micro-batches into _forward_backward_micro, or compute it from the data batch size and micro_batch_size. Alternatively, separate the policy loss backward from the auxiliary loss backward, applying grad_sum_correction_factor only to the policy loss and a different factor to the auxiliary terms.

The same fix needs to be applied in the Megatron worker at skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py lines 356-357, where the same pattern exists with grad_sum_correction_factor = num_microbatches * dp_size.
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

Comment on lines +356 to +357
unscaled_loss = policy_loss + kl_loss_term - entropy_loss_term
loss = unscaled_loss * grad_sum_correction_factor
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 KL loss and entropy loss gradients amplified by num_microbatches in Megatron model wrapper

Same issue as the FSDP worker but in the Megatron path. grad_sum_correction_factor = num_microbatches * dp_size (megatron_model_wrapper.py:262) is designed to cancel Megatron's internal loss division by num_microbatches and DP averaging, recovering a sum reduction for the policy loss. But kl_loss_term and entropy_loss_term (lines 338-341) are per-micro-batch means. Megatron divides them by num_microbatches (producing correct averaging), but then the grad_sum_correction_factor multiplies back by num_microbatches * dp_size. After DP averaging, the net effect on KL/entropy is a factor of num_microbatches amplification compared to the intended behavior.

Prompt for agents
In skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py, the loss_func closure computes unscaled_loss = policy_loss + kl_loss_term - entropy_loss_term (line 356) and then multiplies by grad_sum_correction_factor (line 357). The grad_sum_correction_factor = num_microbatches * dp_size is correct for the policy_loss (which is a pre-scaled sum), but kl_loss_term and entropy_loss_term are per-micro-batch means. Megatron internally divides by num_microbatches, then the correction factor multiplies by num_microbatches * dp_size, and DP averages by 1/dp_size. The net result is that KL and entropy are num_microbatches times larger than intended.

Fix: Either (a) divide kl_loss_term and entropy_loss_term by num_microbatches before adding them to the loss (so that after Megatron's division they become kl/num_micro^2 and the correction factor brings them back to kl), or (b) apply the correction factor only to policy_loss and use a separate factor (dp_size only) for the auxiliary terms.
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants