Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
589c150
Move loss reduction normalization to trainer-level advantage scaling,…
justinvyu Mar 9, 2026
333f31a
Add token_mean_baseline loss reduction for mean-of-microbatch-means c…
justinvyu Mar 9, 2026
aaaba4c
fix assertion
justinvyu Mar 9, 2026
a121360
Update tests for sum-based reduce_loss and dp_size scaling changes
justinvyu Mar 10, 2026
15de89a
Merge remote-tracking branch 'upstream/main' into token_mean_loss_red…
justinvyu Mar 17, 2026
e3842c3
lint
justinvyu Mar 17, 2026
13bfe80
fix tests
justinvyu Mar 17, 2026
e76bece
Refactor advantage normalization: fix z-score propagation, skip for c…
justinvyu Mar 20, 2026
0192e8e
token_mean_baseline -> token_mean_legacy
justinvyu Mar 20, 2026
4ee0b31
Extract apply_loss_reduction_to_advantages_minibatch to ppo_utils and…
justinvyu Mar 20, 2026
c8f06cc
Fix metric reporting: remove dp_size scaling, separate micro-batch vs…
justinvyu Mar 25, 2026
2c13315
Fix critic metric reporting: explicit sum_loss_metrics flag for reduc…
justinvyu Mar 27, 2026
14ba02e
Remove reduce_metrics_across_minibatches, reuse reduce_metrics
justinvyu Mar 27, 2026
0cfc95b
Merge remote-tracking branch 'upstream/main' into token_mean_loss_red…
justinvyu Mar 27, 2026
717c3a7
add some comments about sum metrics
justinvyu Mar 27, 2026
661f5d8
add clarifying comments and rename loss_scale
justinvyu Mar 27, 2026
5cc95a1
no_grad for safety and make private
justinvyu Mar 27, 2026
ce8f6aa
remove outdated comments about loss reduction type in sapo tests
justinvyu Mar 27, 2026
1a60bb5
fix test
justinvyu Mar 27, 2026
c5feb83
fix test
justinvyu Mar 27, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions examples/train/async/async_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from skyrl.train.trainer import RayPPOTrainer
from tqdm import tqdm
from skyrl.train.utils import Timer
from skyrl.backends.skyrl_train.utils.ppo_utils import normalize_advantages_dict
from skyrl.backends.skyrl_train.training_batch import TrainingInputBatch
from skyrl.train.generators.base import GeneratorOutput
from skyrl.train.utils.trainer_utils import ResumeMode
Expand Down Expand Up @@ -146,9 +145,6 @@ async def _run_training(self, generation_buffer):
training_input.pop(key)
training_input.metadata.pop("uids")

if self.cfg.trainer.algorithm.advantage_batch_normalize:
training_input = normalize_advantages_dict(training_input)

