[skyrl-train] Implement loss reduction via advantage normalization and fix token_mean reduction strategy#1296
[skyrl-train] Implement loss reduction via advantage normalization and fix token_mean reduction strategy#1296justinvyu wants to merge 20 commits intoNovaSky-AI:mainfrom
token_mean reduction strategy#1296Conversation
… scale loss by dp_size for FSDP/Megatron parity Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…omparison Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…uction # Conflicts: # skyrl/backends/skyrl_train/utils/ppo_utils.py # skyrl/train/fully_async_trainer.py # skyrl/train/trainer.py # tests/backends/skyrl_train/gpu/test_grpo_sp_sanity.py
…ritic, rename token_mean_baseline to token_mean_legacy Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
… add unit tests Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
token_mean reduction strategy
erictang000
left a comment
There was a problem hiding this comment.
this looks almost good to merge, super clean thanks for adding the token_mean_legacy path
just want to check my understanding + 1 question about the metrics code that I think I probably wrote on the old PR...
skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py
Outdated
Show resolved
Hide resolved
… mini-batch reduction - Report unscaled loss metrics (remove * loss_scale / * dp_size) in both FSDP and Megatron workers - Rename reduce_metrics -> reduce_metrics_across_microbatches (sums _loss for gradient accumulation) - Add reduce_metrics_across_minibatches in trainer_utils (averages _loss for logging) - Use sum all-reduce for _loss keys across DP workers to reconstruct full mini-batch loss Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
| "final_loss": unscaled_loss.detach().item(), | ||
| "policy_loss": policy_loss.detach().item(), |
There was a problem hiding this comment.
Metrics fix 1: remove dp_size multiplier in reported metrics, since there's no average that we need to correct for, since reduce_microbatch_metrics and all_reduce_metrics both do sums for *_loss metrics.
skyrl/train/trainer.py
Outdated
| # pop out loss_fn_outputs since it's not a scalar metric and to avoid logging it | ||
| all_metrics.pop("loss_fn_outputs", None) | ||
| reduced_metrics = reduce_metrics(all_metrics) | ||
| reduced_metrics = reduce_metrics_across_minibatches(all_metrics) |
There was a problem hiding this comment.
Metrics fix 2: Take an average across minibatches instead of still summing. This is because the loss reduction normalization happens at the minibatch level. Across different minibatches we should just average, otherwise we'll increase the reported loss scale by ~num_minibatches
…e_metrics Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…uction # Conflicts: # skyrl/train/trainer.py
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
| if resolved_loss_name == "cross_entropy": | ||
| loss = policy_loss | ||
| unscaled_loss = policy_loss | ||
| loss = unscaled_loss * grad_sum_correction_factor |
There was a problem hiding this comment.
Q: should this affect the SFT case? SFT doesn't look at the normalized advantages either, similar to the critic loss case.
Before the PR, the SFT case does a sum across the negative log likelihoods within a microbatch, but still averaged over microbatches and dp workers.
Now, we are summing negative log likelihood across the entire minibatch. What's the desired behavior here?
|
To sanity check the difference in loss metric magnitudes, I dumped the raw advantages on the first step and manually calculated the loss with the different reduction methods on the same dumped data. Using dumped advantage tensors from a real GRPO run to compare old vs. new: With With With The new
|
There was a problem hiding this comment.
🟡 Validation error message missing 'token_mean_legacy' option
The assertion in validate_cfg correctly accepts token_mean_legacy as a valid loss_reduction value (line 278-281), but the error message string on line 283-284 still only lists ['token_mean', 'sequence_mean', 'seq_mean_token_sum_norm'], omitting the newly added token_mean_legacy. If a user provides an invalid value, the error message will be misleading about which options are available.
(Refers to lines 283-284)
Was this helpful? React with 👍 or 👎 to provide feedback.
| unscaled_loss = policy_loss + kl_loss_term - entropy_loss_term | ||
| loss = unscaled_loss * grad_sum_correction_factor |
There was a problem hiding this comment.
🔴 KL loss and entropy loss gradients amplified by num_micro_batches in FSDP policy worker
The old optim_step divided accumulated gradients by micro_batches_accumulated (worker.py:926-931 in old code), which correctly averaged the contribution of all terms (policy, KL, entropy) across micro-batches. This scaling was removed. The policy loss term is unaffected because its reduction is now baked into advantage pre-scaling via apply_loss_reduction_to_advantages_minibatch. However, kl_loss_term and entropy_loss_term (lines 878, 866) are still computed as per-micro-batch masked means and are not pre-scaled. Their gradients now accumulate (sum) across micro-batches without any subsequent division, making their effective contribution num_micro_batches times larger than before. For example, with 4 micro-batches and entropy_loss_coef=0.01, the effective entropy coefficient becomes 0.04, which can destabilize training.
Concrete trace through old vs new gradient math
Old code per worker: grad_accumulated = sum_micro(grad(L_micro)), then optim_step divides by M → effective = mean_micro(grad(L_micro)). For KL/entropy (mean per micro), this gives grad(kl_mean).
New code per worker: grad_accumulated = sum_micro(grad(L_micro * dp_size)), no division → effective after FSDP avg = sum_micro(grad(L_micro)). For KL/entropy, this gives M * grad(kl_mean), i.e. M times the old value.
Prompt for agents
In skyrl/backends/skyrl_train/workers/worker.py, the RL path in _forward_backward_micro (around lines 857-885) computes unscaled_loss = policy_loss + kl_loss_term - entropy_loss_term, then multiplies by grad_sum_correction_factor. The problem is that policy_loss is a sum (pre-scaled by advantage normalization to produce correct gradients when summed across micro-batches), but kl_loss_term and entropy_loss_term are per-micro-batch means that should be averaged across micro-batches, not summed. Since optim_step no longer divides by micro_batches_accumulated, the KL and entropy terms are effectively multiplied by num_micro_batches.
Fix: Scale kl_loss_term and entropy_loss_term by (1 / num_micro_batches) before adding them to the loss. You can pass the number of micro-batches into _forward_backward_micro, or compute it from the data batch size and micro_batch_size. Alternatively, separate the policy loss backward from the auxiliary loss backward, applying grad_sum_correction_factor only to the policy loss and a different factor to the auxiliary terms.
The same fix needs to be applied in the Megatron worker at skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py lines 356-357, where the same pattern exists with grad_sum_correction_factor = num_microbatches * dp_size.
Was this helpful? React with 👍 or 👎 to provide feedback.
| unscaled_loss = policy_loss + kl_loss_term - entropy_loss_term | ||
| loss = unscaled_loss * grad_sum_correction_factor |
There was a problem hiding this comment.
🔴 KL loss and entropy loss gradients amplified by num_microbatches in Megatron model wrapper
Same issue as the FSDP worker but in the Megatron path. grad_sum_correction_factor = num_microbatches * dp_size (megatron_model_wrapper.py:262) is designed to cancel Megatron's internal loss division by num_microbatches and DP averaging, recovering a sum reduction for the policy loss. But kl_loss_term and entropy_loss_term (lines 338-341) are per-micro-batch means. Megatron divides them by num_microbatches (producing correct averaging), but then the grad_sum_correction_factor multiplies back by num_microbatches * dp_size. After DP averaging, the net effect on KL/entropy is a factor of num_microbatches amplification compared to the intended behavior.
Prompt for agents
In skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py, the loss_func closure computes unscaled_loss = policy_loss + kl_loss_term - entropy_loss_term (line 356) and then multiplies by grad_sum_correction_factor (line 357). The grad_sum_correction_factor = num_microbatches * dp_size is correct for the policy_loss (which is a pre-scaled sum), but kl_loss_term and entropy_loss_term are per-micro-batch means. Megatron internally divides by num_microbatches, then the correction factor multiplies by num_microbatches * dp_size, and DP averages by 1/dp_size. The net result is that KL and entropy are num_microbatches times larger than intended.
Fix: Either (a) divide kl_loss_term and entropy_loss_term by num_microbatches before adding them to the loss (so that after Megatron's division they become kl/num_micro^2 and the correction factor brings them back to kl), or (b) apply the correction factor only to policy_loss and use a separate factor (dp_size only) for the auxiliary terms.
Was this helpful? React with 👍 or 👎 to provide feedback.
This is a breaking change for the default
token_meanloss behavior, as well as observedpolicy_lossmetrics! See the "Differences in reported loss metric magnitudes" section below.Summary
reduce_loss()to always returns a simple masked sum ((loss * mask).sum()). To achieve different reduction strategies, we pre-scale the advantages before they enter the loss function, which also aligns with how Tinker's API handles it.backward()to counteract the default data parallel mean gradient all-reduce across workers to do a sum instead.token_meanloss reduction method to take a mean across all tokens in the minibatch rather than averaging across microbatches. Allows running with the old loss reduction with thetoken_mean_legacyconfig.Loss reduction strategies
Option 1: token_mean
Option 1b: token_mean_legacy
token_meanbehavior before this PR.Option 2: sequence_mean
Option 3: seq_mean_token_sum_norm
Mean all-reduce -> sum all-reduce
We need the loss to be summed across microbatches and data parallel workers:
Difference in reported loss metric magnitudes
You will observe that the loss metric reported has a different magnitude compared to your older runs. This is beacuse the old token_mean implementation was somewhere between a true token mean and a sequence mean due to per-micro-batch normalization (ex: micro_batch_size=1 was equivalent to sequence mean).
The new
token_meanis a proper minibatch token mean, whilesequence_meanweights every sequence equally regardless of length. When comparing the loss produced by different reduction methods computed on the same advantages, from a real run:The old token_mean gave each micro-batch equal weight rather than each token, so its scale depended on how advantages were distributed across micro-batches. The new implementation is invariant to micro-batch size.
Note that
token_mean_legacyreports the old metrics still, and thesequence_meanandseq_mean_token_sum_normmodes also match exactly. See this comment for more details.Tinker compatibility
Here was the first attempt at fixing the loss reduction across microbatches: #909
This method was to track total tokens and then do one big normalization at the
optim_stepin order to get an average per-token loss. But, we decided to align with Tinker's way of just summing up the loss at the end, and pushing any loss normalization to the user's advantage calculation.The benefit is that users have full control of customizing their loss reduction strategy, rather than having it happen in our opaque
forward_backward,optim_stepimplementation which would require some configuration argument that diverges from tinker's API. For example, we would need to add a config somewhere to determine how to average/sum the loss:The current PR aligns with Tinker semantics:
Example for
loss_reduction="token_mean":1/num_minibatch_tokensnormalization into the advantage:loss = sum( -advantage_i * ratio_i for i in range(num_minibatch_tokens) ) / num_minibatch_tokenssum( -(advantage_i / num_minibatch_tokens) * ratio_i for i in range(num_minibatch_tokens) )Learning curve comparisons before/after the PR
FSDP (wandb)
Megatron (wandb)
1.7B:
30B lora:
master baseline:

token_mean_legacy+ fixedtoken_mean: