From 589c1501ae09552d6158d8808cd6292b71c4ffbc Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Mon, 9 Mar 2026 11:51:20 -0700 Subject: [PATCH 01/18] Move loss reduction normalization to trainer-level advantage scaling, scale loss by dp_size for FSDP/Megatron parity Co-Authored-By: Claude Opus 4.6 --- examples/train/async/async_trainer.py | 4 - .../integrations/skyrl_train/trainer.py | 4 - .../skyrl_train/distributed/strategy.py | 17 ++-- skyrl/backends/skyrl_train/utils/ppo_utils.py | 81 +++---------------- .../megatron/megatron_model_wrapper.py | 22 +++-- .../workers/megatron/megatron_worker.py | 12 +-- skyrl/backends/skyrl_train/workers/worker.py | 37 ++++----- .../skyrl_train/workers/worker_utils.py | 15 ++-- skyrl/train/fully_async_trainer.py | 4 - skyrl/train/trainer.py | 62 ++++++++++++-- .../skyrl_train/gpu/test_grpo_sp_sanity.py | 4 - 11 files changed, 123 insertions(+), 139 deletions(-) diff --git a/examples/train/async/async_trainer.py b/examples/train/async/async_trainer.py index 9f0dc7c063..a80edd730f 100644 --- a/examples/train/async/async_trainer.py +++ b/examples/train/async/async_trainer.py @@ -5,7 +5,6 @@ from skyrl.train.trainer import RayPPOTrainer from tqdm import tqdm from skyrl.train.utils import Timer -from skyrl.backends.skyrl_train.utils.ppo_utils import normalize_advantages_dict from skyrl.backends.skyrl_train.training_batch import TrainingInputBatch from skyrl.train.generators.base import GeneratorOutput from skyrl.train.utils.trainer_utils import ResumeMode @@ -146,9 +145,6 @@ async def _run_training(self, generation_buffer): training_input.pop(key) training_input.metadata.pop("uids") - if self.cfg.trainer.algorithm.advantage_batch_normalize: - training_input = normalize_advantages_dict(training_input) - if self.cfg.trainer.dump_data_batch: # dump data to file with Timer("dump_data_batch"): diff --git a/skyrl-agent/skyrl_agent/integrations/skyrl_train/trainer.py b/skyrl-agent/skyrl_agent/integrations/skyrl_train/trainer.py index 43112dcc10..39253ccc0f 100644 --- a/skyrl-agent/skyrl_agent/integrations/skyrl_train/trainer.py +++ b/skyrl-agent/skyrl_agent/integrations/skyrl_train/trainer.py @@ -39,7 +39,6 @@ from skyrl.train.utils import Timer from skyrl.backends.skyrl_train.utils.ppo_utils import ( get_kl_controller, - normalize_advantages_dict, ) from skyrl.train.utils.trainer_utils import ( validate_generator_output, @@ -381,9 +380,6 @@ async def train(self): training_input.pop(key) training_input.metadata.pop("uids") - if self.cfg.trainer.algorithm.advantage_batch_normalize: - training_input = normalize_advantages_dict(training_input) - if self.cfg.trainer.dump_data_batch: # dump data to file with Timer("dump_data_batch"): diff --git a/skyrl/backends/skyrl_train/distributed/strategy.py b/skyrl/backends/skyrl_train/distributed/strategy.py index ce41c113a6..a8bc97cab7 100644 --- a/skyrl/backends/skyrl_train/distributed/strategy.py +++ b/skyrl/backends/skyrl_train/distributed/strategy.py @@ -67,11 +67,11 @@ def get_rank(self) -> int: """Get current process rank""" return dist.get_rank() - def all_reduce(self, data: DataT, op="mean") -> DataT: - """Perform all_reduce across all processes""" + def all_reduce(self, data: DataT, op="mean", group=None) -> DataT: + """Perform all_reduce across all processes (or within a process group).""" assert op in ("mean", "max", "sum", "min") if isinstance(data, dict): - return {k: self.all_reduce(v, op) for k, v in data.items()} + return {k: self.all_reduce(v, op, group=group) for k, v in data.items()} else: is_tensor = True if not isinstance(data, torch.Tensor): @@ -82,14 +82,15 @@ def all_reduce(self, data: DataT, op="mean") -> DataT: if is_cpu_tensor: data = data.to(torch.cuda.current_device()) if op == "mean": - data /= self.world_size - dist.all_reduce(data, op=dist.ReduceOp.SUM) + group_size = dist.get_world_size(group) if group is not None else self.world_size + data /= group_size + dist.all_reduce(data, op=dist.ReduceOp.SUM, group=group) elif op == "max": - dist.all_reduce(data, op=dist.ReduceOp.MAX) + dist.all_reduce(data, op=dist.ReduceOp.MAX, group=group) elif op == "min": - dist.all_reduce(data, op=dist.ReduceOp.MIN) + dist.all_reduce(data, op=dist.ReduceOp.MIN, group=group) elif op == "sum": - dist.all_reduce(data, op=dist.ReduceOp.SUM) + dist.all_reduce(data, op=dist.ReduceOp.SUM, group=group) if is_cpu_tensor: data = data.cpu() return data.item() if not is_tensor else data diff --git a/skyrl/backends/skyrl_train/utils/ppo_utils.py b/skyrl/backends/skyrl_train/utils/ppo_utils.py index 1f9fe4e469..db82cbe4f5 100644 --- a/skyrl/backends/skyrl_train/utils/ppo_utils.py +++ b/skyrl/backends/skyrl_train/utils/ppo_utils.py @@ -19,7 +19,7 @@ from collections import defaultdict from enum import StrEnum from functools import wraps -from typing import Callable, List, Literal, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import numpy as np import ray @@ -28,7 +28,6 @@ from loguru import logger from skyrl.train.config import AlgorithmConfig -from skyrl.backends.skyrl_train.training_batch import TrainingInputBatch from skyrl.backends.skyrl_train.utils.off_policy_correction_utils import apply_off_policy_correction from skyrl.backends.skyrl_train.utils.torch_utils import masked_mean, safe_exp_delta @@ -123,27 +122,6 @@ def compute_approx_kl( return kld -@torch.no_grad() -def normalize_advantages_dict(data: TrainingInputBatch) -> TrainingInputBatch: - """Normalizes the advantages in the data batch. - - Expects: - - `["advantages"]`: Float[torch.Tensor, "batch_size seqlen"] - - `["response_mask"]`: Float[torch.Tensor, "batch_size seqlen"] - """ - advantages: Float[torch.Tensor, "batch_size seqlen"] = data["advantages"] - response_masks: Float[torch.Tensor, "batch_size seqlen"] = data["response_mask"] - num_actions: float = response_masks.sum() - # mean - mean: float = advantages.mean() - # std - std: float = ((advantages - mean).pow(2) * response_masks).sum() - rstd: float = (std / num_actions).clamp(min=1e-8).rsqrt() - - data["advantages"] = (advantages - mean) * rstd - return data - - def masked_var(values, mask, unbiased=True): """Compute variance of tensor with masked values.""" mean = masked_mean(values, mask) @@ -555,12 +533,6 @@ def ppo_policy_loss( rollout_logprobs: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, dict[str, float]]: assert config.policy_loss_type in ["regular", "dual_clip"], "loss_type must be either 'regular' or 'dual_clip'" - loss_reduction = config.loss_reduction - assert loss_reduction in [ - "token_mean", - "sequence_mean", - "seq_mean_token_sum_norm", - ], "loss_reduction must be either 'token_mean', 'sequence_mean', or 'seq_mean_token_sum_norm'" ratio = safe_exp_delta(log_probs - old_log_probs, clip=20.0, out_dtype=log_probs.dtype) surr1 = ratio * advantages @@ -581,7 +553,7 @@ def ppo_policy_loss( ) loss_metrics.update(off_policy_metrics) - loss = reduce_loss(loss, loss_mask, loss_reduction, config.max_seq_len) + loss = reduce_loss(loss, loss_mask) return loss, loss_metrics @@ -652,7 +624,7 @@ def gate_function(x, tau): loss_metrics.update(off_policy_metrics) # for SAPO, we need to aggregate the loss at the sequence level (seq-mean-token-mean) - loss = reduce_loss(loss, loss_mask, loss_reduction, config.max_seq_len) + loss = reduce_loss(loss, loss_mask) return loss, loss_metrics @@ -719,7 +691,7 @@ def gspo_policy_loss( ) loss_metrics.update(off_policy_metrics) - loss = reduce_loss(loss, loss_mask, loss_reduction, config.max_seq_len) + loss = reduce_loss(loss, loss_mask) return loss, loss_metrics @@ -756,7 +728,7 @@ def compute_policy_loss_cispo( ) loss_metrics.update(off_policy_metrics) - loss = reduce_loss(loss, loss_mask, config.loss_reduction, config.max_seq_len) + loss = reduce_loss(loss, loss_mask) return loss, loss_metrics @@ -818,12 +790,7 @@ def compute_policy_loss_clip_cov( # Apply correction mask to losses pg_losses = torch.maximum(pg_losses1, pg_losses2) * corr - pg_loss = reduce_loss( - loss=pg_losses, - loss_mask=loss_mask, - loss_reduction=config.loss_reduction, - max_seq_len=config.max_seq_len, - ) + pg_loss = reduce_loss(loss=pg_losses, loss_mask=loss_mask) return pg_loss, {"clip_ratio": clip_frac.item()} @@ -877,12 +844,7 @@ def compute_policy_loss_kl_cov( large_cov_idxs % advantages.shape[1], ] - pg_loss = reduce_loss( - loss=pg_losses, - loss_mask=loss_mask, - loss_reduction=config.loss_reduction, - max_seq_len=config.max_seq_len, - ) + pg_loss = reduce_loss(loss=pg_losses, loss_mask=loss_mask) # NOTE (sumanthrh): Since the pg clip ratio is not applicable for KL-COV so we just use 0.0 return pg_loss, {"clip_ratio": 0.0} @@ -921,10 +883,7 @@ def cross_entropy_loss( elementwise_loss = -log_probs # Apply loss mask and sum (matching Tinker's SUM reduction semantics) - if loss_mask is not None: - loss = (elementwise_loss * loss_mask).sum() - else: - loss = elementwise_loss.sum() + loss = reduce_loss(elementwise_loss, loss_mask) # No clipping in cross-entropy loss return loss, {"clip_ratio": 0.0} @@ -983,30 +942,8 @@ def importance_sampling_loss( def reduce_loss( loss: torch.Tensor, loss_mask: Optional[torch.Tensor], - loss_reduction: Literal["token_mean", "sequence_mean", "seq_mean_token_sum_norm"], - max_seq_len: Optional[int] = None, ) -> torch.Tensor: - if loss_reduction == "token_mean": - # sum over *all* valid tokens, divide by total valid-token count - loss = masked_mean(loss, loss_mask) - elif loss_reduction == "sequence_mean": - # per-sequence token-mean (dim=-1), then batch-mean - loss = masked_mean(loss, loss_mask, dim=-1).mean() - elif loss_reduction == "seq_mean_token_sum_norm": - # per-sequence token-sum, normalized by the max sequence length, then batch mean - # this is the Dr. GRPO loss reduction to avoid length bias by normalizing by a constant - assert max_seq_len is not None, "max_seq_len must be provided for seq_mean_token_sum_norm loss reduction" - # NOTE: max_seq_len can be set explicitly via algorithm.max_seq_len, otherwise defaults to - # cfg.generator.max_input_length + cfg.generator.sampling_params.max_generate_length - if loss_mask is not None: - seq_losses = torch.sum(loss * loss_mask, dim=-1) / max_seq_len - else: - # If no mask, assume all tokens are valid - seq_losses = torch.sum(loss, dim=-1) / max_seq_len - loss = torch.mean(seq_losses) - else: - raise ValueError(f"Invalid loss reduction type: {loss_reduction}") - return loss + return (loss * loss_mask).sum() if loss_mask is not None else loss.sum() # NOTE (erictang000): below ported from verl diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py index 5ab565f989..40004dc427 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py @@ -229,10 +229,19 @@ def loss_func(logits, data): loss_mask = data["loss_mask"] rollout_action_logprobs = data["rollout_action_logprobs"] action_mask = data.get("action_mask") + num_microbatches = data.get("num_microbatches") + dp_size = mpu.get_data_parallel_world_size() tp_grp = mpu.get_tensor_model_parallel_group() tp_rank = mpu.get_tensor_model_parallel_rank() + # Megatron's pipeline parallel forward_backward_func internally divides loss by num_microbatches + # https://github.com/NVIDIA/Megatron-LM/blob/core_v0.15.2/megatron/core/pipeline_parallel/schedules.py#L248 + # we want to maintain a sum of losses across all micro batches, so we reverse this division. + # we additionally multiply by the data parallelism size to undo the DDP all-reduce mean + # https://github.com/NVIDIA/Megatron-LM/blob/core_v0.15.2/megatron/core/distributed/distributed_data_parallel.py#L285 + loss_scale = num_microbatches * dp_size + # temperature normalization if temperature != 1.0: logits.div_(temperature) @@ -262,13 +271,15 @@ def loss_func(logits, data): # SFT path: cross_entropy loss (negative log likelihood) if resolved_loss_name == "cross_entropy": - loss = policy_loss + unscaled_loss = policy_loss + loss = unscaled_loss * loss_scale # Compute elementwise loss for Tinker API (per-token NLL) with torch.no_grad(): elementwise_loss = -action_log_probs if loss_mask is not None: elementwise_loss = elementwise_loss * loss_mask + elementwise_loss = elementwise_loss * loss_scale # Build per-sequence loss_fn_outputs batch_size = action_log_probs.shape[0] @@ -289,7 +300,7 @@ def loss_func(logits, data): ) metrics = { - "loss": loss.detach().item(), + "loss": unscaled_loss.detach().item(), "response_length": num_actions, "loss_fn_outputs": loss_fn_outputs, } @@ -319,7 +330,8 @@ def loss_func(logits, data): kl_loss = torch.tensor(0.0) kl_loss_term = kl_loss * loss_config.kl_loss_coef - loss = policy_loss + kl_loss_term - entropy_loss_term + unscaled_loss = policy_loss + kl_loss_term - entropy_loss_term + loss = unscaled_loss * loss_scale # Build per-sequence loss_fn_outputs with logprobs. batch_size = action_log_probs.shape[0] @@ -342,8 +354,8 @@ def loss_func(logits, data): ) metrics = { - "final_loss": loss.detach().item(), - "policy_loss": policy_loss.detach().item(), + "final_loss": unscaled_loss.detach().item() * dp_size, + "policy_loss": policy_loss.detach().item() * dp_size, "policy_entropy": entropy.detach().item(), "policy_kl": kl_loss.detach().item(), "loss_fn_outputs": loss_fn_outputs, diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py index 97caa4b60f..25bfb217d2 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py @@ -655,6 +655,9 @@ def forward_backward( } ) + for m_batch in micro_buffer: + m_batch["num_microbatches"] = len(micro_buffer) + if not micro_buffer: return {} @@ -673,9 +676,6 @@ def forward_backward( if self.empty_cuda_cache: torch.cuda.empty_cache() - # Track number of micro-batches for metrics - self._micro_batches_accumulated += len(micro_buffer) - # Aggregate metrics across micro-batches all_loss_fn_outputs = [] # Handle separately from scalar metrics for metrics in metrics_list: @@ -685,10 +685,12 @@ def forward_backward( for k, v in metrics.items(): all_metrics[k].append(v) - # Reduce and all-reduce metrics + # Reduce and all-reduce metrics across DP ranks only + # (metrics should be identical within DP groups, i.e., across TP/PP/SP ranks) status = reduce_metrics(dict(all_metrics)) status["policy_lr"] = self.optimizer.param_groups[0]["lr"] - status = all_reduce_metrics(status, self.strategy) + group = mpu.get_data_parallel_group(with_context_parallel=True) + status = all_reduce_metrics(status, self.strategy, group=group) # Add loss_fn_outputs back (not reduced, kept as list) if all_loss_fn_outputs: diff --git a/skyrl/backends/skyrl_train/workers/worker.py b/skyrl/backends/skyrl_train/workers/worker.py index 2b11983cdf..8ca884711f 100644 --- a/skyrl/backends/skyrl_train/workers/worker.py +++ b/skyrl/backends/skyrl_train/workers/worker.py @@ -709,8 +709,13 @@ def forward_backward( for k, v in metrics.items(): all_metrics[k].append(v) + # reduce metrics across micro batches (sum, mean, min, max) result = reduce_metrics(dict(all_metrics)) + # all reduce metrics across DP workers + dp_group = self.device_mesh.get_group("dp") + result = all_reduce_metrics(result, self.strategy, group=dp_group) + # Add back loss_fn_outputs (concatenated across micro-batches) if all_loss_fn_outputs: result["loss_fn_outputs"] = all_loss_fn_outputs @@ -824,9 +829,12 @@ def _forward_backward_micro( rollout_logprobs=rollout_action_logprobs, ) + loss_scale = self.mesh_rank.dp_size + # SFT path: skip KL/entropy terms, return per-token outputs for Tinker API if resolved_loss_name == "cross_entropy": - loss = policy_loss + unscaled_loss = policy_loss + loss = unscaled_loss * loss_scale self.strategy.backward(loss, self.model, self.optimizer) # Compute elementwise loss for Tinker API (per-token NLL) @@ -834,6 +842,7 @@ def _forward_backward_micro( elementwise_loss = -action_log_probs if loss_mask is not None: elementwise_loss = elementwise_loss * loss_mask + elementwise_loss = elementwise_loss * loss_scale # Build per-sequence loss_fn_outputs (matches Tinker's ForwardBackwardOutput structure) # Trim to actual response length per sample (Tinker expects variable-length arrays @@ -889,7 +898,8 @@ def _forward_backward_micro( kl_loss = torch.tensor(0.0) kl_loss_term = kl_loss * self.cfg.algorithm.kl_loss_coef - loss = policy_loss + kl_loss_term - entropy_loss_term + unscaled_loss = policy_loss + kl_loss_term - entropy_loss_term + loss = unscaled_loss * loss_scale self.strategy.backward(loss, self.model, self.optimizer) # Build per-sequence loss_fn_outputs with logprobs. @@ -914,7 +924,7 @@ def _forward_backward_micro( status = { "final_loss": loss.item(), - "policy_loss": policy_loss.item(), + "policy_loss": policy_loss.item() * loss_scale, "policy_entropy": entropy.item(), "response_length": num_actions, "policy_lr": self.scheduler.get_last_lr()[0], @@ -925,37 +935,18 @@ def _forward_backward_micro( if self.cfg.algorithm.use_kl_loss: status["policy_kl"] = kl_loss.item() - loss_fn_outputs = status.pop("loss_fn_outputs", None) - - # All-reduce metrics across DP workers - status = all_reduce_metrics(status, self.strategy) - - # Add back loss_fn_outputs after all_reduce - if loss_fn_outputs is not None: - status["loss_fn_outputs"] = loss_fn_outputs - return status def optim_step(self) -> float: """ - Scale gradients by 1/micro_batches_accumulated, perform optimizer step, and reset counter. + Perform optimizer step. Returns: The gradient norm (before scaling, after clipping) """ - # Scale accumulated gradients by 1/N to get correct average - if self._micro_batches_accumulated > 0: - scale = 1.0 / self._micro_batches_accumulated - for param in self.model.parameters(): - if param.grad is not None: - param.grad.mul_(scale) - # Perform optimizer step (includes gradient clipping) grad_norm = self.strategy.optimizer_step(self.optimizer, self.model, self.scheduler, name="actor") - # Reset counter for next accumulation cycle - self._micro_batches_accumulated = 0 - if grad_norm is not None: grad_norm = grad_norm.detach().cpu().item() return grad_norm diff --git a/skyrl/backends/skyrl_train/workers/worker_utils.py b/skyrl/backends/skyrl_train/workers/worker_utils.py index eb76c5ee7d..bd36148665 100644 --- a/skyrl/backends/skyrl_train/workers/worker_utils.py +++ b/skyrl/backends/skyrl_train/workers/worker_utils.py @@ -24,16 +24,21 @@ def reduce_metrics(metrics: Dict[str, List[float]]) -> Dict[str, float]: return reduced_metrics -def all_reduce_metrics(metrics: Dict[str, List[float]], strategy: DistributedStrategy) -> Dict[str, float]: +def all_reduce_metrics(metrics: Dict[str, List[float]], strategy: DistributedStrategy, group=None) -> Dict[str, float]: """All reduce metrics across all processes.""" min_metrics = {k: v for k, v in metrics.items() if k.endswith("_min")} max_metrics = {k: v for k, v in metrics.items() if k.endswith("_max")} - mean_metrics = {k: v for k, v in metrics.items() if k not in min_metrics and k not in max_metrics} - status_mean = strategy.all_reduce(mean_metrics, op="mean") - status_min = strategy.all_reduce(min_metrics, op="min") - status_max = strategy.all_reduce(max_metrics, op="max") + sum_metrics = {k: v for k, v in metrics.items() if k.endswith("_loss")} + mean_metrics = { + k: v for k, v in metrics.items() if k not in min_metrics and k not in max_metrics and k not in sum_metrics + } + status_mean = strategy.all_reduce(mean_metrics, op="mean", group=group) + status_min = strategy.all_reduce(min_metrics, op="min", group=group) + status_max = strategy.all_reduce(max_metrics, op="max", group=group) + status_sum = strategy.all_reduce(sum_metrics, op="sum", group=group) status_mean.update(status_min) status_mean.update(status_max) + status_mean.update(status_sum) return status_mean diff --git a/skyrl/train/fully_async_trainer.py b/skyrl/train/fully_async_trainer.py index aaed1f01c3..0507c9f639 100644 --- a/skyrl/train/fully_async_trainer.py +++ b/skyrl/train/fully_async_trainer.py @@ -20,7 +20,6 @@ from skyrl.train.trainer import RayPPOTrainer from tqdm import tqdm from skyrl.train.utils import Timer -from skyrl.backends.skyrl_train.utils.ppo_utils import normalize_advantages_dict from skyrl.backends.skyrl_train.training_batch import TrainingInputBatch from skyrl.train.generators.base import GeneratorOutput from skyrl.train.utils.trainer_utils import ResumeMode, build_dataloader @@ -512,9 +511,6 @@ async def _run_training(self, training_input: TrainingInputBatch): training_input.pop(key) training_input.metadata.pop("uids") - if self.cfg.trainer.algorithm.advantage_batch_normalize: - training_input = normalize_advantages_dict(training_input) - if self.cfg.trainer.dump_data_batch: # dump data to file with Timer("dump_data_batch"): diff --git a/skyrl/train/trainer.py b/skyrl/train/trainer.py index 51541bfc77..3acf0dee60 100644 --- a/skyrl/train/trainer.py +++ b/skyrl/train/trainer.py @@ -53,7 +53,6 @@ FixedKLController, compute_approx_kl, get_kl_controller, - normalize_advantages_dict, ) from skyrl.backends.skyrl_train.utils.torch_utils import masked_mean from skyrl.train.utils.tracking import Tracking @@ -275,9 +274,6 @@ async def train(self): training_input.pop(key) training_input.metadata.pop("uids") - if self.cfg.trainer.algorithm.advantage_batch_normalize: - training_input = normalize_advantages_dict(training_input) - if self.cfg.trainer.dump_data_batch: # dump data to file with Timer("dump_data_batch"): @@ -1037,6 +1033,53 @@ def apply_reward_kl_penalty( return data + def normalize_minibatch_advantages(self, data: TrainingInputBatch) -> TrainingInputBatch: + """Normalize the advantages in the mini-batch. + + This function handles two types of normalization: + 1. Batch normalization (z-score): if advantage_batch_normalize is True, + normalizes advantages to have zero mean and unit variance. + 2. Loss reduction normalization: scales advantages based on the loss_reduction + type to calculate the correct minibatch loss when reducing with a sum. + """ + advantages = data["advantages"] + loss_mask = data["loss_mask"] + response_mask = data["response_mask"] + + # NOTE: Do not modify the tensor in place! + # Otherwise subsequent epochs will keep dividing the same tensor. + + # Step 1: Z-score normalization (if enabled) + if self.cfg.trainer.algorithm.advantage_batch_normalize: + num_actions = response_mask.sum() + mean = advantages.mean() + std = ((advantages - mean).pow(2) * response_mask).sum() + rstd = (std / num_actions).clamp(min=1e-8).rsqrt() + advantages = (advantages - mean) * rstd + + # Step 2: Loss reduction normalization + # Option 1: token mean + if self.cfg.trainer.algorithm.loss_reduction == "token_mean": + data["advantages"] = advantages / loss_mask.sum().clamp(min=1) + + # Option 2: sequence mean + elif self.cfg.trainer.algorithm.loss_reduction == "sequence_mean": + batch_size = len(data) + data["advantages"] = advantages / (batch_size * loss_mask.sum(dim=-1, keepdim=True).clamp(min=1)) + + # Option 3: Dr. GRPO style loss reduction to avoid length bias by normalizing by a constant + elif self.cfg.trainer.algorithm.loss_reduction == "seq_mean_token_sum_norm": + batch_size = len(data) + max_seq_len = self.cfg.trainer.algorithm.max_seq_len + data["advantages"] = advantages / (batch_size * max_seq_len) + + else: + # No loss reduction normalization, but still apply batch normalization if it was done + if self.cfg.trainer.algorithm.advantage_batch_normalize: + data["advantages"] = advantages + + return data + def _execute_training_step(self, model: str, data: TrainingInputBatch) -> Dict[str, float]: """ Execute training step for FSDP strategy using forward_backward + optim_step. @@ -1062,13 +1105,22 @@ def _execute_training_step(self, model: str, data: TrainingInputBatch) -> Dict[s mini_batch_size = self.cfg.trainer.critic_mini_batch_size * n_samples all_metrics: Dict[str, List[float]] = defaultdict(list) + num_mini_batches = len(data) // mini_batch_size + + # iterate over mini-batches to do mini batch level normalization + for local_step in range(num_mini_batches): + start_idx = local_step * mini_batch_size + end_idx = (local_step + 1) * mini_batch_size + mini_batch = data[start_idx:end_idx] + mini_batch = self.normalize_minibatch_advantages(mini_batch) + # Copy normalized advantages back to original batch + data["advantages"][start_idx:end_idx] = mini_batch["advantages"] # Stage full batch in object store ONCE to avoid repeated serialization data_ref = self.dispatch.stage_data(data) # Training loop over epochs and mini-batches for _epoch in range(self.cfg.trainer.update_epochs_per_batch): - num_mini_batches = len(data) // mini_batch_size for local_step in range(num_mini_batches): start_idx = local_step * mini_batch_size end_idx = (local_step + 1) * mini_batch_size diff --git a/tests/backends/skyrl_train/gpu/test_grpo_sp_sanity.py b/tests/backends/skyrl_train/gpu/test_grpo_sp_sanity.py index 2aee84d768..654a989560 100644 --- a/tests/backends/skyrl_train/gpu/test_grpo_sp_sanity.py +++ b/tests/backends/skyrl_train/gpu/test_grpo_sp_sanity.py @@ -11,7 +11,6 @@ from skyrl.train.config import SkyRLTrainConfig from skyrl.train.utils import Timer -from skyrl.backends.skyrl_train.utils.ppo_utils import normalize_advantages_dict import asyncio @@ -118,9 +117,6 @@ def train(self): # remove some unwanted keys data.pop(batch_keys=["rewards"]) - if self.cfg.trainer.algorithm.advantage_batch_normalize: - data = normalize_advantages_dict(data) - # 4. train policy/critic model with Timer("train_critic_and_policy", self.all_timings): status = self.train_critic_and_policy(data) From 333f31a0d411fd891f368bf0dbd3fd629418340a Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Mon, 9 Mar 2026 12:05:19 -0700 Subject: [PATCH 02/18] Add token_mean_baseline loss reduction for mean-of-microbatch-means comparison Co-Authored-By: Claude Opus 4.6 --- skyrl/train/trainer.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/skyrl/train/trainer.py b/skyrl/train/trainer.py index 3acf0dee60..87cad372e3 100644 --- a/skyrl/train/trainer.py +++ b/skyrl/train/trainer.py @@ -1046,9 +1046,6 @@ def normalize_minibatch_advantages(self, data: TrainingInputBatch) -> TrainingIn loss_mask = data["loss_mask"] response_mask = data["response_mask"] - # NOTE: Do not modify the tensor in place! - # Otherwise subsequent epochs will keep dividing the same tensor. - # Step 1: Z-score normalization (if enabled) if self.cfg.trainer.algorithm.advantage_batch_normalize: num_actions = response_mask.sum() @@ -1062,6 +1059,21 @@ def normalize_minibatch_advantages(self, data: TrainingInputBatch) -> TrainingIn if self.cfg.trainer.algorithm.loss_reduction == "token_mean": data["advantages"] = advantages / loss_mask.sum().clamp(min=1) + # Option 1b: token-mean within each microbatch, then mean across microbatches + elif self.cfg.trainer.algorithm.loss_reduction == "token_mean_baseline": + micro_batch_size = self.cfg.trainer.micro_train_batch_size_per_gpu + num_micro_batches = len(data) // micro_batch_size + for i in range(num_micro_batches): + start_idx = i * micro_batch_size + end_idx = (i + 1) * micro_batch_size + microbatch_advantages = advantages[start_idx:end_idx] + microbatch_loss_mask = loss_mask[start_idx:end_idx] + # Compute token-mean within each microbatch + microbatch_advantages = microbatch_advantages / microbatch_loss_mask.sum().clamp(min=1) + # Average across microbatches + microbatch_advantages /= num_micro_batches + data["advantages"][start_idx:end_idx] = microbatch_advantages + # Option 2: sequence mean elif self.cfg.trainer.algorithm.loss_reduction == "sequence_mean": batch_size = len(data) From aaaba4c50b54dcdf813198ddf0f6bd5d5c3c9ba3 Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Mon, 9 Mar 2026 12:05:36 -0700 Subject: [PATCH 03/18] fix assertion Signed-off-by: Justin Yu --- skyrl/train/utils/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/skyrl/train/utils/utils.py b/skyrl/train/utils/utils.py index bea1d6bb0c..1d3a5891f1 100644 --- a/skyrl/train/utils/utils.py +++ b/skyrl/train/utils/utils.py @@ -270,6 +270,7 @@ def validate_cfg(cfg: SkyRLTrainConfig): assert cfg.trainer.algorithm.loss_reduction in ( "token_mean", + "token_mean_baseline", "sequence_mean", "seq_mean_token_sum_norm", ), ( From a1213604ca49192f8975fad523846930c13926d2 Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Mon, 9 Mar 2026 18:27:14 -0700 Subject: [PATCH 04/18] Update tests for sum-based reduce_loss and dp_size scaling changes Co-Authored-By: Claude Opus 4.6 --- .../gpu/gpu_ci/test_megatron_worker.py | 20 +- .../skyrl_train/utils/test_ppo_utils.py | 32 +--- .../skyrl_train/workers/test_worker_utils.py | 60 ++++-- tests/train/algorithms/test_losses.py | 180 +----------------- tests/train/test_trainer.py | 6 +- 5 files changed, 80 insertions(+), 218 deletions(-) diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_megatron_worker.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_megatron_worker.py index 5ec47a93cd..76672c0c58 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/test_megatron_worker.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_megatron_worker.py @@ -78,13 +78,14 @@ def get_test_training_batch(batch_size=4) -> TrainingInputBatch: sequences = [tokenizer.encode(sentence) for sentence in sentences] attention_masks = [[1] * len(seq) for seq in sequences] - num_actions = 10 + num_actions = 15 # max seq len 1 longer than the longest sequence so we always have some padding max_seq_length = max([len(seq) for seq in sequences]) + 7 pad_token_id = tokenizer.pad_token_id pad_before = [4, 0, 1, 6] * num_repeats pad_after = [max_seq_length - len(seq) - pad_before[i] for i, seq in enumerate(sequences)] + loss_masks = torch.stack([torch.cat([torch.ones(num_actions - pad_after[i]), torch.zeros(pad_after[i])]) for i in range(batch_size)]) for i, (pad_before, pad_after) in enumerate(zip(pad_before, pad_after)): sequences[i] = [pad_token_id] * pad_before + sequences[i] + [pad_token_id] * pad_after @@ -103,8 +104,8 @@ def get_test_training_batch(batch_size=4) -> TrainingInputBatch: "values": torch.tensor([[0.1] * num_actions] * batch_size), "returns": torch.tensor([[0.1] * num_actions] * batch_size), "advantages": torch.tensor([[0.5] * num_actions] * batch_size), - "loss_mask": torch.tensor([[1] * num_actions] * batch_size), - "response_mask": torch.tensor([[1] * num_actions] * batch_size), + "loss_mask": loss_masks, + "response_mask": loss_masks, } ) data.metadata = {"response_length": num_actions} @@ -439,11 +440,11 @@ async def test_megatron_lora_forward(ray_init_fixture, tp, pp, cp, ep, etp, gpus ("policy", 4, 1, 1, 4, 1, 4, True, False, True), ], ids=[ - "tp2_pp2_policy_seq_packing", + "x", "tp2_pp2_policy_seq_packing_with_entropy_loss", "tp2_pp2_policy_lora", "tp2_pp2_policy_unpacked", - "tp2_cp2_policy_seq_packing", + "tp2_cp2_policy_seq_packing_no_entropy_loss", "tp4_pp1_cp1_ep4_etp1_policy_seq_packing", "tp4_pp1_cp1_ep4_etp1_policy_seq_packing_lora", ], @@ -456,7 +457,8 @@ async def test_megatron_train( Full test: initialize actor group, send dummy experience to training_step, validate output. """ cfg = get_test_actor_config(model_name=MODEL_NAME if ep == 1 else MOE_MODEL_NAME) - batch = get_test_training_batch(batch_size=gpus_per_node) + batch_size = gpus_per_node * 8 + batch = get_test_training_batch(batch_size=batch_size) cfg.trainer.strategy = "megatron" cfg.trainer.placement.policy_num_gpus_per_node = gpus_per_node @@ -466,6 +468,7 @@ async def test_megatron_train( cfg.trainer.policy.megatron_config.expert_model_parallel_size = ep cfg.trainer.policy.megatron_config.expert_tensor_parallel_size = etp cfg.trainer.use_sample_packing = use_sample_packing + cfg.trainer.algorithm.use_kl_loss = False if use_entropy_loss: cfg.trainer.algorithm.use_entropy_loss = True cfg.trainer.algorithm.entropy_loss_coef = 0.01 @@ -483,7 +486,7 @@ async def test_megatron_train( cfg.trainer.algorithm.off_policy_correction.geo_mask_low = 0.98 # set batch sizes correctly - cfg.trainer.train_batch_size = gpus_per_node + cfg.trainer.train_batch_size = batch_size cfg.trainer.policy_mini_batch_size = gpus_per_node cfg.generator.n_samples_per_prompt = 1 cfg.trainer.micro_train_batch_size_per_gpu = 1 @@ -546,7 +549,7 @@ async def test_megatron_train( # Both FSDP and Megatron use forward_backward + optim_step (unified interface) batch.metadata["global_step"] = 0 - results_fsdp = ray.get(actor_group.async_run_ray_method("pass_through", "forward_backward", batch)) + results_fsdp = ray.get(actor_group.async_run_ray_method("mesh", "forward_backward", batch)) ray.get(actor_group.async_run_ray_method("pass_through", "optim_step")) # Get learning rate from worker lr_results = ray.get(actor_group.async_run_ray_method("pass_through", "get_lr")) @@ -562,7 +565,6 @@ async def test_megatron_train( "policy_lr", "loss_metrics/clip_ratio", "policy_entropy", - "policy_kl", "final_loss", ] if ep > 1: diff --git a/tests/backends/skyrl_train/utils/test_ppo_utils.py b/tests/backends/skyrl_train/utils/test_ppo_utils.py index 2e3070fa09..5ab0875146 100644 --- a/tests/backends/skyrl_train/utils/test_ppo_utils.py +++ b/tests/backends/skyrl_train/utils/test_ppo_utils.py @@ -243,29 +243,17 @@ def test_compute_gae_advantage_return_lam(advantage_test_data): def test_reduce_loss(): - """Test the reduce_loss function with different reduction types.""" - # Test data: 2x3 loss tensor with different valid token counts per sequence + """Test that reduce_loss computes the masked sum correctly.""" loss = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) - loss_mask = torch.tensor([[1.0, 1.0, 1.0], [1.0, 0.0, 0.0]]) # seq0 has 3 tokens, seq1 has 1 token - - # Test token_mean: sum all valid losses / count valid tokens - # Valid losses: [1.0, 2.0, 3.0, 4.0], mean = 10.0/4 = 2.5 - result_token = reduce_loss(loss, loss_mask, "token_mean") - expected_token = torch.tensor(2.5) - assert torch.allclose(result_token, expected_token), f"Expected {expected_token}, got {result_token}" - - # Test sequence_mean: mean of per-sequence means - # Seq 0: (1.0 + 2.0 + 3.0) / 3 = 2.0, Seq 1: 4.0 / 1 = 4.0, batch mean = (2.0 + 4.0) / 2 = 3.0 - result_seq = reduce_loss(loss, loss_mask, "sequence_mean") - expected_seq = torch.tensor(3.0) - assert torch.allclose(result_seq, expected_seq), f"Expected {expected_seq}, got {result_seq}" - - # Test seq_mean_token_sum_norm: sum per sequence / max_len, then batch mean - # Seq 0: (1.0 + 2.0 + 3.0) / 4 = 1.5, Seq 1: 4.0 / 4 = 1.0, batch mean = (1.5 + 1.0) / 2 = 1.25 - max_seq_len = 4 - result_max = reduce_loss(loss, loss_mask, "seq_mean_token_sum_norm", max_seq_len) - expected_max = torch.tensor(1.25) - assert torch.allclose(result_max, expected_max), f"Expected {expected_max}, got {result_max}" + loss_mask = torch.tensor([[1.0, 1.0, 1.0], [1.0, 0.0, 0.0]]) + + # With mask: sum of valid losses = 1.0 + 2.0 + 3.0 + 4.0 = 10.0 + result = reduce_loss(loss, loss_mask) + assert torch.allclose(result, torch.tensor(10.0)) + + # Without mask: sum of all losses = 1+2+3+4+5+6 = 21.0 + result_no_mask = reduce_loss(loss, None) + assert torch.allclose(result_no_mask, torch.tensor(21.0)) def test_adaptive_kl_controller_update(): diff --git a/tests/backends/skyrl_train/workers/test_worker_utils.py b/tests/backends/skyrl_train/workers/test_worker_utils.py index fee8fb6ffa..fe7f637dfe 100644 --- a/tests/backends/skyrl_train/workers/test_worker_utils.py +++ b/tests/backends/skyrl_train/workers/test_worker_utils.py @@ -23,10 +23,16 @@ def test_reduce_metrics_min_suffix(self): assert result["is_ratio_min"] == 1.0 def test_reduce_metrics_mean_default(self): - """Keys without _max/_min suffix should use mean reduction.""" + """Keys without _max/_min/_loss suffix should use mean reduction.""" + metrics = {"entropy": [1.0, 2.0, 3.0]} + result = reduce_metrics(metrics) + assert result["entropy"] == 2.0 + + def test_reduce_metrics_loss_sum(self): + """Keys ending in _loss should use sum reduction.""" metrics = {"policy_loss": [1.0, 2.0, 3.0]} result = reduce_metrics(metrics) - assert result["policy_loss"] == 2.0 # mean of [1, 2, 3] + assert result["policy_loss"] == 6.0 # sum of [1, 2, 3] def test_reduce_metrics_mixed(self): """Test mixed metric types are reduced correctly.""" @@ -34,11 +40,13 @@ def test_reduce_metrics_mixed(self): "is_ratio_max": [1.0, 10.0], "is_ratio_min": [0.5, 2.0], "policy_loss": [1.0, 3.0], + "entropy": [1.0, 3.0], } result = reduce_metrics(metrics) assert result["is_ratio_max"] == 10.0 assert result["is_ratio_min"] == 0.5 - assert result["policy_loss"] == 2.0 + assert result["policy_loss"] == 4.0 # sum + assert result["entropy"] == 2.0 # mean def test_reduce_metrics_single_value(self): """Test reduction with single value lists.""" @@ -65,7 +73,7 @@ def test_all_reduce_metrics_separates_by_suffix(self): strategy = MagicMock() # Mock all_reduce to return the input dict unchanged but track calls - def mock_all_reduce(d, op): + def mock_all_reduce(d, op, group=None): return {k: v for k, v in d.items()} strategy.all_reduce.side_effect = mock_all_reduce @@ -79,8 +87,8 @@ def mock_all_reduce(d, op): _ = all_reduce_metrics(metrics, strategy) - # Verify all_reduce was called 3 times - assert strategy.all_reduce.call_count == 3 + # Verify all_reduce was called 4 times + assert strategy.all_reduce.call_count == 4 # Check that the correct ops were used calls = strategy.all_reduce.call_args_list @@ -93,9 +101,13 @@ def mock_all_reduce(d, op): op = kwargs.get("op") if kwargs else args[1] ops_and_keys.append((op, set(data_dict.keys()))) - # Verify mean metrics (policy_loss, entropy) + # Verify mean metrics (entropy) mean_call = [c for c in ops_and_keys if c[0] == "mean"][0] - assert mean_call[1] == {"policy_loss", "entropy"} + assert mean_call[1] == {"entropy"} + + # Verify sum metrics (_loss suffix) + sum_call = [c for c in ops_and_keys if c[0] == "sum"][0] + assert sum_call[1] == {"policy_loss"} # Verify min metrics min_call = [c for c in ops_and_keys if c[0] == "min"][0] @@ -110,9 +122,11 @@ def test_all_reduce_metrics_returns_merged_results(self): strategy = MagicMock() # Mock all_reduce to modify values based on op - def mock_all_reduce(d, op): + def mock_all_reduce(d, op, group=None): if op == "mean": return {k: v * 2 for k, v in d.items()} # Double for mean + elif op == "sum": + return {k: v * 4 for k, v in d.items()} # Quadruple for sum elif op == "min": return {k: v / 2 for k, v in d.items()} # Halve for min elif op == "max": @@ -125,6 +139,7 @@ def mock_all_reduce(d, op): "is_ratio_max": 10.0, "is_ratio_min": 0.1, "policy_loss": 1.5, + "entropy": 0.5, } result = all_reduce_metrics(metrics, strategy) @@ -133,16 +148,18 @@ def mock_all_reduce(d, op): assert "is_ratio_max" in result assert "is_ratio_min" in result assert "policy_loss" in result + assert "entropy" in result # Check values were transformed correctly assert result["is_ratio_max"] == 30.0 # 10.0 * 3 (max op) assert result["is_ratio_min"] == 0.05 # 0.1 / 2 (min op) - assert result["policy_loss"] == 3.0 # 1.5 * 2 (mean op) + assert result["policy_loss"] == 6.0 # sum op + assert result["entropy"] == 1.0 # 0.5 * 2 (mean op) def test_all_reduce_metrics_only_max(self): """Test with only _max metrics.""" strategy = MagicMock() - strategy.all_reduce.side_effect = lambda d, op: d + strategy.all_reduce.side_effect = lambda d, op, group=None: d metrics = {"loss_max": 5.0, "ratio_max": 10.0} @@ -153,7 +170,7 @@ def test_all_reduce_metrics_only_max(self): def test_all_reduce_metrics_only_min(self): """Test with only _min metrics.""" strategy = MagicMock() - strategy.all_reduce.side_effect = lambda d, op: d + strategy.all_reduce.side_effect = lambda d, op, group=None: d metrics = {"loss_min": 0.1, "ratio_min": 0.01} @@ -162,12 +179,23 @@ def test_all_reduce_metrics_only_min(self): assert result == {"loss_min": 0.1, "ratio_min": 0.01} def test_all_reduce_metrics_only_mean(self): - """Test with only mean metrics (no _max/_min suffix).""" + """Test with only mean metrics (no _max/_min/_loss suffix).""" + strategy = MagicMock() + strategy.all_reduce.side_effect = lambda d, op, group=None: d + + metrics = {"entropy": 0.5, "kl_div": 1.5} + + result = all_reduce_metrics(metrics, strategy) + + assert result == {"entropy": 0.5, "kl_div": 1.5} + + def test_all_reduce_metrics_only_sum(self): + """Test with only _loss metrics (sum reduction).""" strategy = MagicMock() - strategy.all_reduce.side_effect = lambda d, op: d + strategy.all_reduce.side_effect = lambda d, op, group=None: d - metrics = {"policy_loss": 1.5, "entropy": 0.5} + metrics = {"policy_loss": 1.5, "value_loss": 0.5} result = all_reduce_metrics(metrics, strategy) - assert result == {"policy_loss": 1.5, "entropy": 0.5} + assert result == {"policy_loss": 1.5, "value_loss": 0.5} diff --git a/tests/train/algorithms/test_losses.py b/tests/train/algorithms/test_losses.py index 23fac03b0d..65ed5c7fc3 100644 --- a/tests/train/algorithms/test_losses.py +++ b/tests/train/algorithms/test_losses.py @@ -47,7 +47,6 @@ def test_policy_loss_dual_clip(): eps_clip_high=0.2, clip_ratio_c=3.0, policy_loss_type="dual_clip", - loss_reduction="token_mean", max_seq_len=4, off_policy_correction=NULL_OFF_POLICY_CORR, ) @@ -71,7 +70,7 @@ def test_policy_loss_dual_clip(): # For negative advantages, use dual clipped loss final_loss = torch.where(advantages < 0, min_loss, max_loss) # [-0.5, 1.0, 12.0] assert torch.allclose(final_loss, torch.tensor([[-0.5, 1.0, 12.0]], device=device), rtol=1e-3) - expected_loss = final_loss.mean() # -(-12.5/3) = 4.1667 + expected_loss = final_loss.sum() # Calculate actual loss actual_loss, _ = loss_fn(log_probs=log_probs, old_log_probs=old_log_probs, advantages=advantages, config=config) @@ -79,7 +78,7 @@ def test_policy_loss_dual_clip(): # Verify results torch.testing.assert_close(actual_loss, expected_loss, rtol=1e-3, atol=1e-8) # close to hand calculated value - assert actual_loss.item() == pytest.approx(4.1667, abs=1e-4) + assert actual_loss.item() == pytest.approx(12.5, abs=1e-4) def test_policy_loss_cispo(): @@ -98,7 +97,6 @@ def test_policy_loss_cispo(): config = AlgorithmConfig( cispo=CISPOConfig(cispo_eps_clip_low=0.2, cispo_eps_clip_high=0.2), policy_loss_type="cispo", - loss_reduction="token_mean", max_seq_len=4, off_policy_correction=NULL_OFF_POLICY_CORR, ) @@ -119,9 +117,9 @@ def test_policy_loss_cispo(): # loss_per_token[0] = -(1.0 * 0.8 * -1.69315) = 1.35452 # loss_per_token[1] = -(-1.0 * 1.0 * -1.0) = -1.0 # loss_per_token[2] = -(-4.0 * 1.2 * -0.69741) = -3.347568 - # mean(loss) = (1.35452 - 1.0 - 3.347568) / 3 = -0.99768266666 + # sum(loss) = (1.35452 - 1.0 - 3.347568) = -2.9930 loss = -ratio.clamp(1 - 0.2, 1 + 0.2) * advantages * log_probs - expected_loss = loss.mean() + expected_loss = loss.sum() # Calculate actual loss actual_loss, _ = loss_fn( @@ -134,158 +132,7 @@ def test_policy_loss_cispo(): # Verify results torch.testing.assert_close(actual_loss, expected_loss, rtol=1e-3, atol=1e-8) # close to hand calculated value - assert actual_loss.item() == pytest.approx(-0.99768266666, abs=1e-4) - - -def test_policy_loss_reduction_modes(): - """Tests different loss_reduction modes in PolicyLoss function. - - Note: token_mean and sequence_mean give the same result when all sequences - have the same length and no mask is applied, but differ when masking creates - different effective sequence lengths. - """ - - device = "cpu" - - clip_eps_low = 0.2 - clip_eps_high = 0.2 - - advantages = torch.tensor( - [ - [2.0, 2.0, 2.0], # sequence 1: consistently higher advantages - [1.0, 1.0, 1.0], # sequence 2: consistently lower advantages - ], - device=device, - ) - - old_log_probs = torch.tensor([[-1.0, -1.0, -1.0], [-1.0, -1.0, -1.0]], device=device) - - log_probs = torch.tensor( - [[-1.5, -0.5, -1.2], [-0.8, -1.3, -0.9]], # ratios ≈ [[0.61, 1.65, 0.83],[1.22, 0.74, 1.11]] - device=device, - ) - - # Create masks to test sequences with different numbers of valid tokens - loss_mask = torch.tensor([[1.0, 1.0, 1.0], [1.0, 0.0, 0.0]], device=device) - - # Create configs for different reduction modes - config_token = AlgorithmConfig( - eps_clip_low=clip_eps_low, - eps_clip_high=clip_eps_high, - clip_ratio_c=3.0, - policy_loss_type="regular", - loss_reduction="token_mean", - max_seq_len=4, - off_policy_correction=NULL_OFF_POLICY_CORR, - ) - - config_seq = AlgorithmConfig( - eps_clip_low=clip_eps_low, - eps_clip_high=clip_eps_high, - clip_ratio_c=3.0, - policy_loss_type="regular", - loss_reduction="sequence_mean", - max_seq_len=4, - off_policy_correction=NULL_OFF_POLICY_CORR, - ) - - # Get loss function - loss_fn = PolicyLossRegistry.get("regular") - - # Test token_mean without mask - loss_token_no_mask, _ = loss_fn(log_probs, old_log_probs, advantages, config_token) - - # Test token_mean with mask - loss_token_with_mask, _ = loss_fn(log_probs, old_log_probs, advantages, config_token, loss_mask) - - # Test sequence_mean without mask - loss_seq_no_mask, _ = loss_fn(log_probs, old_log_probs, advantages, config_seq) - - # Test sequence_mean with mask - loss_seq_with_mask, _ = loss_fn(log_probs, old_log_probs, advantages, config_seq, loss_mask) - - # Manual calculations to verify (using default PolicyLoss parameters) - ratio = torch.exp(log_probs - old_log_probs) - surr1 = ratio * advantages - surr2 = ratio.clamp(1 - clip_eps_low, 1 + clip_eps_high) * advantages # clip_eps_low=0.2, clip_eps_high=0.2 - loss_per_token = -torch.min(surr1, surr2) - - # Expected token_mean without mask: mean of all tokens - expected_token_no_mask = loss_per_token.mean() - - # Expected token_mean with mask: masked mean of all tokens - expected_token_with_mask = (loss_per_token * loss_mask).sum() / (loss_mask.sum() + 1e-8) - - # Expected sequence_mean without mask: mean of sequence means - expected_seq_no_mask = loss_per_token.mean(dim=1).mean() - - # Expected sequence_mean with mask: mean of masked sequence means - seq_means_masked = (loss_per_token * loss_mask).sum(dim=1) / (loss_mask.sum(dim=1) + 1e-8) - expected_seq_with_mask = seq_means_masked.mean() - - # Verify results - torch.testing.assert_close(loss_token_no_mask, expected_token_no_mask, rtol=1e-5, atol=1e-8) - torch.testing.assert_close(loss_token_with_mask, expected_token_with_mask, rtol=1e-5, atol=1e-8) - torch.testing.assert_close(loss_seq_no_mask, expected_seq_no_mask, rtol=1e-5, atol=1e-8) - torch.testing.assert_close(loss_seq_with_mask, expected_seq_with_mask, rtol=1e-5, atol=1e-8) - - # Verify that the two reduction modes give the same results when sequences have equal length and no mask - assert torch.allclose( - loss_token_no_mask, loss_seq_no_mask, rtol=1e-5 - ), "token_mean and sequence_mean should give same results when sequences have equal length and no mask" - # But they should give different results when mask creates different effective sequence lengths - assert not torch.allclose( - loss_token_with_mask, loss_seq_with_mask, rtol=1e-3 - ), "token_mean and sequence_mean with mask should give different results" - - -def test_policy_loss_reduction_edge_cases(): - """Tests edge cases for loss_reduction modes.""" - - device = "cpu" - - # Test with single sequence (should give same result for both modes) - advantages = torch.tensor([[1.0, -1.0, 2.0]], device=device) - old_log_probs = torch.tensor([[-1.0, -1.0, -1.0]], device=device) - log_probs = torch.tensor([[-1.5, -0.5, -1.2]], device=device) - - # Create configs for different reduction modes - config_token = AlgorithmConfig( - eps_clip_low=0.2, - eps_clip_high=0.2, - clip_ratio_c=3.0, - policy_loss_type="regular", - loss_reduction="token_mean", - max_seq_len=4, - off_policy_correction=NULL_OFF_POLICY_CORR, - ) - - config_seq = AlgorithmConfig( - eps_clip_low=0.2, - eps_clip_high=0.2, - clip_ratio_c=3.0, - policy_loss_type="regular", - loss_reduction="sequence_mean", - max_seq_len=4, - off_policy_correction=NULL_OFF_POLICY_CORR, - ) - # Get loss function - loss_fn = PolicyLossRegistry.get("regular") - - loss_token, _ = loss_fn(log_probs, old_log_probs, advantages, config_token) - loss_seq, _ = loss_fn(log_probs, old_log_probs, advantages, config_seq) - - # With single sequence, both modes should give same result - torch.testing.assert_close(loss_token, loss_seq, rtol=1e-6, atol=1e-8) - - # Test with completely masked sequence - loss_mask = torch.tensor([[0.0, 0.0, 0.0]], device=device) - loss_token_masked, _ = loss_fn(log_probs, old_log_probs, advantages, config_token, loss_mask) - loss_seq_masked, _ = loss_fn(log_probs, old_log_probs, advantages, config_seq, loss_mask) - - # Should handle zero mask gracefully (due to +1e-8 in denominator) - assert torch.isfinite(loss_token_masked) - assert torch.isfinite(loss_seq_masked) + assert actual_loss.item() == pytest.approx(-2.9930, abs=1e-4) def test_gspo_importance_sampling_levels(): @@ -348,7 +195,6 @@ def test_gspo_importance_sampling_levels(): eps_clip_high=clip_eps_high, clip_ratio_c=3.0, policy_loss_type="regular", - loss_reduction="token_mean", max_seq_len=4, off_policy_correction=NULL_OFF_POLICY_CORR, ) @@ -361,7 +207,6 @@ def test_gspo_importance_sampling_levels(): eps_clip_high=clip_eps_high, clip_ratio_c=3.0, policy_loss_type="gspo", - loss_reduction="sequence_mean", # GSPO recommended reduction max_seq_len=4, off_policy_correction=NULL_OFF_POLICY_CORR, ) @@ -374,7 +219,7 @@ def test_gspo_importance_sampling_levels(): surr1_token = ratio_token * advantages surr2_token = ratio_token.clamp(1 - clip_eps_low, 1 + clip_eps_high) * advantages loss_per_token_token = -torch.min(surr1_token, surr2_token) - expected_token = (loss_per_token_token * loss_mask).sum() / (loss_mask.sum() + 1e-8) + expected_token = (loss_per_token_token * loss_mask).sum() # Calculate token-level clipping ratio is_clipped_token = (-surr2_token > -surr1_token) & (loss_mask.bool()) @@ -390,8 +235,8 @@ def test_gspo_importance_sampling_levels(): surr1_sequence = ratio_sequence * advantages surr2_sequence = ratio_sequence.clamp(1 - clip_eps_low, 1 + clip_eps_high) * advantages loss_per_token_sequence = -torch.min(surr1_sequence, surr2_sequence) - # GSPO uses sequence_mean reduction - expected_sequence = masked_mean(loss_per_token_sequence, loss_mask, dim=-1).mean() + # GSPO uses sum reduction + expected_sequence = loss_per_token_sequence.sum() # Calculate sequence-level clipping ratio is_clipped_sequence = (-surr2_sequence > -surr1_sequence) & (loss_mask.bool()) @@ -465,7 +310,6 @@ def test_clip_cov_policy_loss(): eps_clip_low=0.2, eps_clip_high=0.2, policy_loss_type="clip_cov", - loss_reduction="token_mean", max_seq_len=4, clip_cov=ClipCovConfig(clip_ratio=0.5, clip_cov_lb=-5.0, clip_cov_ub=5.0), # Large ratio for testing off_policy_correction=NULL_OFF_POLICY_CORR, @@ -487,7 +331,6 @@ def test_clip_cov_policy_loss(): eps_clip_low=0.2, eps_clip_high=0.2, policy_loss_type="regular", - loss_reduction="token_mean", max_seq_len=4, off_policy_correction=NULL_OFF_POLICY_CORR, ) @@ -525,7 +368,6 @@ def test_kl_cov_policy_loss(): # Create KL-Cov config config = AlgorithmConfig( policy_loss_type="kl_cov", - loss_reduction="token_mean", max_seq_len=4, kl_cov=KLCovConfig(kl_cov_frac=0.5, ppo_kl_coef=1.0), # Apply KL to 50% of tokens off_policy_correction=NULL_OFF_POLICY_CORR, @@ -546,7 +388,6 @@ def test_kl_cov_policy_loss(): eps_clip_low=0.2, eps_clip_high=0.2, policy_loss_type="regular", - loss_reduction="token_mean", max_seq_len=4, use_tis=False, off_policy_correction=NULL_OFF_POLICY_CORR, @@ -577,7 +418,6 @@ def test_sapo_policy_loss_basic(): # SAPO config: uses sequence_mean reduction and distinct tau_pos / tau_neg config = AlgorithmConfig( policy_loss_type="sapo", - loss_reduction="sequence_mean", max_seq_len=4, sapo=SAPOConfig(tau_pos=1.0, tau_neg=2.0), off_policy_correction=NULL_OFF_POLICY_CORR, @@ -609,8 +449,8 @@ def gate_function(x, tau): gates = gate_function(ratio, taus) loss_per_token = -gates * advantages - # sequence_mean reduction: per-sequence token mean, then batch mean - expected_loss = loss_per_token.mean(dim=-1).mean() + # sum reduction + expected_loss = loss_per_token.sum() torch.testing.assert_close(actual_loss, expected_loss, rtol=1e-5, atol=1e-8) diff --git a/tests/train/test_trainer.py b/tests/train/test_trainer.py index 04666ba19d..8ccffc44e2 100644 --- a/tests/train/test_trainer.py +++ b/tests/train/test_trainer.py @@ -448,7 +448,11 @@ def create_test_worker(worker_class): # Mock dependencies worker.strategy = MagicMock() worker.strategy.is_rank_0.return_value = False # Disable progress bars - worker.strategy.all_reduce.return_value = {"loss": 0.5, "lr": 1e-4} + worker.strategy.all_reduce.side_effect = lambda d, op, group=None: d # Return input dict unchanged + + # Mock device_mesh for DP group access + worker.device_mesh = MagicMock() + worker.device_mesh.get_group.return_value = None # No actual process group in tests # Always set model for all worker types worker.model = MagicMock() From e3842c3e798a8b294368a36ccc43cfec0f708921 Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Tue, 17 Mar 2026 10:53:18 -0700 Subject: [PATCH 05/18] lint Signed-off-by: Justin Yu --- tests/backends/skyrl_train/gpu/gpu_ci/test_megatron_worker.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_megatron_worker.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_megatron_worker.py index 947489259d..27936c1810 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/test_megatron_worker.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_megatron_worker.py @@ -91,7 +91,9 @@ def get_test_training_batch(batch_size=4) -> TrainingInputBatch: pad_token_id = tokenizer.pad_token_id pad_before = [4, 0, 1, 6] * num_repeats pad_after = [max_seq_length - len(seq) - pad_before[i] for i, seq in enumerate(sequences)] - loss_masks = torch.stack([torch.cat([torch.ones(num_actions - pad_after[i]), torch.zeros(pad_after[i])]) for i in range(batch_size)]) + loss_masks = torch.stack( + [torch.cat([torch.ones(num_actions - pad_after[i]), torch.zeros(pad_after[i])]) for i in range(batch_size)] + ) for i, (pad_before, pad_after) in enumerate(zip(pad_before, pad_after)): sequences[i] = [pad_token_id] * pad_before + sequences[i] + [pad_token_id] * pad_after From 13bfe802386d059556d6546f11898b8264d5fed5 Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Tue, 17 Mar 2026 11:49:42 -0700 Subject: [PATCH 06/18] fix tests Signed-off-by: Justin Yu --- skyrl/backends/skyrl_train/workers/worker_utils.py | 2 ++ tests/backends/skyrl_train/gpu/gpu_ci/test_megatron_worker.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/skyrl/backends/skyrl_train/workers/worker_utils.py b/skyrl/backends/skyrl_train/workers/worker_utils.py index fa53af809e..ba02929da3 100644 --- a/skyrl/backends/skyrl_train/workers/worker_utils.py +++ b/skyrl/backends/skyrl_train/workers/worker_utils.py @@ -20,6 +20,8 @@ def reduce_metrics(metrics: Dict[str, List[float]]) -> Dict[str, float]: reduced_metrics[k] = max(v) elif k.endswith("_min"): reduced_metrics[k] = min(v) + elif k.endswith("_loss"): + reduced_metrics[k] = sum(v) else: reduced_metrics[k] = sum(v) / len(v) return reduced_metrics diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_megatron_worker.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_megatron_worker.py index 27936c1810..1ebbba45a0 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/test_megatron_worker.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_megatron_worker.py @@ -448,7 +448,7 @@ async def test_megatron_lora_forward(ray_init_fixture, tp, pp, cp, ep, etp, gpus ("policy", 4, 1, 1, 4, 1, 4, True, False, True), ], ids=[ - "x", + "tp2_pp2_policy_seq_packing", "tp2_pp2_policy_seq_packing_with_entropy_loss", "tp2_pp2_policy_lora", "tp2_pp2_policy_unpacked", From e76bece89388ffbc8bec05c1b685b946c8805aea Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Fri, 20 Mar 2026 14:10:23 -0700 Subject: [PATCH 07/18] Refactor advantage normalization: fix z-score propagation, skip for critic, rename token_mean_baseline to token_mean_legacy Co-Authored-By: Claude Opus 4.6 --- skyrl/train/trainer.py | 62 ++++++++++++++++++++---------------------- 1 file changed, 30 insertions(+), 32 deletions(-) diff --git a/skyrl/train/trainer.py b/skyrl/train/trainer.py index 7c4c08c428..6aa514c2d2 100644 --- a/skyrl/train/trainer.py +++ b/skyrl/train/trainer.py @@ -1047,17 +1047,8 @@ def apply_reward_kl_penalty( return data - def normalize_minibatch_advantages(self, data: TrainingInputBatch) -> TrainingInputBatch: - """Normalize the advantages in the mini-batch. - - This function handles two types of normalization: - 1. Batch normalization (z-score): if advantage_batch_normalize is True, - normalizes advantages to have zero mean and unit variance. - 2. Loss reduction normalization: scales advantages based on the loss_reduction - type to calculate the correct minibatch loss when reducing with a sum. - """ + def normalize_advantages(self, data: TrainingInputBatch, mini_batch_size: int) -> TrainingInputBatch: advantages = data["advantages"] - loss_mask = data["loss_mask"] response_mask = data["response_mask"] # Step 1: Z-score normalization (if enabled) @@ -1066,15 +1057,33 @@ def normalize_minibatch_advantages(self, data: TrainingInputBatch) -> TrainingIn mean = advantages.mean() std = ((advantages - mean).pow(2) * response_mask).sum() rstd = (std / num_actions).clamp(min=1e-8).rsqrt() - advantages = (advantages - mean) * rstd + data["advantages"] = (advantages - mean) * rstd # Step 2: Loss reduction normalization + num_mini_batches = len(data) // mini_batch_size + normalized_advantages = torch.zeros_like(advantages) + for local_step in range(num_mini_batches): + start_idx = local_step * mini_batch_size + end_idx = (local_step + 1) * mini_batch_size + normalized_advantages[start_idx:end_idx] = self._normalize_minibatch_advantages(data[start_idx:end_idx]) + + data["advantages"] = normalized_advantages + return data + + def _normalize_minibatch_advantages(self, data: TrainingInputBatch) -> torch.Tensor: + """Normalize the advantages in the mini-batch.""" + advantages = data["advantages"] + loss_mask = data["loss_mask"] + + normalized_advantages = torch.zeros_like(advantages) + batch_size = len(data) + # Option 1: token mean if self.cfg.trainer.algorithm.loss_reduction == "token_mean": - data["advantages"] = advantages / loss_mask.sum().clamp(min=1) + normalized_advantages = advantages / loss_mask.sum().clamp(min=1) - # Option 1b: token-mean within each microbatch, then mean across microbatches - elif self.cfg.trainer.algorithm.loss_reduction == "token_mean_baseline": + # Option 1b: legacy token-mean implementation which averages token mean across microbatches + elif self.cfg.trainer.algorithm.loss_reduction == "token_mean_legacy": micro_batch_size = self.cfg.trainer.micro_train_batch_size_per_gpu num_micro_batches = len(data) // micro_batch_size for i in range(num_micro_batches): @@ -1086,25 +1095,21 @@ def normalize_minibatch_advantages(self, data: TrainingInputBatch) -> TrainingIn microbatch_advantages = microbatch_advantages / microbatch_loss_mask.sum().clamp(min=1) # Average across microbatches microbatch_advantages /= num_micro_batches - data["advantages"][start_idx:end_idx] = microbatch_advantages + normalized_advantages[start_idx:end_idx] = microbatch_advantages # Option 2: sequence mean elif self.cfg.trainer.algorithm.loss_reduction == "sequence_mean": - batch_size = len(data) - data["advantages"] = advantages / (batch_size * loss_mask.sum(dim=-1, keepdim=True).clamp(min=1)) + normalized_advantages = advantages / (batch_size * loss_mask.sum(dim=-1, keepdim=True).clamp(min=1)) # Option 3: Dr. GRPO style loss reduction to avoid length bias by normalizing by a constant elif self.cfg.trainer.algorithm.loss_reduction == "seq_mean_token_sum_norm": - batch_size = len(data) max_seq_len = self.cfg.trainer.algorithm.max_seq_len - data["advantages"] = advantages / (batch_size * max_seq_len) + normalized_advantages = advantages / (batch_size * max_seq_len) else: - # No loss reduction normalization, but still apply batch normalization if it was done - if self.cfg.trainer.algorithm.advantage_batch_normalize: - data["advantages"] = advantages + raise ValueError(f"Invalid loss reduction type: {self.cfg.trainer.algorithm.loss_reduction}") - return data + return normalized_advantages def _execute_training_step(self, model: str, data: TrainingInputBatch) -> Dict[str, float]: """ @@ -1127,21 +1132,14 @@ def _execute_training_step(self, model: str, data: TrainingInputBatch) -> Dict[s n_samples = self.cfg.generator.n_samples_per_prompt if model == "policy": mini_batch_size = self.cfg.trainer.policy_mini_batch_size * n_samples + # Normalize advantages for policy training; critic training does not need this + data = self.normalize_advantages(data, mini_batch_size) else: mini_batch_size = self.cfg.trainer.critic_mini_batch_size * n_samples all_metrics: Dict[str, List[float]] = defaultdict(list) num_mini_batches = len(data) // mini_batch_size - # iterate over mini-batches to do mini batch level normalization - for local_step in range(num_mini_batches): - start_idx = local_step * mini_batch_size - end_idx = (local_step + 1) * mini_batch_size - mini_batch = data[start_idx:end_idx] - mini_batch = self.normalize_minibatch_advantages(mini_batch) - # Copy normalized advantages back to original batch - data["advantages"][start_idx:end_idx] = mini_batch["advantages"] - # Stage full batch in object store ONCE to avoid repeated serialization data_ref = self.dispatch.stage_data(data) From 0192e8ef82c6030138915a87d902ba3d195aff1f Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Fri, 20 Mar 2026 14:11:29 -0700 Subject: [PATCH 08/18] token_mean_baseline -> token_mean_legacy Signed-off-by: Justin Yu --- skyrl/train/utils/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skyrl/train/utils/utils.py b/skyrl/train/utils/utils.py index 278d577316..445e1b9e00 100644 --- a/skyrl/train/utils/utils.py +++ b/skyrl/train/utils/utils.py @@ -282,7 +282,7 @@ def validate_cfg(cfg: SkyRLTrainConfig): assert cfg.trainer.algorithm.loss_reduction in ( "token_mean", - "token_mean_baseline", + "token_mean_legacy", "sequence_mean", "seq_mean_token_sum_norm", ), ( From 4ee0b311b94936d7b633a701b01ea2467f89421e Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Fri, 20 Mar 2026 14:37:55 -0700 Subject: [PATCH 09/18] Extract apply_loss_reduction_to_advantages_minibatch to ppo_utils and add unit tests Co-Authored-By: Claude Opus 4.6 --- skyrl/backends/skyrl_train/utils/ppo_utils.py | 52 ++++++++++ skyrl/train/trainer.py | 53 ++-------- .../skyrl_train/utils/test_ppo_utils.py | 96 +++++++++++++++++++ 3 files changed, 158 insertions(+), 43 deletions(-) diff --git a/skyrl/backends/skyrl_train/utils/ppo_utils.py b/skyrl/backends/skyrl_train/utils/ppo_utils.py index a12ce45852..905c26c7af 100644 --- a/skyrl/backends/skyrl_train/utils/ppo_utils.py +++ b/skyrl/backends/skyrl_train/utils/ppo_utils.py @@ -996,6 +996,58 @@ def reduce_loss( return (loss * loss_mask).sum() if loss_mask is not None else loss.sum() +def apply_loss_reduction_to_advantages_minibatch( + advantages: torch.Tensor, + loss_mask: torch.Tensor, + loss_reduction: str, + micro_batch_size: int, + max_seq_len: int, +) -> torch.Tensor: + """Scale advantages so that reduce_loss (a simple sum) produces the desired loss reduction. + + Args: + advantages: Advantage tensor of shape (minibatch_size, seq_len). + loss_mask: Mask of shape (minibatch_size, seq_len) indicating valid loss tokens. + loss_reduction: One of "token_mean", "token_mean_legacy", "sequence_mean", "seq_mean_token_sum_norm". + micro_batch_size: Number of sequences per micro-batch + max_seq_len: Maximum sequence length. + + Returns: + Scaled advantages tensor. + """ + batch_size = advantages.shape[0] + normalized_advantages = torch.zeros_like(advantages) + + # Option 1: token mean + if loss_reduction == "token_mean": + normalized_advantages = advantages / loss_mask.sum().clamp(min=1) + + # Option 1b: legacy token-mean that normalizes per-microbatch then averages across microbatches. + elif loss_reduction == "token_mean_legacy": + num_micro_batches = batch_size // micro_batch_size + for i in range(num_micro_batches): + start_idx = i * micro_batch_size + end_idx = (i + 1) * micro_batch_size + mb_advantages = advantages[start_idx:end_idx] + mb_loss_mask = loss_mask[start_idx:end_idx] + mb_advantages = mb_advantages / mb_loss_mask.sum().clamp(min=1) + mb_advantages /= num_micro_batches + normalized_advantages[start_idx:end_idx] = mb_advantages + + # Option 2: sequence mean + elif loss_reduction == "sequence_mean": + normalized_advantages = advantages / (batch_size * loss_mask.sum(dim=-1, keepdim=True).clamp(min=1)) + + # Option 3: Dr. GRPO style loss reduction to avoid length bias by normalizing by a constant + elif loss_reduction == "seq_mean_token_sum_norm": + normalized_advantages = advantages / (batch_size * max_seq_len) + + else: + raise ValueError(f"Invalid loss reduction type: {loss_reduction}") + + return normalized_advantages + + # NOTE (erictang000): below ported from verl @register_advantage_estimator(AdvantageEstimator.REINFORCE_PP) def compute_reinforce_plus_plus_outcome_advantage( diff --git a/skyrl/train/trainer.py b/skyrl/train/trainer.py index 6aa514c2d2..642e9f0876 100644 --- a/skyrl/train/trainer.py +++ b/skyrl/train/trainer.py @@ -33,6 +33,7 @@ from skyrl.backends.skyrl_train.utils.ppo_utils import ( AdaptiveKLController, FixedKLController, + apply_loss_reduction_to_advantages_minibatch, compute_approx_kl, get_kl_controller, ) @@ -1059,58 +1060,24 @@ def normalize_advantages(self, data: TrainingInputBatch, mini_batch_size: int) - rstd = (std / num_actions).clamp(min=1e-8).rsqrt() data["advantages"] = (advantages - mean) * rstd - # Step 2: Loss reduction normalization + # Step 2: Loss reduction normalization per mini-batch num_mini_batches = len(data) // mini_batch_size normalized_advantages = torch.zeros_like(advantages) for local_step in range(num_mini_batches): start_idx = local_step * mini_batch_size end_idx = (local_step + 1) * mini_batch_size - normalized_advantages[start_idx:end_idx] = self._normalize_minibatch_advantages(data[start_idx:end_idx]) + mini_batch = data[start_idx:end_idx] + normalized_advantages[start_idx:end_idx] = apply_loss_reduction_to_advantages_minibatch( + advantages=mini_batch["advantages"], + loss_mask=mini_batch["loss_mask"], + loss_reduction=self.cfg.trainer.algorithm.loss_reduction, + micro_batch_size=self.cfg.trainer.micro_train_batch_size_per_gpu, + max_seq_len=self.cfg.trainer.algorithm.max_seq_len, + ) data["advantages"] = normalized_advantages return data - def _normalize_minibatch_advantages(self, data: TrainingInputBatch) -> torch.Tensor: - """Normalize the advantages in the mini-batch.""" - advantages = data["advantages"] - loss_mask = data["loss_mask"] - - normalized_advantages = torch.zeros_like(advantages) - batch_size = len(data) - - # Option 1: token mean - if self.cfg.trainer.algorithm.loss_reduction == "token_mean": - normalized_advantages = advantages / loss_mask.sum().clamp(min=1) - - # Option 1b: legacy token-mean implementation which averages token mean across microbatches - elif self.cfg.trainer.algorithm.loss_reduction == "token_mean_legacy": - micro_batch_size = self.cfg.trainer.micro_train_batch_size_per_gpu - num_micro_batches = len(data) // micro_batch_size - for i in range(num_micro_batches): - start_idx = i * micro_batch_size - end_idx = (i + 1) * micro_batch_size - microbatch_advantages = advantages[start_idx:end_idx] - microbatch_loss_mask = loss_mask[start_idx:end_idx] - # Compute token-mean within each microbatch - microbatch_advantages = microbatch_advantages / microbatch_loss_mask.sum().clamp(min=1) - # Average across microbatches - microbatch_advantages /= num_micro_batches - normalized_advantages[start_idx:end_idx] = microbatch_advantages - - # Option 2: sequence mean - elif self.cfg.trainer.algorithm.loss_reduction == "sequence_mean": - normalized_advantages = advantages / (batch_size * loss_mask.sum(dim=-1, keepdim=True).clamp(min=1)) - - # Option 3: Dr. GRPO style loss reduction to avoid length bias by normalizing by a constant - elif self.cfg.trainer.algorithm.loss_reduction == "seq_mean_token_sum_norm": - max_seq_len = self.cfg.trainer.algorithm.max_seq_len - normalized_advantages = advantages / (batch_size * max_seq_len) - - else: - raise ValueError(f"Invalid loss reduction type: {self.cfg.trainer.algorithm.loss_reduction}") - - return normalized_advantages - def _execute_training_step(self, model: str, data: TrainingInputBatch) -> Dict[str, float]: """ Execute training step for FSDP strategy using forward_backward + optim_step. diff --git a/tests/backends/skyrl_train/utils/test_ppo_utils.py b/tests/backends/skyrl_train/utils/test_ppo_utils.py index 2062e8d77e..bd15e089e0 100644 --- a/tests/backends/skyrl_train/utils/test_ppo_utils.py +++ b/tests/backends/skyrl_train/utils/test_ppo_utils.py @@ -14,6 +14,7 @@ AdvantageEstimatorRegistry, FixedKLController, PolicyLossRegistry, + apply_loss_reduction_to_advantages_minibatch, compute_advantages_and_returns, compute_approx_kl, compute_gae_advantage_return, @@ -546,3 +547,98 @@ def test_func(**kwargs): finally: ray.shutdown() + + +class TestApplyLossReductionToAdvantagesMinibatch: + """Tests for apply_loss_reduction_to_advantages_minibatch.""" + + def test_token_mean(self): + advantages = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + loss_mask = torch.tensor([[1.0, 1.0, 0.0], [1.0, 1.0, 1.0]]) + # valid tokens: 1+2+4+5+6 = 18, count = 5, mean = 3.6 + scaled = apply_loss_reduction_to_advantages_minibatch( + advantages=advantages, + loss_mask=loss_mask, + loss_reduction="token_mean", + micro_batch_size=1, + max_seq_len=3, + ) + loss = reduce_loss(scaled, loss_mask) + assert torch.allclose(loss, torch.tensor(3.6)) + + def test_token_mean_all_masked(self): + """Token mean with all-zero mask should produce zero loss, not NaN.""" + advantages = torch.tensor([[1.0, 2.0]]) + loss_mask = torch.tensor([[0.0, 0.0]]) + scaled = apply_loss_reduction_to_advantages_minibatch( + advantages=advantages, + loss_mask=loss_mask, + loss_reduction="token_mean", + micro_batch_size=1, + max_seq_len=2, + ) + loss = reduce_loss(scaled, loss_mask) + assert torch.allclose(loss, torch.tensor(0.0)) + + def test_sequence_mean(self): + advantages = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) + loss_mask = torch.tensor([[1.0, 1.0], [1.0, 0.0]]) + # seq 0: token mean = (1+2)/2 = 1.5 + # seq 1: token mean = 3/1 = 3.0 + # sequence mean = (1.5 + 3.0) / 2 = 2.25 + scaled = apply_loss_reduction_to_advantages_minibatch( + advantages=advantages, + loss_mask=loss_mask, + loss_reduction="sequence_mean", + micro_batch_size=1, + max_seq_len=2, + ) + loss = reduce_loss(scaled, loss_mask) + assert torch.allclose(loss, torch.tensor(2.25)) + + def test_seq_mean_token_sum_norm(self): + advantages = torch.tensor([[2.0, 3.0], [9.0, 12.0]]) + loss_mask = torch.tensor([[1.0, 1.0], [1.0, 0.0]]) + # seq 0: (2+3)/10 = 0.5 + # seq 1: 9/10 = 0.9 + # mean across batch = (0.5 + 0.9) / 2 = 0.7 + scaled = apply_loss_reduction_to_advantages_minibatch( + advantages=advantages, + loss_mask=loss_mask, + loss_reduction="seq_mean_token_sum_norm", + micro_batch_size=1, + max_seq_len=10, + ) + loss = reduce_loss(scaled, loss_mask) + assert torch.allclose(loss, torch.tensor(0.7)) + + def test_token_mean_legacy(self): + """Legacy token mean: per-microbatch token mean, then averaged across microbatches.""" + # 4 sequences, micro_batch_size=2 -> 2 microbatches + advantages = torch.tensor([[2.0, 4.0], [6.0, 8.0], [10.0, 12.0], [14.0, 16.0]]) + loss_mask = torch.tensor([[1.0, 1.0], [1.0, 0.0], [1.0, 1.0], [1.0, 1.0]]) + # microbatch 0: token mean = (2+4+6)/3 = 4 + # microbatch 1: token mean = (10+12+14+16)/4 = 13 + # average of microbatch means = (4.0 + 13.0) / 2 = 8.5 + scaled = apply_loss_reduction_to_advantages_minibatch( + advantages=advantages, + loss_mask=loss_mask, + loss_reduction="token_mean_legacy", + micro_batch_size=2, + max_seq_len=2, + ) + loss = reduce_loss(scaled, loss_mask) + assert torch.allclose(loss, torch.tensor(8.5)) + + def test_invalid_loss_reduction_raises(self): + """Invalid loss_reduction should raise ValueError.""" + advantages = torch.tensor([[1.0]]) + loss_mask = torch.tensor([[1.0]]) + with pytest.raises(ValueError, match="Invalid loss reduction type"): + apply_loss_reduction_to_advantages_minibatch( + advantages=advantages, + loss_mask=loss_mask, + loss_reduction="invalid", + micro_batch_size=1, + max_seq_len=1, + ) From c8f06cce87be18b58c35a7b195c038e2b9a4d195 Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Wed, 25 Mar 2026 15:54:42 -0700 Subject: [PATCH 10/18] Fix metric reporting: remove dp_size scaling, separate micro-batch vs mini-batch reduction - Report unscaled loss metrics (remove * loss_scale / * dp_size) in both FSDP and Megatron workers - Rename reduce_metrics -> reduce_metrics_across_microbatches (sums _loss for gradient accumulation) - Add reduce_metrics_across_minibatches in trainer_utils (averages _loss for logging) - Use sum all-reduce for _loss keys across DP workers to reconstruct full mini-batch loss Co-Authored-By: Claude Opus 4.6 --- skyrl/backends/skyrl_train/utils/ppo_utils.py | 2 +- .../megatron/megatron_model_wrapper.py | 4 ++-- .../workers/megatron/megatron_worker.py | 4 ++-- skyrl/backends/skyrl_train/workers/worker.py | 12 +++++------ .../skyrl_train/workers/worker_utils.py | 14 ++++++++++--- skyrl/train/trainer.py | 4 ++-- skyrl/train/utils/trainer_utils.py | 21 +++++++++++++++++++ .../skyrl_train/workers/test_worker_utils.py | 2 +- 8 files changed, 46 insertions(+), 17 deletions(-) diff --git a/skyrl/backends/skyrl_train/utils/ppo_utils.py b/skyrl/backends/skyrl_train/utils/ppo_utils.py index 905c26c7af..1987aa7650 100644 --- a/skyrl/backends/skyrl_train/utils/ppo_utils.py +++ b/skyrl/backends/skyrl_train/utils/ppo_utils.py @@ -1003,7 +1003,7 @@ def apply_loss_reduction_to_advantages_minibatch( micro_batch_size: int, max_seq_len: int, ) -> torch.Tensor: - """Scale advantages so that reduce_loss (a simple sum) produces the desired loss reduction. + """Scale advantages so that summing produces the desired loss reduction. Args: advantages: Advantage tensor of shape (minibatch_size, seq_len). diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py index 72922029fd..79ab0e15c5 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py @@ -368,8 +368,8 @@ def loss_func(logits, data): ) metrics = { - "final_loss": unscaled_loss.detach().item() * dp_size, - "policy_loss": policy_loss.detach().item() * dp_size, + "final_loss": unscaled_loss.detach().item(), + "policy_loss": policy_loss.detach().item(), "policy_entropy": entropy.detach().item(), "policy_kl": kl_loss.detach().item(), "loss_fn_outputs": loss_fn_outputs, diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py index 8675d0a711..cdc7a24cb9 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py @@ -47,7 +47,7 @@ from skyrl.backends.skyrl_train.workers.worker_utils import ( BatchIterator, all_reduce_metrics, - reduce_metrics, + reduce_metrics_across_microbatches, ) from skyrl.env_vars import SKYRL_WORKER_NCCL_TIMEOUT_IN_S from skyrl.train.config.config import MegatronDDPConfig, get_config_as_dict @@ -730,7 +730,7 @@ def forward_backward( # Reduce and all-reduce metrics across DP ranks only # (metrics should be identical within DP groups, i.e., across TP/PP/SP ranks) - status = reduce_metrics(dict(all_metrics)) + status = reduce_metrics_across_microbatches(dict(all_metrics)) status["policy_lr"] = self.optimizer.param_groups[0]["lr"] group = mpu.get_data_parallel_group(with_context_parallel=True) status = all_reduce_metrics(status, self.strategy, group=group) diff --git a/skyrl/backends/skyrl_train/workers/worker.py b/skyrl/backends/skyrl_train/workers/worker.py index 4356964d8c..fd1462ce89 100644 --- a/skyrl/backends/skyrl_train/workers/worker.py +++ b/skyrl/backends/skyrl_train/workers/worker.py @@ -52,7 +52,7 @@ from skyrl.backends.skyrl_train.workers.worker_utils import ( BatchIterator, all_reduce_metrics, - reduce_metrics, + reduce_metrics_across_microbatches, ) from skyrl.env_vars import ( _SKYRL_USE_NEW_INFERENCE, @@ -715,7 +715,7 @@ def forward_backward( all_metrics[k].append(v) # reduce metrics across micro batches (sum, mean, min, max) - result = reduce_metrics(dict(all_metrics)) + result = reduce_metrics_across_microbatches(dict(all_metrics)) # all reduce metrics across DP workers dp_group = self.device_mesh.get_group("dp") @@ -871,7 +871,7 @@ def _forward_backward_micro( ) status = { - "loss": loss.item(), + "loss": unscaled_loss.item(), "response_length": num_actions, "lr": self.scheduler.get_last_lr()[0], "loss_fn_outputs": loss_fn_outputs, @@ -928,8 +928,8 @@ def _forward_backward_micro( ) status = { - "final_loss": loss.item(), - "policy_loss": policy_loss.item() * loss_scale, + "final_loss": unscaled_loss.item(), + "policy_loss": policy_loss.item(), "policy_entropy": entropy.item(), "response_length": num_actions, "policy_lr": self.scheduler.get_last_lr()[0], @@ -1070,7 +1070,7 @@ def forward_backward(self, data: TrainingInputBatch) -> Dict[str, float]: for k, v in metrics.items(): all_metrics[k].append(v) - return reduce_metrics(dict(all_metrics)) + return reduce_metrics_across_microbatches(dict(all_metrics)) def forward_backward_from_staged(self, data: TrainingInputBatch, start_idx: int, end_idx: int) -> Dict[str, float]: """ diff --git a/skyrl/backends/skyrl_train/workers/worker_utils.py b/skyrl/backends/skyrl_train/workers/worker_utils.py index ba02929da3..7322d967d2 100644 --- a/skyrl/backends/skyrl_train/workers/worker_utils.py +++ b/skyrl/backends/skyrl_train/workers/worker_utils.py @@ -6,9 +6,17 @@ from skyrl.train.dataset.replay_buffer import Experience -def reduce_metrics(metrics: Dict[str, List[float]]) -> Dict[str, float]: - """ - Reduce scalar metrics from a list of entries per key by averaging. +def reduce_metrics_across_microbatches(metrics: Dict[str, List[float]]) -> Dict[str, float]: + """Reduce metrics across micro-batches within a single mini-batch. + + NOTE: Metrics ending in `_loss` are summed, not averaged, because the scaling + is already done at the advantage level. + See `apply_loss_reduction_to_advantages_minibatch` for more details. + + + Args: + metrics: Dictionary of metrics with keys as metric names and values as lists of metric values. + The list of values corresponds to micro-batches within a single mini-batch. """ reduced_metrics = dict() for k, v in metrics.items(): diff --git a/skyrl/train/trainer.py b/skyrl/train/trainer.py index 642e9f0876..95bf9d33b8 100644 --- a/skyrl/train/trainer.py +++ b/skyrl/train/trainer.py @@ -40,7 +40,7 @@ from skyrl.backends.skyrl_train.utils.torch_utils import masked_mean from skyrl.backends.skyrl_train.workers.worker import PPORayActorGroup from skyrl.backends.skyrl_train.workers.worker_dispatch import WorkerDispatch -from skyrl.backends.skyrl_train.workers.worker_utils import reduce_metrics +from skyrl.train.utils.trainer_utils import reduce_metrics_across_minibatches from skyrl.env_vars import SKYRL_RAY_PG_TIMEOUT_IN_S from skyrl.train.config import SkyRLTrainConfig from skyrl.train.dataset import PromptDataset @@ -1129,7 +1129,7 @@ def _execute_training_step(self, model: str, data: TrainingInputBatch) -> Dict[s # Reduce metrics across all mini-batches and epochs # pop out loss_fn_outputs since it's not a scalar metric and to avoid logging it all_metrics.pop("loss_fn_outputs", None) - reduced_metrics = reduce_metrics(all_metrics) + reduced_metrics = reduce_metrics_across_minibatches(all_metrics) return reduced_metrics def train_critic_and_policy(self, data: TrainingInputBatch): diff --git a/skyrl/train/utils/trainer_utils.py b/skyrl/train/utils/trainer_utils.py index 2d4e6e24b1..5328e0b73b 100644 --- a/skyrl/train/utils/trainer_utils.py +++ b/skyrl/train/utils/trainer_utils.py @@ -708,3 +708,24 @@ def get_rope_scaling_config(trainer_cfg: TrainerConfig) -> dict[str, Any]: def get_rope_theta_config(trainer_cfg: TrainerConfig) -> int | None: return trainer_cfg.rope_theta + + +def reduce_metrics_across_minibatches(metrics: Dict[str, List[float]]) -> Dict[str, float]: + """ + Reduce metrics across mini-batches and epochs for logging. + All metrics (including _loss) are averaged, since the worker-level reduction + already handles summing _loss keys across micro-batches. + """ + reduced_metrics = dict() + for k, v in metrics.items(): + assert len(v) > 0, f"No metrics for key {k}" + if not all(isinstance(x, (int, float)) for x in v): + print(f"Metrics for key {k} are not all numbers: {v}") + continue + if k.endswith("_max"): + reduced_metrics[k] = max(v) + elif k.endswith("_min"): + reduced_metrics[k] = min(v) + else: + reduced_metrics[k] = sum(v) / len(v) + return reduced_metrics diff --git a/tests/backends/skyrl_train/workers/test_worker_utils.py b/tests/backends/skyrl_train/workers/test_worker_utils.py index f897f65b80..fd3a3ce564 100644 --- a/tests/backends/skyrl_train/workers/test_worker_utils.py +++ b/tests/backends/skyrl_train/workers/test_worker_utils.py @@ -10,7 +10,7 @@ from skyrl.backends.skyrl_train.workers.worker_utils import ( all_reduce_metrics, - reduce_metrics, + reduce_metrics_across_microbatches as reduce_metrics, ) From 2c133150503c0f04e68f392426765c9982689273 Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Fri, 27 Mar 2026 11:53:35 -0700 Subject: [PATCH 11/18] Fix critic metric reporting: explicit sum_loss_metrics flag for reduce_metrics Co-Authored-By: Claude Opus 4.6 (1M context) --- .../workers/megatron/megatron_worker.py | 6 +- skyrl/backends/skyrl_train/workers/worker.py | 19 +++-- .../skyrl_train/workers/worker_utils.py | 35 ++++++-- .../skyrl_train/workers/test_worker_utils.py | 82 +++++++------------ 4 files changed, 70 insertions(+), 72 deletions(-) diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py index cdc7a24cb9..e591b236f5 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py @@ -47,7 +47,7 @@ from skyrl.backends.skyrl_train.workers.worker_utils import ( BatchIterator, all_reduce_metrics, - reduce_metrics_across_microbatches, + reduce_metrics, ) from skyrl.env_vars import SKYRL_WORKER_NCCL_TIMEOUT_IN_S from skyrl.train.config.config import MegatronDDPConfig, get_config_as_dict @@ -730,10 +730,10 @@ def forward_backward( # Reduce and all-reduce metrics across DP ranks only # (metrics should be identical within DP groups, i.e., across TP/PP/SP ranks) - status = reduce_metrics_across_microbatches(dict(all_metrics)) + status = reduce_metrics(all_metrics, sum_loss_metrics=True) status["policy_lr"] = self.optimizer.param_groups[0]["lr"] group = mpu.get_data_parallel_group(with_context_parallel=True) - status = all_reduce_metrics(status, self.strategy, group=group) + status = all_reduce_metrics(status, self.strategy, group=group, sum_loss_metrics=True) # Add loss_fn_outputs back (not reduced, kept as list) if all_loss_fn_outputs: diff --git a/skyrl/backends/skyrl_train/workers/worker.py b/skyrl/backends/skyrl_train/workers/worker.py index fd1462ce89..400639d799 100644 --- a/skyrl/backends/skyrl_train/workers/worker.py +++ b/skyrl/backends/skyrl_train/workers/worker.py @@ -52,7 +52,7 @@ from skyrl.backends.skyrl_train.workers.worker_utils import ( BatchIterator, all_reduce_metrics, - reduce_metrics_across_microbatches, + reduce_metrics, ) from skyrl.env_vars import ( _SKYRL_USE_NEW_INFERENCE, @@ -714,12 +714,12 @@ def forward_backward( for k, v in metrics.items(): all_metrics[k].append(v) - # reduce metrics across micro batches (sum, mean, min, max) - result = reduce_metrics_across_microbatches(dict(all_metrics)) + # reduce metrics across micro batches + result = reduce_metrics(all_metrics, sum_loss_metrics=True) # all reduce metrics across DP workers dp_group = self.device_mesh.get_group("dp") - result = all_reduce_metrics(result, self.strategy, group=dp_group) + result = all_reduce_metrics(result, self.strategy, group=dp_group, sum_loss_metrics=True) # Add back loss_fn_outputs (concatenated across micro-batches) if all_loss_fn_outputs: @@ -1070,7 +1070,13 @@ def forward_backward(self, data: TrainingInputBatch) -> Dict[str, float]: for k, v in metrics.items(): all_metrics[k].append(v) - return reduce_metrics_across_microbatches(dict(all_metrics)) + # reduce metrics across micro batches + result = reduce_metrics(all_metrics) + + # all reduce metrics across DP workers + result = all_reduce_metrics(result, self.strategy) + + return result def forward_backward_from_staged(self, data: TrainingInputBatch, start_idx: int, end_idx: int) -> Dict[str, float]: """ @@ -1141,9 +1147,6 @@ def _forward_backward_micro(self, experience: Experience) -> Dict[str, float]: "critic_lr": self.scheduler.get_last_lr()[0], } - # All-reduce metrics across DP workers - status = all_reduce_metrics(status, self.strategy) - return status def optim_step(self) -> float: diff --git a/skyrl/backends/skyrl_train/workers/worker_utils.py b/skyrl/backends/skyrl_train/workers/worker_utils.py index 7322d967d2..f399c32a5c 100644 --- a/skyrl/backends/skyrl_train/workers/worker_utils.py +++ b/skyrl/backends/skyrl_train/workers/worker_utils.py @@ -6,17 +6,19 @@ from skyrl.train.dataset.replay_buffer import Experience -def reduce_metrics_across_microbatches(metrics: Dict[str, List[float]]) -> Dict[str, float]: - """Reduce metrics across micro-batches within a single mini-batch. +def reduce_metrics(metrics: Dict[str, List[float]], sum_loss_metrics: bool = False) -> Dict[str, float]: + """Reduce scalar metrics from a list of entries per key with the appropriate reduction. - NOTE: Metrics ending in `_loss` are summed, not averaged, because the scaling - is already done at the advantage level. - See `apply_loss_reduction_to_advantages_minibatch` for more details. + Default reduction is mean. Metrics ending in `_min` or `_max` use min/max respectively. + If sum_loss_metrics is True, metrics ending in `_loss` are summed instead of averaged. + This should be used if the scaling is already done at the advantage level. + See `apply_loss_reduction_to_advantages_minibatch` for more details. Args: metrics: Dictionary of metrics with keys as metric names and values as lists of metric values. The list of values corresponds to micro-batches within a single mini-batch. + sum_loss_metrics: If True, metrics ending in `_loss` are summed (for pre-scaled policy losses). """ reduced_metrics = dict() for k, v in metrics.items(): @@ -28,18 +30,33 @@ def reduce_metrics_across_microbatches(metrics: Dict[str, List[float]]) -> Dict[ reduced_metrics[k] = max(v) elif k.endswith("_min"): reduced_metrics[k] = min(v) - elif k.endswith("_loss"): + elif sum_loss_metrics and k.endswith("_loss"): reduced_metrics[k] = sum(v) else: reduced_metrics[k] = sum(v) / len(v) return reduced_metrics -def all_reduce_metrics(metrics: Dict[str, List[float]], strategy: DistributedStrategy, group=None) -> Dict[str, float]: - """All reduce metrics across all processes.""" +def all_reduce_metrics( + metrics: Dict[str, float], + strategy: DistributedStrategy, + group=None, + sum_loss_metrics: bool = False, +) -> Dict[str, float]: + """All reduce metrics across all processes. + + Default reduction is mean. Metrics ending in `_min` or `_max` use min/max respectively. + If sum_loss_metrics is True, metrics ending in `_loss` are summed instead of averaged. + + Args: + metrics: Dictionary of metric name to scalar value. + strategy: Distributed strategy for all-reduce. + group: Process group for all-reduce. + sum_loss_metrics: If True, metrics ending in `_loss` are summed (for pre-scaled policy losses). + """ min_metrics = {k: v for k, v in metrics.items() if k.endswith("_min")} max_metrics = {k: v for k, v in metrics.items() if k.endswith("_max")} - sum_metrics = {k: v for k, v in metrics.items() if k.endswith("_loss")} + sum_metrics = {k: v for k, v in metrics.items() if sum_loss_metrics and k.endswith("_loss")} mean_metrics = { k: v for k, v in metrics.items() if k not in min_metrics and k not in max_metrics and k not in sum_metrics } diff --git a/tests/backends/skyrl_train/workers/test_worker_utils.py b/tests/backends/skyrl_train/workers/test_worker_utils.py index fd3a3ce564..f40dfc6709 100644 --- a/tests/backends/skyrl_train/workers/test_worker_utils.py +++ b/tests/backends/skyrl_train/workers/test_worker_utils.py @@ -10,7 +10,7 @@ from skyrl.backends.skyrl_train.workers.worker_utils import ( all_reduce_metrics, - reduce_metrics_across_microbatches as reduce_metrics, + reduce_metrics, ) @@ -33,11 +33,17 @@ def test_reduce_metrics_mean_default(self): result = reduce_metrics(metrics) assert result["entropy"] == 2.0 - def test_reduce_metrics_loss_sum(self): - """Keys ending in _loss should use sum reduction.""" + def test_reduce_metrics_loss_default_mean(self): + """_loss keys default to mean when sum_loss_metrics=False.""" metrics = {"policy_loss": [1.0, 2.0, 3.0]} result = reduce_metrics(metrics) - assert result["policy_loss"] == 6.0 # sum of [1, 2, 3] + assert result["policy_loss"] == 2.0 + + def test_reduce_metrics_sum_loss_metrics(self): + """_loss keys are summed when sum_loss_metrics=True.""" + metrics = {"policy_loss": [1.0, 2.0, 3.0]} + result = reduce_metrics(metrics, sum_loss_metrics=True) + assert result["policy_loss"] == 6.0 def test_reduce_metrics_mixed(self): """Test mixed metric types are reduced correctly.""" @@ -47,7 +53,7 @@ def test_reduce_metrics_mixed(self): "policy_loss": [1.0, 3.0], "entropy": [1.0, 3.0], } - result = reduce_metrics(metrics) + result = reduce_metrics(metrics, sum_loss_metrics=True) assert result["is_ratio_max"] == 10.0 assert result["is_ratio_min"] == 0.5 assert result["policy_loss"] == 4.0 # sum @@ -90,7 +96,7 @@ def mock_all_reduce(d, op, group=None): "entropy": 0.5, } - _ = all_reduce_metrics(metrics, strategy) + _ = all_reduce_metrics(metrics, strategy, sum_loss_metrics=True) # Verify all_reduce was called 4 times assert strategy.all_reduce.call_count == 4 @@ -110,7 +116,7 @@ def mock_all_reduce(d, op, group=None): mean_call = [c for c in ops_and_keys if c[0] == "mean"][0] assert mean_call[1] == {"entropy"} - # Verify sum metrics (_loss suffix) + # Verify sum metrics (explicit sum_keys) sum_call = [c for c in ops_and_keys if c[0] == "sum"][0] assert sum_call[1] == {"policy_loss"} @@ -122,6 +128,22 @@ def mock_all_reduce(d, op, group=None): max_call = [c for c in ops_and_keys if c[0] == "max"][0] assert max_call[1] == {"is_ratio_max"} + def test_all_reduce_metrics_average_loss_metrics(self): + """Verify _loss keys are averaged when sum_loss_metrics=False.""" + strategy = MagicMock() + + # Mock all_reduce to return the input dict unchanged but track calls + def mock_all_reduce(d, op, group=None): + return {k: v for k, v in d.items()} + + strategy.all_reduce.side_effect = mock_all_reduce + + metrics = {"policy_loss": 1.5} + result = all_reduce_metrics(metrics, strategy, sum_loss_metrics=False) + assert result["policy_loss"] == 1.5 + assert strategy.all_reduce.call_count == 1 + assert strategy.all_reduce.call_args[0][1] == "mean" + def test_all_reduce_metrics_returns_merged_results(self): """Verify results from all reductions are merged correctly.""" strategy = MagicMock() @@ -147,7 +169,7 @@ def mock_all_reduce(d, op, group=None): "entropy": 0.5, } - result = all_reduce_metrics(metrics, strategy) + result = all_reduce_metrics(metrics, strategy, sum_loss_metrics=True) # Check all keys are present assert "is_ratio_max" in result @@ -160,47 +182,3 @@ def mock_all_reduce(d, op, group=None): assert result["is_ratio_min"] == 0.05 # 0.1 / 2 (min op) assert result["policy_loss"] == 6.0 # sum op assert result["entropy"] == 1.0 # 0.5 * 2 (mean op) - - def test_all_reduce_metrics_only_max(self): - """Test with only _max metrics.""" - strategy = MagicMock() - strategy.all_reduce.side_effect = lambda d, op, group=None: d - - metrics = {"loss_max": 5.0, "ratio_max": 10.0} - - result = all_reduce_metrics(metrics, strategy) - - assert result == {"loss_max": 5.0, "ratio_max": 10.0} - - def test_all_reduce_metrics_only_min(self): - """Test with only _min metrics.""" - strategy = MagicMock() - strategy.all_reduce.side_effect = lambda d, op, group=None: d - - metrics = {"loss_min": 0.1, "ratio_min": 0.01} - - result = all_reduce_metrics(metrics, strategy) - - assert result == {"loss_min": 0.1, "ratio_min": 0.01} - - def test_all_reduce_metrics_only_mean(self): - """Test with only mean metrics (no _max/_min/_loss suffix).""" - strategy = MagicMock() - strategy.all_reduce.side_effect = lambda d, op, group=None: d - - metrics = {"entropy": 0.5, "kl_div": 1.5} - - result = all_reduce_metrics(metrics, strategy) - - assert result == {"entropy": 0.5, "kl_div": 1.5} - - def test_all_reduce_metrics_only_sum(self): - """Test with only _loss metrics (sum reduction).""" - strategy = MagicMock() - strategy.all_reduce.side_effect = lambda d, op, group=None: d - - metrics = {"policy_loss": 1.5, "value_loss": 0.5} - - result = all_reduce_metrics(metrics, strategy) - - assert result == {"policy_loss": 1.5, "value_loss": 0.5} From 14ba02e3a32cdc179e02317f6ae15fcb06661a7a Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Fri, 27 Mar 2026 11:57:13 -0700 Subject: [PATCH 12/18] Remove reduce_metrics_across_minibatches, reuse reduce_metrics Co-Authored-By: Claude Opus 4.6 (1M context) --- skyrl/train/trainer.py | 4 ++-- skyrl/train/utils/trainer_utils.py | 19 ------------------- 2 files changed, 2 insertions(+), 21 deletions(-) diff --git a/skyrl/train/trainer.py b/skyrl/train/trainer.py index 95bf9d33b8..8fc26552b4 100644 --- a/skyrl/train/trainer.py +++ b/skyrl/train/trainer.py @@ -40,7 +40,7 @@ from skyrl.backends.skyrl_train.utils.torch_utils import masked_mean from skyrl.backends.skyrl_train.workers.worker import PPORayActorGroup from skyrl.backends.skyrl_train.workers.worker_dispatch import WorkerDispatch -from skyrl.train.utils.trainer_utils import reduce_metrics_across_minibatches +from skyrl.backends.skyrl_train.workers.worker_utils import reduce_metrics from skyrl.env_vars import SKYRL_RAY_PG_TIMEOUT_IN_S from skyrl.train.config import SkyRLTrainConfig from skyrl.train.dataset import PromptDataset @@ -1129,7 +1129,7 @@ def _execute_training_step(self, model: str, data: TrainingInputBatch) -> Dict[s # Reduce metrics across all mini-batches and epochs # pop out loss_fn_outputs since it's not a scalar metric and to avoid logging it all_metrics.pop("loss_fn_outputs", None) - reduced_metrics = reduce_metrics_across_minibatches(all_metrics) + reduced_metrics = reduce_metrics(all_metrics, sum_loss_metrics=False) return reduced_metrics def train_critic_and_policy(self, data: TrainingInputBatch): diff --git a/skyrl/train/utils/trainer_utils.py b/skyrl/train/utils/trainer_utils.py index 5328e0b73b..1375e6517a 100644 --- a/skyrl/train/utils/trainer_utils.py +++ b/skyrl/train/utils/trainer_utils.py @@ -710,22 +710,3 @@ def get_rope_theta_config(trainer_cfg: TrainerConfig) -> int | None: return trainer_cfg.rope_theta -def reduce_metrics_across_minibatches(metrics: Dict[str, List[float]]) -> Dict[str, float]: - """ - Reduce metrics across mini-batches and epochs for logging. - All metrics (including _loss) are averaged, since the worker-level reduction - already handles summing _loss keys across micro-batches. - """ - reduced_metrics = dict() - for k, v in metrics.items(): - assert len(v) > 0, f"No metrics for key {k}" - if not all(isinstance(x, (int, float)) for x in v): - print(f"Metrics for key {k} are not all numbers: {v}") - continue - if k.endswith("_max"): - reduced_metrics[k] = max(v) - elif k.endswith("_min"): - reduced_metrics[k] = min(v) - else: - reduced_metrics[k] = sum(v) / len(v) - return reduced_metrics From 717c3a7ef2be5f7f8d8b9597bb0fc4396cba6347 Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Fri, 27 Mar 2026 12:28:27 -0700 Subject: [PATCH 13/18] add some comments about sum metrics Signed-off-by: Justin Yu --- .../backends/skyrl_train/workers/megatron/megatron_worker.py | 3 ++- skyrl/backends/skyrl_train/workers/worker.py | 5 ++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py index 2710604f2b..3b1ec841ce 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py @@ -730,8 +730,9 @@ def forward_backward( for k, v in metrics.items(): all_metrics[k].append(v) - # Reduce and all-reduce metrics across DP ranks only + # Reduce across microbatches and all-reduce metrics across DP ranks # (metrics should be identical within DP groups, i.e., across TP/PP/SP ranks) + # NOTE: Sum loss metrics because scaling is already applied at the advantage level status = reduce_metrics(all_metrics, sum_loss_metrics=True) status["policy_lr"] = self.optimizer.param_groups[0]["lr"] group = mpu.get_data_parallel_group(with_context_parallel=True) diff --git a/skyrl/backends/skyrl_train/workers/worker.py b/skyrl/backends/skyrl_train/workers/worker.py index 9a7fd5bcb0..2196559797 100644 --- a/skyrl/backends/skyrl_train/workers/worker.py +++ b/skyrl/backends/skyrl_train/workers/worker.py @@ -714,10 +714,9 @@ def forward_backward( for k, v in metrics.items(): all_metrics[k].append(v) - # reduce metrics across micro batches + # Reduce across microbatches and all-reduce metrics across DP ranks + # NOTE: Sum loss metrics because scaling is already applied at the advantage level result = reduce_metrics(all_metrics, sum_loss_metrics=True) - - # all reduce metrics across DP workers dp_group = self.device_mesh.get_group("dp") result = all_reduce_metrics(result, self.strategy, group=dp_group, sum_loss_metrics=True) From 661f5d8d599326a9dc75fe8589be76494a2b63f1 Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Fri, 27 Mar 2026 12:46:46 -0700 Subject: [PATCH 14/18] add clarifying comments and rename loss_scale Signed-off-by: Justin Yu --- .../megatron/megatron_model_wrapper.py | 20 ++++++++++--------- skyrl/backends/skyrl_train/workers/worker.py | 11 ++++++---- 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py index 87b1b27fbc..73581ecb20 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py @@ -252,12 +252,14 @@ def loss_func(logits, data): tp_grp = mpu.get_tensor_model_parallel_group() tp_rank = mpu.get_tensor_model_parallel_rank() - # Megatron's pipeline parallel forward_backward_func internally divides loss by num_microbatches - # https://github.com/NVIDIA/Megatron-LM/blob/core_v0.15.2/megatron/core/pipeline_parallel/schedules.py#L248 - # we want to maintain a sum of losses across all micro batches, so we reverse this division. - # we additionally multiply by the data parallelism size to undo the DDP all-reduce mean - # https://github.com/NVIDIA/Megatron-LM/blob/core_v0.15.2/megatron/core/distributed/distributed_data_parallel.py#L285 - loss_scale = num_microbatches * dp_size + # Policy losses are pre-scaled to achieve the correct loss_reduction when summing across the entire minibatch + # (see `apply_loss_reduction_to_advantages_minibatch`). + # Megatron divides loss by num_microbatches + # (https://github.com/NVIDIA/Megatron-LM/blob/core_v0.15.2/megatron/core/pipeline_parallel/schedules.py#L248) + # and the data parallel all-reduce averages gradients across dp_size + # (https://github.com/NVIDIA/Megatron-LM/blob/core_v0.15.2/megatron/core/distributed/distributed_data_parallel.py#L285) + # so we multiply by both factors to recover the correct sum reduction. + grad_sum_correction_factor = num_microbatches * dp_size # temperature normalization if temperature != 1.0: @@ -289,14 +291,14 @@ def loss_func(logits, data): # SFT path: cross_entropy loss (negative log likelihood) if resolved_loss_name == "cross_entropy": unscaled_loss = policy_loss - loss = unscaled_loss * loss_scale + loss = unscaled_loss * grad_sum_correction_factor # Compute elementwise loss for Tinker API (per-token NLL) with torch.no_grad(): elementwise_loss = -action_log_probs if loss_mask is not None: elementwise_loss = elementwise_loss * loss_mask - elementwise_loss = elementwise_loss * loss_scale + elementwise_loss = elementwise_loss * grad_sum_correction_factor # Build per-sequence loss_fn_outputs batch_size = action_log_probs.shape[0] @@ -352,7 +354,7 @@ def loss_func(logits, data): kl_loss_term = kl_loss * loss_config.kl_loss_coef unscaled_loss = policy_loss + kl_loss_term - entropy_loss_term - loss = unscaled_loss * loss_scale + loss = unscaled_loss * grad_sum_correction_factor # Build per-sequence loss_fn_outputs with logprobs. batch_size = action_log_probs.shape[0] diff --git a/skyrl/backends/skyrl_train/workers/worker.py b/skyrl/backends/skyrl_train/workers/worker.py index 2196559797..ae6e6fd012 100644 --- a/skyrl/backends/skyrl_train/workers/worker.py +++ b/skyrl/backends/skyrl_train/workers/worker.py @@ -806,12 +806,15 @@ def _forward_backward_micro( rollout_logprobs=rollout_action_logprobs, ) - loss_scale = self.mesh_rank.dp_size + # DP all-reduce averages gradients, but policy losses are pre-scaled sums + # (see `apply_loss_reduction_to_advantages_minibatch`), so we multiply by + # dp_size to recover the correct sum reduction across workers. + grad_sum_correction_factor = self.mesh_rank.dp_size # SFT path: skip KL/entropy terms, return per-token outputs for Tinker API if resolved_loss_name == "cross_entropy": unscaled_loss = policy_loss - loss = unscaled_loss * loss_scale + loss = unscaled_loss * grad_sum_correction_factor self.strategy.backward(loss, self.model, self.optimizer) # Compute elementwise loss for Tinker API (per-token NLL) @@ -819,7 +822,7 @@ def _forward_backward_micro( elementwise_loss = -action_log_probs if loss_mask is not None: elementwise_loss = elementwise_loss * loss_mask - elementwise_loss = elementwise_loss * loss_scale + elementwise_loss = elementwise_loss * grad_sum_correction_factor # Build per-sequence loss_fn_outputs (matches Tinker's ForwardBackwardOutput structure) # Trim to actual response length per sample (Tinker expects variable-length arrays @@ -878,7 +881,7 @@ def _forward_backward_micro( kl_loss_term = kl_loss * self.cfg.algorithm.kl_loss_coef unscaled_loss = policy_loss + kl_loss_term - entropy_loss_term - loss = unscaled_loss * loss_scale + loss = unscaled_loss * grad_sum_correction_factor self.strategy.backward(loss, self.model, self.optimizer) # Build per-sequence loss_fn_outputs with logprobs. From 5cc95a1b88be62da8f6b7c6ab44ccccbccd789e7 Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Fri, 27 Mar 2026 12:56:31 -0700 Subject: [PATCH 15/18] no_grad for safety and make private Signed-off-by: Justin Yu --- skyrl/train/trainer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/skyrl/train/trainer.py b/skyrl/train/trainer.py index a765141636..f01eed964c 100644 --- a/skyrl/train/trainer.py +++ b/skyrl/train/trainer.py @@ -1054,7 +1054,8 @@ def apply_reward_kl_penalty( return data - def normalize_advantages(self, data: TrainingInputBatch, mini_batch_size: int) -> TrainingInputBatch: + @torch.no_grad() + def _normalize_advantages(self, data: TrainingInputBatch, mini_batch_size: int) -> TrainingInputBatch: advantages = data["advantages"] response_mask = data["response_mask"] @@ -1106,7 +1107,7 @@ def _execute_training_step(self, model: str, data: TrainingInputBatch) -> Dict[s if model == "policy": mini_batch_size = self.cfg.trainer.policy_mini_batch_size * n_samples # Normalize advantages for policy training; critic training does not need this - data = self.normalize_advantages(data, mini_batch_size) + data = self._normalize_advantages(data, mini_batch_size) else: mini_batch_size = self.cfg.trainer.critic_mini_batch_size * n_samples From ce8f6aa9474199c120df134d8429bd88ccf99801 Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Fri, 27 Mar 2026 13:06:46 -0700 Subject: [PATCH 16/18] remove outdated comments about loss reduction type in sapo tests Signed-off-by: Justin Yu --- skyrl/backends/skyrl_train/utils/ppo_utils.py | 1 - tests/train/algorithms/test_losses.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/skyrl/backends/skyrl_train/utils/ppo_utils.py b/skyrl/backends/skyrl_train/utils/ppo_utils.py index a5a9d7f110..b0b038b20d 100644 --- a/skyrl/backends/skyrl_train/utils/ppo_utils.py +++ b/skyrl/backends/skyrl_train/utils/ppo_utils.py @@ -628,7 +628,6 @@ def gate_function(x, tau): ) loss_metrics.update(off_policy_metrics) - # for SAPO, we need to aggregate the loss at the sequence level (seq-mean-token-mean) loss = reduce_loss(loss, loss_mask) return loss, loss_metrics diff --git a/tests/train/algorithms/test_losses.py b/tests/train/algorithms/test_losses.py index af6b41dc5d..dab3434ebc 100644 --- a/tests/train/algorithms/test_losses.py +++ b/tests/train/algorithms/test_losses.py @@ -415,7 +415,7 @@ def test_sapo_policy_loss_basic(): # Ratios ≈ [exp(-0.5), exp(0.2), exp(-0.1)] ≈ [0.6065, 1.2214, 0.9048] log_probs = torch.tensor([[-1.5, -0.8, -1.1]], device=device) - # SAPO config: uses sequence_mean reduction and distinct tau_pos / tau_neg + # SAPO config with distinct tau_pos / tau_neg config = AlgorithmConfig( policy_loss_type="sapo", max_seq_len=4, From 1a60bb5aed1c44494a3f051595391adeb978fb41 Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Fri, 27 Mar 2026 13:14:56 -0700 Subject: [PATCH 17/18] fix test Signed-off-by: Justin Yu --- .../skyrl_train/workers/test_worker_utils.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/tests/backends/skyrl_train/workers/test_worker_utils.py b/tests/backends/skyrl_train/workers/test_worker_utils.py index f40dfc6709..bcb8ce0a7b 100644 --- a/tests/backends/skyrl_train/workers/test_worker_utils.py +++ b/tests/backends/skyrl_train/workers/test_worker_utils.py @@ -138,11 +138,19 @@ def mock_all_reduce(d, op, group=None): strategy.all_reduce.side_effect = mock_all_reduce - metrics = {"policy_loss": 1.5} - result = all_reduce_metrics(metrics, strategy, sum_loss_metrics=False) - assert result["policy_loss"] == 1.5 - assert strategy.all_reduce.call_count == 1 - assert strategy.all_reduce.call_args[0][1] == "mean" + metrics = {"critic_loss": 1.5, "entropy": 0.5} + _ = all_reduce_metrics(metrics, strategy, sum_loss_metrics=False) + + # Both should be mean-reduced (critic_loss is NOT summed without the flag) + ops_and_keys = [] + for args, kwargs in strategy.all_reduce.call_args_list: + data_dict = args[0] + op = kwargs.get("op", args[1]) + if data_dict: + ops_and_keys.append((op, set(data_dict.keys()))) + + mean_call = [c for c in ops_and_keys if c[0] == "mean"][0] + assert mean_call[1] == {"critic_loss", "entropy"} def test_all_reduce_metrics_returns_merged_results(self): """Verify results from all reductions are merged correctly.""" From c5feb83b38f4635c7fc705c2bb192a7d6ad16947 Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Fri, 27 Mar 2026 14:02:03 -0700 Subject: [PATCH 18/18] fix test Signed-off-by: Justin Yu --- .../skyrl_train/workers/test_worker_utils.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/tests/backends/skyrl_train/workers/test_worker_utils.py b/tests/backends/skyrl_train/workers/test_worker_utils.py index bcb8ce0a7b..2c01499faa 100644 --- a/tests/backends/skyrl_train/workers/test_worker_utils.py +++ b/tests/backends/skyrl_train/workers/test_worker_utils.py @@ -79,7 +79,8 @@ def test_reduce_metrics_empty_raises(self): class TestAllReduceMetrics: - def test_all_reduce_metrics_separates_by_suffix(self): + @pytest.mark.parametrize("sum_loss_metrics", [True, False]) + def test_all_reduce_metrics_separates_by_suffix(self, sum_loss_metrics): """Verify metrics are correctly separated by suffix and reduced with correct ops.""" strategy = MagicMock() @@ -96,7 +97,7 @@ def mock_all_reduce(d, op, group=None): "entropy": 0.5, } - _ = all_reduce_metrics(metrics, strategy, sum_loss_metrics=True) + _ = all_reduce_metrics(metrics, strategy, sum_loss_metrics=sum_loss_metrics) # Verify all_reduce was called 4 times assert strategy.all_reduce.call_count == 4 @@ -114,11 +115,17 @@ def mock_all_reduce(d, op, group=None): # Verify mean metrics (entropy) mean_call = [c for c in ops_and_keys if c[0] == "mean"][0] - assert mean_call[1] == {"entropy"} + if sum_loss_metrics: + assert mean_call[1] == {"entropy"} + else: + assert mean_call[1] == {"entropy", "policy_loss"} # Verify sum metrics (explicit sum_keys) sum_call = [c for c in ops_and_keys if c[0] == "sum"][0] - assert sum_call[1] == {"policy_loss"} + if sum_loss_metrics: + assert sum_call[1] == {"policy_loss"} + else: + assert sum_call[1] == set() # Verify min metrics min_call = [c for c in ops_and_keys if c[0] == "min"][0] @@ -145,7 +152,7 @@ def mock_all_reduce(d, op, group=None): ops_and_keys = [] for args, kwargs in strategy.all_reduce.call_args_list: data_dict = args[0] - op = kwargs.get("op", args[1]) + op = kwargs["op"] if data_dict: ops_and_keys.append((op, set(data_dict.keys())))