if self.cfg.trainer.dump_data_batch:
# dump data to file
with Timer("dump_data_batch"):
Expand Down
4 changes: 0 additions & 4 deletions skyrl-agent/skyrl_agent/integrations/skyrl_train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
from skyrl.train.utils import Timer
from skyrl.backends.skyrl_train.utils.ppo_utils import (
get_kl_controller,
normalize_advantages_dict,
)
from skyrl.train.utils.trainer_utils import (
validate_generator_output,
Expand Down Expand Up @@ -382,9 +381,6 @@ async def train(self):
training_input.pop(key)
training_input.metadata.pop("uids")

if self.cfg.trainer.algorithm.advantage_batch_normalize:
training_input = normalize_advantages_dict(training_input)

if self.cfg.trainer.dump_data_batch:
# dump data to file
with Timer("dump_data_batch"):
Expand Down
17 changes: 9 additions & 8 deletions skyrl/backends/skyrl_train/distributed/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,11 @@ def get_rank(self) -> int:
"""Get current process rank"""
return dist.get_rank()

def all_reduce(self, data: DataT, op="mean") -> DataT:
"""Perform all_reduce across all processes"""
def all_reduce(self, data: DataT, op="mean", group=None) -> DataT:
"""Perform all_reduce across all processes (or within a process group)."""
assert op in ("mean", "max", "sum", "min")
if isinstance(data, dict):
return {k: self.all_reduce(v, op) for k, v in data.items()}
return {k: self.all_reduce(v, op, group=group) for k, v in data.items()}
else:
is_tensor = True
if not isinstance(data, torch.Tensor):
Expand All @@ -82,14 +82,15 @@ def all_reduce(self, data: DataT, op="mean") -> DataT:
if is_cpu_tensor:
data = data.to(torch.cuda.current_device())
if op == "mean":
data /= self.world_size
dist.all_reduce(data, op=dist.ReduceOp.SUM)
group_size = dist.get_world_size(group) if group is not None else self.world_size
data /= group_size
dist.all_reduce(data, op=dist.ReduceOp.SUM, group=group)
elif op == "max":
dist.all_reduce(data, op=dist.ReduceOp.MAX)
dist.all_reduce(data, op=dist.ReduceOp.MAX, group=group)
elif op == "min":
dist.all_reduce(data, op=dist.ReduceOp.MIN)
dist.all_reduce(data, op=dist.ReduceOp.MIN, group=group)
elif op == "sum":
dist.all_reduce(data, op=dist.ReduceOp.SUM)
dist.all_reduce(data, op=dist.ReduceOp.SUM, group=group)
if is_cpu_tensor:
data = data.cpu()
return data.item() if not is_tensor else data
Expand Down
133 changes: 57 additions & 76 deletions skyrl/backends/skyrl_train/utils/ppo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,14 @@
from collections import defaultdict
from enum import StrEnum
from functools import wraps
from typing import Callable, List, Literal, Optional, Tuple, Union
from typing import Callable, List, Optional, Tuple, Union

import numpy as np
import ray
import torch
from jaxtyping import Float
from loguru import logger

from skyrl.backends.skyrl_train.training_batch import TrainingInputBatch
from skyrl.backends.skyrl_train.utils.off_policy_correction_utils import (
apply_off_policy_correction,
)
Expand Down Expand Up @@ -124,27 +123,6 @@ def compute_approx_kl(
return kld


@torch.no_grad()
def normalize_advantages_dict(data: TrainingInputBatch) -> TrainingInputBatch:
"""Normalizes the advantages in the data batch.

Expects:
- `["advantages"]`: Float[torch.Tensor, "batch_size seqlen"]
- `["response_mask"]`: Float[torch.Tensor, "batch_size seqlen"]
"""
advantages: Float[torch.Tensor, "batch_size seqlen"] = data["advantages"]
response_masks: Float[torch.Tensor, "batch_size seqlen"] = data["response_mask"]
num_actions: float = response_masks.sum()
# mean
mean: float = advantages.mean()
# std
std: float = ((advantages - mean).pow(2) * response_masks).sum()
rstd: float = (std / num_actions).clamp(min=1e-8).rsqrt()

data["advantages"] = (advantages - mean) * rstd
return data


def masked_var(values, mask, unbiased=True):
"""Compute variance of tensor with masked values."""
mean = masked_mean(values, mask)
Expand Down Expand Up @@ -558,12 +536,6 @@ def ppo_policy_loss(
rollout_logprobs: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, dict[str, float]]:
assert config.policy_loss_type in ["regular", "dual_clip"], "loss_type must be either 'regular' or 'dual_clip'"
loss_reduction = config.loss_reduction
assert loss_reduction in [
"token_mean",
"sequence_mean",
"seq_mean_token_sum_norm",
], "loss_reduction must be either 'token_mean', 'sequence_mean', or 'seq_mean_token_sum_norm'"

ratio = safe_exp_delta(log_probs - old_log_probs, clip=20.0, out_dtype=log_probs.dtype)
surr1 = ratio * advantages
Expand All @@ -584,7 +556,7 @@ def ppo_policy_loss(
)
loss_metrics.update(off_policy_metrics)

loss = reduce_loss(loss, loss_mask, loss_reduction, config.max_seq_len)
loss = reduce_loss(loss, loss_mask)
return loss, loss_metrics


Expand Down Expand Up @@ -656,8 +628,7 @@ def gate_function(x, tau):
)
loss_metrics.update(off_policy_metrics)

# for SAPO, we need to aggregate the loss at the sequence level (seq-mean-token-mean)
loss = reduce_loss(loss, loss_mask, loss_reduction, config.max_seq_len)
loss = reduce_loss(loss, loss_mask)

return loss, loss_metrics

Expand Down Expand Up @@ -726,7 +697,7 @@ def gspo_policy_loss(
)
loss_metrics.update(off_policy_metrics)

loss = reduce_loss(loss, loss_mask, loss_reduction, config.max_seq_len)
loss = reduce_loss(loss, loss_mask)

return loss, loss_metrics

Expand Down Expand Up @@ -763,7 +734,7 @@ def compute_policy_loss_cispo(
)
loss_metrics.update(off_policy_metrics)

loss = reduce_loss(loss, loss_mask, config.loss_reduction, config.max_seq_len)
loss = reduce_loss(loss, loss_mask)
return loss, loss_metrics


Expand Down Expand Up @@ -791,13 +762,6 @@ def rollout_is_policy_loss(
"""
assert rollout_logprobs is not None, "rollout_logprobs are required for rollout_is"

loss_reduction = config.loss_reduction
assert loss_reduction in [
"token_mean",
"sequence_mean",
"seq_mean_token_sum_norm",
], "loss_reduction must be either 'token_mean', 'sequence_mean', or 'seq_mean_token_sum_norm'"

ratio = safe_exp_delta(log_probs - rollout_logprobs, clip=20.0, out_dtype=log_probs.dtype)

in_range = (ratio > 1 - config.eps_clip_low) & (ratio < 1 + config.eps_clip_high)
Expand All @@ -812,7 +776,7 @@ def rollout_is_policy_loss(
)
loss_metrics.update(off_policy_metrics)

loss = reduce_loss(loss, loss_mask, loss_reduction, config.max_seq_len)
loss = reduce_loss(loss, loss_mask)
return loss, loss_metrics


Expand Down Expand Up @@ -874,12 +838,7 @@ def compute_policy_loss_clip_cov(
# Apply correction mask to losses
pg_losses = torch.maximum(pg_losses1, pg_losses2) * corr

pg_loss = reduce_loss(
loss=pg_losses,
loss_mask=loss_mask,
loss_reduction=config.loss_reduction,
max_seq_len=config.max_seq_len,
)
pg_loss = reduce_loss(loss=pg_losses, loss_mask=loss_mask)

return pg_loss, {"clip_ratio": clip_frac.item()}

Expand Down Expand Up @@ -933,12 +892,7 @@ def compute_policy_loss_kl_cov(
large_cov_idxs % advantages.shape[1],
]

pg_loss = reduce_loss(
loss=pg_losses,
loss_mask=loss_mask,
loss_reduction=config.loss_reduction,
max_seq_len=config.max_seq_len,
)
pg_loss = reduce_loss(loss=pg_losses, loss_mask=loss_mask)

# NOTE (sumanthrh): Since the pg clip ratio is not applicable for KL-COV so we just use 0.0
return pg_loss, {"clip_ratio": 0.0}
Expand Down Expand Up @@ -977,10 +931,7 @@ def cross_entropy_loss(
elementwise_loss = -log_probs

# Apply loss mask and sum (matching Tinker's SUM reduction semantics)
if loss_mask is not None:
loss = (elementwise_loss * loss_mask).sum()
else:
loss = elementwise_loss.sum()
loss = reduce_loss(elementwise_loss, loss_mask)

# No clipping in cross-entropy loss
return loss, {"clip_ratio": 0.0}
Expand Down Expand Up @@ -1039,30 +990,60 @@ def importance_sampling_loss(
def reduce_loss(
loss: torch.Tensor,
loss_mask: Optional[torch.Tensor],
loss_reduction: Literal["token_mean", "sequence_mean", "seq_mean_token_sum_norm"],
max_seq_len: Optional[int] = None,
) -> torch.Tensor:
return (loss * loss_mask).sum() if loss_mask is not None else loss.sum()


def apply_loss_reduction_to_advantages_minibatch(
advantages: torch.Tensor,
loss_mask: torch.Tensor,
loss_reduction: str,
micro_batch_size: int,
max_seq_len: int,
) -> torch.Tensor:
"""Scale advantages so that summing produces the desired loss reduction.

Args:
advantages: Advantage tensor of shape (minibatch_size, seq_len).
loss_mask: Mask of shape (minibatch_size, seq_len) indicating valid loss tokens.
loss_reduction: One of "token_mean", "token_mean_legacy", "sequence_mean", "seq_mean_token_sum_norm".
micro_batch_size: Number of sequences per micro-batch
max_seq_len: Maximum sequence length.

Returns:
Scaled advantages tensor.
"""
batch_size = advantages.shape[0]
normalized_advantages = torch.zeros_like(advantages)

# Option 1: token mean
if loss_reduction == "token_mean":
# sum over *all* valid tokens, divide by total valid-token count
loss = masked_mean(loss, loss_mask)
normalized_advantages = advantages / loss_mask.sum().clamp(min=1)

# Option 1b: legacy token-mean that normalizes per-microbatch then averages across microbatches.
elif loss_reduction == "token_mean_legacy":
num_micro_batches = batch_size // micro_batch_size
for i in range(num_micro_batches):
start_idx = i * micro_batch_size
end_idx = (i + 1) * micro_batch_size
mb_advantages = advantages[start_idx:end_idx]
mb_loss_mask = loss_mask[start_idx:end_idx]
mb_advantages = mb_advantages / mb_loss_mask.sum().clamp(min=1)
mb_advantages /= num_micro_batches
normalized_advantages[start_idx:end_idx] = mb_advantages

# Option 2: sequence mean
elif loss_reduction == "sequence_mean":
# per-sequence token-mean (dim=-1), then batch-mean
loss = masked_mean(loss, loss_mask, dim=-1).mean()
normalized_advantages = advantages / (batch_size * loss_mask.sum(dim=-1, keepdim=True).clamp(min=1))

# Option 3: Dr. GRPO style loss reduction to avoid length bias by normalizing by a constant
elif loss_reduction == "seq_mean_token_sum_norm":
# per-sequence token-sum, normalized by the max sequence length, then batch mean
# this is the Dr. GRPO loss reduction to avoid length bias by normalizing by a constant
assert max_seq_len is not None, "max_seq_len must be provided for seq_mean_token_sum_norm loss reduction"
# NOTE: max_seq_len can be set explicitly via algorithm.max_seq_len, otherwise defaults to
# cfg.generator.max_input_length + cfg.generator.sampling_params.max_generate_length
if loss_mask is not None:
seq_losses = torch.sum(loss * loss_mask, dim=-1) / max_seq_len
else:
# If no mask, assume all tokens are valid
seq_losses = torch.sum(loss, dim=-1) / max_seq_len
loss = torch.mean(seq_losses)
normalized_advantages = advantages / (batch_size * max_seq_len)

else:
raise ValueError(f"Invalid loss reduction type: {loss_reduction}")
return loss

return normalized_advantages


# NOTE (erictang000): below ported from verl
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,10 +246,21 @@ def loss_func(logits, data):
loss_mask = data["loss_mask"]
rollout_action_logprobs = data["rollout_action_logprobs"]
action_mask = data.get("action_mask")
num_microbatches = data.get("num_microbatches")

dp_size = mpu.get_data_parallel_world_size()
tp_grp = mpu.get_tensor_model_parallel_group()
tp_rank = mpu.get_tensor_model_parallel_rank()

# Policy losses are pre-scaled to achieve the correct loss_reduction when summing across the entire minibatch
# (see `apply_loss_reduction_to_advantages_minibatch`).
# Megatron divides loss by num_microbatches
# (https://github.com/NVIDIA/Megatron-LM/blob/core_v0.15.2/megatron/core/pipeline_parallel/schedules.py#L248)
# and the data parallel all-reduce averages gradients across dp_size
# (https://github.com/NVIDIA/Megatron-LM/blob/core_v0.15.2/megatron/core/distributed/distributed_data_parallel.py#L285)
# so we multiply by both factors to recover the correct sum reduction.
grad_sum_correction_factor = num_microbatches * dp_size

# temperature normalization
if temperature != 1.0:
logits.div_(temperature)
Expand Down Expand Up @@ -279,13 +290,15 @@ def loss_func(logits, data):

# SFT path: cross_entropy loss (negative log likelihood)
if resolved_loss_name == "cross_entropy":
loss = policy_loss
unscaled_loss = policy_loss
loss = unscaled_loss * grad_sum_correction_factor
Comment on lines 292 to +294
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: should this affect the SFT case? SFT doesn't look at the normalized advantages either, similar to the critic loss case.

Before the PR, the SFT case does a sum across the negative log likelihoods within a microbatch, but still averaged over microbatches and dp workers.
Now, we are summing negative log likelihood across the entire minibatch. What's the desired behavior here?


# Compute elementwise loss for Tinker API (per-token NLL)
with torch.no_grad():
elementwise_loss = -action_log_probs
if loss_mask is not None:
elementwise_loss = elementwise_loss * loss_mask
elementwise_loss = elementwise_loss * grad_sum_correction_factor

# Build per-sequence loss_fn_outputs
batch_size = action_log_probs.shape[0]
Expand All @@ -310,7 +323,7 @@ def loss_func(logits, data):
)

metrics = {
"loss": loss.detach().item(),
"loss": unscaled_loss.detach().item(),
"response_length": num_actions,
"loss_fn_outputs": loss_fn_outputs,
}
Expand Down Expand Up @@ -340,7 +353,8 @@ def loss_func(logits, data):
kl_loss = torch.tensor(0.0)
kl_loss_term = kl_loss * loss_config.kl_loss_coef

loss = policy_loss + kl_loss_term - entropy_loss_term
unscaled_loss = policy_loss + kl_loss_term - entropy_loss_term
loss = unscaled_loss * grad_sum_correction_factor
Comment on lines +356 to +357
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 KL loss and entropy loss gradients amplified by num_microbatches in Megatron model wrapper

Same issue as the FSDP worker but in the Megatron path. grad_sum_correction_factor = num_microbatches * dp_size (megatron_model_wrapper.py:262) is designed to cancel Megatron's internal loss division by num_microbatches and DP averaging, recovering a sum reduction for the policy loss. But kl_loss_term and entropy_loss_term (lines 338-341) are per-micro-batch means. Megatron divides them by num_microbatches (producing correct averaging), but then the grad_sum_correction_factor multiplies back by num_microbatches * dp_size. After DP averaging, the net effect on KL/entropy is a factor of num_microbatches amplification compared to the intended behavior.

Prompt for agents
In skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py, the loss_func closure computes unscaled_loss = policy_loss + kl_loss_term - entropy_loss_term (line 356) and then multiplies by grad_sum_correction_factor (line 357). The grad_sum_correction_factor = num_microbatches * dp_size is correct for the policy_loss (which is a pre-scaled sum), but kl_loss_term and entropy_loss_term are per-micro-batch means. Megatron internally divides by num_microbatches, then the correction factor multiplies by num_microbatches * dp_size, and DP averages by 1/dp_size. The net result is that KL and entropy are num_microbatches times larger than intended.

Fix: Either (a) divide kl_loss_term and entropy_loss_term by num_microbatches before adding them to the loss (so that after Megatron's division they become kl/num_micro^2 and the correction factor brings them back to kl), or (b) apply the correction factor only to policy_loss and use a separate factor (dp_size only) for the auxiliary terms.
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.


# Build per-sequence loss_fn_outputs with logprobs.
batch_size = action_log_probs.shape[0]
Expand All @@ -363,7 +377,7 @@ def loss_func(logits, data):
)

metrics = {
"final_loss": loss.detach().item(),
"final_loss": unscaled_loss.detach().item(),
"policy_loss": policy_loss.detach().item(),
Comment on lines +380 to +381
Copy link
Copy Markdown
Contributor Author

@justinvyu justinvyu Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Metrics fix 1: remove dp_size multiplier in reported metrics, since there's no average that we need to correct for, since reduce_microbatch_metrics and all_reduce_metrics both do sums for *_loss metrics.

"policy_entropy": entropy.detach().item(),
"policy_kl": kl_loss.detach().item(),
Expand Down
Loading
Loading