diff --git a/examples/train/async/async_trainer.py b/examples/train/async/async_trainer.py index 9f0dc7c063..a80edd730f 100644 --- a/examples/train/async/async_trainer.py +++ b/examples/train/async/async_trainer.py @@ -5,7 +5,6 @@ from skyrl.train.trainer import RayPPOTrainer from tqdm import tqdm from skyrl.train.utils import Timer -from skyrl.backends.skyrl_train.utils.ppo_utils import normalize_advantages_dict from skyrl.backends.skyrl_train.training_batch import TrainingInputBatch from skyrl.train.generators.base import GeneratorOutput from skyrl.train.utils.trainer_utils import ResumeMode @@ -146,9 +145,6 @@ async def _run_training(self, generation_buffer): training_input.pop(key) training_input.metadata.pop("uids") - if self.cfg.trainer.algorithm.advantage_batch_normalize: - training_input = normalize_advantages_dict(training_input) - if self.cfg.trainer.dump_data_batch: # dump data to file with Timer("dump_data_batch"): diff --git a/skyrl-agent/skyrl_agent/integrations/skyrl_train/trainer.py b/skyrl-agent/skyrl_agent/integrations/skyrl_train/trainer.py index 2135651f79..884f99f701 100644 --- a/skyrl-agent/skyrl_agent/integrations/skyrl_train/trainer.py +++ b/skyrl-agent/skyrl_agent/integrations/skyrl_train/trainer.py @@ -39,7 +39,6 @@ from skyrl.train.utils import Timer from skyrl.backends.skyrl_train.utils.ppo_utils import ( get_kl_controller, - normalize_advantages_dict, ) from skyrl.train.utils.trainer_utils import ( validate_generator_output, @@ -382,9 +381,6 @@ async def train(self): training_input.pop(key) training_input.metadata.pop("uids") - if self.cfg.trainer.algorithm.advantage_batch_normalize: - training_input = normalize_advantages_dict(training_input) - if self.cfg.trainer.dump_data_batch: # dump data to file with Timer("dump_data_batch"): diff --git a/skyrl/backends/skyrl_train/distributed/strategy.py b/skyrl/backends/skyrl_train/distributed/strategy.py index a002171749..d606be66db 100644 --- a/skyrl/backends/skyrl_train/distributed/strategy.py +++ b/skyrl/backends/skyrl_train/distributed/strategy.py @@ -67,11 +67,11 @@ def get_rank(self) -> int: """Get current process rank""" return dist.get_rank() - def all_reduce(self, data: DataT, op="mean") -> DataT: - """Perform all_reduce across all processes""" + def all_reduce(self, data: DataT, op="mean", group=None) -> DataT: + """Perform all_reduce across all processes (or within a process group).""" assert op in ("mean", "max", "sum", "min") if isinstance(data, dict): - return {k: self.all_reduce(v, op) for k, v in data.items()} + return {k: self.all_reduce(v, op, group=group) for k, v in data.items()} else: is_tensor = True if not isinstance(data, torch.Tensor): @@ -82,14 +82,15 @@ def all_reduce(self, data: DataT, op="mean") -> DataT: if is_cpu_tensor: data = data.to(torch.cuda.current_device()) if op == "mean": - data /= self.world_size - dist.all_reduce(data, op=dist.ReduceOp.SUM) + group_size = dist.get_world_size(group) if group is not None else self.world_size + data /= group_size + dist.all_reduce(data, op=dist.ReduceOp.SUM, group=group) elif op == "max": - dist.all_reduce(data, op=dist.ReduceOp.MAX) + dist.all_reduce(data, op=dist.ReduceOp.MAX, group=group) elif op == "min": - dist.all_reduce(data, op=dist.ReduceOp.MIN) + dist.all_reduce(data, op=dist.ReduceOp.MIN, group=group) elif op == "sum": - dist.all_reduce(data, op=dist.ReduceOp.SUM) + dist.all_reduce(data, op=dist.ReduceOp.SUM, group=group) if is_cpu_tensor: data = data.cpu() return data.item() if not is_tensor else data diff --git a/skyrl/backends/skyrl_train/utils/ppo_utils.py b/skyrl/backends/skyrl_train/utils/ppo_utils.py index 189fecccc0..b0b038b20d 100644 --- a/skyrl/backends/skyrl_train/utils/ppo_utils.py +++ b/skyrl/backends/skyrl_train/utils/ppo_utils.py @@ -19,7 +19,7 @@ from collections import defaultdict from enum import StrEnum from functools import wraps -from typing import Callable, List, Literal, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import numpy as np import ray @@ -27,7 +27,6 @@ from jaxtyping import Float from loguru import logger -from skyrl.backends.skyrl_train.training_batch import TrainingInputBatch from skyrl.backends.skyrl_train.utils.off_policy_correction_utils import ( apply_off_policy_correction, ) @@ -124,27 +123,6 @@ def compute_approx_kl( return kld -@torch.no_grad() -def normalize_advantages_dict(data: TrainingInputBatch) -> TrainingInputBatch: - """Normalizes the advantages in the data batch. - - Expects: - - `["advantages"]`: Float[torch.Tensor, "batch_size seqlen"] - - `["response_mask"]`: Float[torch.Tensor, "batch_size seqlen"] - """ - advantages: Float[torch.Tensor, "batch_size seqlen"] = data["advantages"] - response_masks: Float[torch.Tensor, "batch_size seqlen"] = data["response_mask"] - num_actions: float = response_masks.sum() - # mean - mean: float = advantages.mean() - # std - std: float = ((advantages - mean).pow(2) * response_masks).sum() - rstd: float = (std / num_actions).clamp(min=1e-8).rsqrt() - - data["advantages"] = (advantages - mean) * rstd - return data - - def masked_var(values, mask, unbiased=True): """Compute variance of tensor with masked values.""" mean = masked_mean(values, mask) @@ -558,12 +536,6 @@ def ppo_policy_loss( rollout_logprobs: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, dict[str, float]]: assert config.policy_loss_type in ["regular", "dual_clip"], "loss_type must be either 'regular' or 'dual_clip'" - loss_reduction = config.loss_reduction - assert loss_reduction in [ - "token_mean", - "sequence_mean", - "seq_mean_token_sum_norm", - ], "loss_reduction must be either 'token_mean', 'sequence_mean', or 'seq_mean_token_sum_norm'" ratio = safe_exp_delta(log_probs - old_log_probs, clip=20.0, out_dtype=log_probs.dtype) surr1 = ratio * advantages @@ -584,7 +556,7 @@ def ppo_policy_loss( ) loss_metrics.update(off_policy_metrics) - loss = reduce_loss(loss, loss_mask, loss_reduction, config.max_seq_len) + loss = reduce_loss(loss, loss_mask) return loss, loss_metrics @@ -656,8 +628,7 @@ def gate_function(x, tau): ) loss_metrics.update(off_policy_metrics) - # for SAPO, we need to aggregate the loss at the sequence level (seq-mean-token-mean) - loss = reduce_loss(loss, loss_mask, loss_reduction, config.max_seq_len) + loss = reduce_loss(loss, loss_mask) return loss, loss_metrics @@ -726,7 +697,7 @@ def gspo_policy_loss( ) loss_metrics.update(off_policy_metrics) - loss = reduce_loss(loss, loss_mask, loss_reduction, config.max_seq_len) + loss = reduce_loss(loss, loss_mask) return loss, loss_metrics @@ -763,7 +734,7 @@ def compute_policy_loss_cispo( ) loss_metrics.update(off_policy_metrics) - loss = reduce_loss(loss, loss_mask, config.loss_reduction, config.max_seq_len) + loss = reduce_loss(loss, loss_mask) return loss, loss_metrics @@ -791,13 +762,6 @@ def rollout_is_policy_loss( """ assert rollout_logprobs is not None, "rollout_logprobs are required for rollout_is" - loss_reduction = config.loss_reduction - assert loss_reduction in [ - "token_mean", - "sequence_mean", - "seq_mean_token_sum_norm", - ], "loss_reduction must be either 'token_mean', 'sequence_mean', or 'seq_mean_token_sum_norm'" - ratio = safe_exp_delta(log_probs - rollout_logprobs, clip=20.0, out_dtype=log_probs.dtype) in_range = (ratio > 1 - config.eps_clip_low) & (ratio < 1 + config.eps_clip_high) @@ -812,7 +776,7 @@ def rollout_is_policy_loss( ) loss_metrics.update(off_policy_metrics) - loss = reduce_loss(loss, loss_mask, loss_reduction, config.max_seq_len) + loss = reduce_loss(loss, loss_mask) return loss, loss_metrics @@ -874,12 +838,7 @@ def compute_policy_loss_clip_cov( # Apply correction mask to losses pg_losses = torch.maximum(pg_losses1, pg_losses2) * corr - pg_loss = reduce_loss( - loss=pg_losses, - loss_mask=loss_mask, - loss_reduction=config.loss_reduction, - max_seq_len=config.max_seq_len, - ) + pg_loss = reduce_loss(loss=pg_losses, loss_mask=loss_mask) return pg_loss, {"clip_ratio": clip_frac.item()} @@ -933,12 +892,7 @@ def compute_policy_loss_kl_cov( large_cov_idxs % advantages.shape[1], ] - pg_loss = reduce_loss( - loss=pg_losses, - loss_mask=loss_mask, - loss_reduction=config.loss_reduction, - max_seq_len=config.max_seq_len, - ) + pg_loss = reduce_loss(loss=pg_losses, loss_mask=loss_mask) # NOTE (sumanthrh): Since the pg clip ratio is not applicable for KL-COV so we just use 0.0 return pg_loss, {"clip_ratio": 0.0} @@ -977,10 +931,7 @@ def cross_entropy_loss( elementwise_loss = -log_probs # Apply loss mask and sum (matching Tinker's SUM reduction semantics) - if loss_mask is not None: - loss = (elementwise_loss * loss_mask).sum() - else: - loss = elementwise_loss.sum() + loss = reduce_loss(elementwise_loss, loss_mask) # No clipping in cross-entropy loss return loss, {"clip_ratio": 0.0} @@ -1039,30 +990,60 @@ def importance_sampling_loss( def reduce_loss( loss: torch.Tensor, loss_mask: Optional[torch.Tensor], - loss_reduction: Literal["token_mean", "sequence_mean", "seq_mean_token_sum_norm"], - max_seq_len: Optional[int] = None, ) -> torch.Tensor: + return (loss * loss_mask).sum() if loss_mask is not None else loss.sum() + + +def apply_loss_reduction_to_advantages_minibatch( + advantages: torch.Tensor, + loss_mask: torch.Tensor, + loss_reduction: str, + micro_batch_size: int, + max_seq_len: int, +) -> torch.Tensor: + """Scale advantages so that summing produces the desired loss reduction. + + Args: + advantages: Advantage tensor of shape (minibatch_size, seq_len). + loss_mask: Mask of shape (minibatch_size, seq_len) indicating valid loss tokens. + loss_reduction: One of "token_mean", "token_mean_legacy", "sequence_mean", "seq_mean_token_sum_norm". + micro_batch_size: Number of sequences per micro-batch + max_seq_len: Maximum sequence length. + + Returns: + Scaled advantages tensor. + """ + batch_size = advantages.shape[0] + normalized_advantages = torch.zeros_like(advantages) + + # Option 1: token mean if loss_reduction == "token_mean": - # sum over *all* valid tokens, divide by total valid-token count - loss = masked_mean(loss, loss_mask) + normalized_advantages = advantages / loss_mask.sum().clamp(min=1) + + # Option 1b: legacy token-mean that normalizes per-microbatch then averages across microbatches. + elif loss_reduction == "token_mean_legacy": + num_micro_batches = batch_size // micro_batch_size + for i in range(num_micro_batches): + start_idx = i * micro_batch_size + end_idx = (i + 1) * micro_batch_size + mb_advantages = advantages[start_idx:end_idx] + mb_loss_mask = loss_mask[start_idx:end_idx] + mb_advantages = mb_advantages / mb_loss_mask.sum().clamp(min=1) + mb_advantages /= num_micro_batches + normalized_advantages[start_idx:end_idx] = mb_advantages + + # Option 2: sequence mean elif loss_reduction == "sequence_mean": - # per-sequence token-mean (dim=-1), then batch-mean - loss = masked_mean(loss, loss_mask, dim=-1).mean() + normalized_advantages = advantages / (batch_size * loss_mask.sum(dim=-1, keepdim=True).clamp(min=1)) + + # Option 3: Dr. GRPO style loss reduction to avoid length bias by normalizing by a constant elif loss_reduction == "seq_mean_token_sum_norm": - # per-sequence token-sum, normalized by the max sequence length, then batch mean - # this is the Dr. GRPO loss reduction to avoid length bias by normalizing by a constant - assert max_seq_len is not None, "max_seq_len must be provided for seq_mean_token_sum_norm loss reduction" - # NOTE: max_seq_len can be set explicitly via algorithm.max_seq_len, otherwise defaults to - # cfg.generator.max_input_length + cfg.generator.sampling_params.max_generate_length - if loss_mask is not None: - seq_losses = torch.sum(loss * loss_mask, dim=-1) / max_seq_len - else: - # If no mask, assume all tokens are valid - seq_losses = torch.sum(loss, dim=-1) / max_seq_len - loss = torch.mean(seq_losses) + normalized_advantages = advantages / (batch_size * max_seq_len) + else: raise ValueError(f"Invalid loss reduction type: {loss_reduction}") - return loss + + return normalized_advantages # NOTE (erictang000): below ported from verl diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py index 597d59e446..73581ecb20 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py @@ -246,10 +246,21 @@ def loss_func(logits, data): loss_mask = data["loss_mask"] rollout_action_logprobs = data["rollout_action_logprobs"] action_mask = data.get("action_mask") + num_microbatches = data.get("num_microbatches") + dp_size = mpu.get_data_parallel_world_size() tp_grp = mpu.get_tensor_model_parallel_group() tp_rank = mpu.get_tensor_model_parallel_rank() + # Policy losses are pre-scaled to achieve the correct loss_reduction when summing across the entire minibatch + # (see `apply_loss_reduction_to_advantages_minibatch`). + # Megatron divides loss by num_microbatches + # (https://github.com/NVIDIA/Megatron-LM/blob/core_v0.15.2/megatron/core/pipeline_parallel/schedules.py#L248) + # and the data parallel all-reduce averages gradients across dp_size + # (https://github.com/NVIDIA/Megatron-LM/blob/core_v0.15.2/megatron/core/distributed/distributed_data_parallel.py#L285) + # so we multiply by both factors to recover the correct sum reduction. + grad_sum_correction_factor = num_microbatches * dp_size + # temperature normalization if temperature != 1.0: logits.div_(temperature) @@ -279,13 +290,15 @@ def loss_func(logits, data): # SFT path: cross_entropy loss (negative log likelihood) if resolved_loss_name == "cross_entropy": - loss = policy_loss + unscaled_loss = policy_loss + loss = unscaled_loss * grad_sum_correction_factor # Compute elementwise loss for Tinker API (per-token NLL) with torch.no_grad(): elementwise_loss = -action_log_probs if loss_mask is not None: elementwise_loss = elementwise_loss * loss_mask + elementwise_loss = elementwise_loss * grad_sum_correction_factor # Build per-sequence loss_fn_outputs batch_size = action_log_probs.shape[0] @@ -310,7 +323,7 @@ def loss_func(logits, data): ) metrics = { - "loss": loss.detach().item(), + "loss": unscaled_loss.detach().item(), "response_length": num_actions, "loss_fn_outputs": loss_fn_outputs, } @@ -340,7 +353,8 @@ def loss_func(logits, data): kl_loss = torch.tensor(0.0) kl_loss_term = kl_loss * loss_config.kl_loss_coef - loss = policy_loss + kl_loss_term - entropy_loss_term + unscaled_loss = policy_loss + kl_loss_term - entropy_loss_term + loss = unscaled_loss * grad_sum_correction_factor # Build per-sequence loss_fn_outputs with logprobs. batch_size = action_log_probs.shape[0] @@ -363,7 +377,7 @@ def loss_func(logits, data): ) metrics = { - "final_loss": loss.detach().item(), + "final_loss": unscaled_loss.detach().item(), "policy_loss": policy_loss.detach().item(), "policy_entropy": entropy.detach().item(), "policy_kl": kl_loss.detach().item(), diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py index 38bcfe0061..3b1ec841ce 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py @@ -700,6 +700,9 @@ def forward_backward( } ) + for m_batch in micro_buffer: + m_batch["num_microbatches"] = len(micro_buffer) + if not micro_buffer: return {} @@ -718,9 +721,6 @@ def forward_backward( if self.empty_cuda_cache: torch.cuda.empty_cache() - # Track number of micro-batches for metrics - self._micro_batches_accumulated += len(micro_buffer) - # Aggregate metrics across micro-batches all_loss_fn_outputs = [] # Handle separately from scalar metrics for metrics in metrics_list: @@ -730,10 +730,13 @@ def forward_backward( for k, v in metrics.items(): all_metrics[k].append(v) - # Reduce and all-reduce metrics - status = reduce_metrics(dict(all_metrics)) + # Reduce across microbatches and all-reduce metrics across DP ranks + # (metrics should be identical within DP groups, i.e., across TP/PP/SP ranks) + # NOTE: Sum loss metrics because scaling is already applied at the advantage level + status = reduce_metrics(all_metrics, sum_loss_metrics=True) status["policy_lr"] = self.optimizer.param_groups[0]["lr"] - status = all_reduce_metrics(status, self.strategy) + group = mpu.get_data_parallel_group(with_context_parallel=True) + status = all_reduce_metrics(status, self.strategy, group=group, sum_loss_metrics=True) # Add loss_fn_outputs back (not reduced, kept as list) if all_loss_fn_outputs: diff --git a/skyrl/backends/skyrl_train/workers/worker.py b/skyrl/backends/skyrl_train/workers/worker.py index 5bc2752858..ae6e6fd012 100644 --- a/skyrl/backends/skyrl_train/workers/worker.py +++ b/skyrl/backends/skyrl_train/workers/worker.py @@ -714,7 +714,11 @@ def forward_backward( for k, v in metrics.items(): all_metrics[k].append(v) - result = reduce_metrics(dict(all_metrics)) + # Reduce across microbatches and all-reduce metrics across DP ranks + # NOTE: Sum loss metrics because scaling is already applied at the advantage level + result = reduce_metrics(all_metrics, sum_loss_metrics=True) + dp_group = self.device_mesh.get_group("dp") + result = all_reduce_metrics(result, self.strategy, group=dp_group, sum_loss_metrics=True) # Add back loss_fn_outputs (concatenated across micro-batches) if all_loss_fn_outputs: @@ -802,9 +806,15 @@ def _forward_backward_micro( rollout_logprobs=rollout_action_logprobs, ) + # DP all-reduce averages gradients, but policy losses are pre-scaled sums + # (see `apply_loss_reduction_to_advantages_minibatch`), so we multiply by + # dp_size to recover the correct sum reduction across workers. + grad_sum_correction_factor = self.mesh_rank.dp_size + # SFT path: skip KL/entropy terms, return per-token outputs for Tinker API if resolved_loss_name == "cross_entropy": - loss = policy_loss + unscaled_loss = policy_loss + loss = unscaled_loss * grad_sum_correction_factor self.strategy.backward(loss, self.model, self.optimizer) # Compute elementwise loss for Tinker API (per-token NLL) @@ -812,6 +822,7 @@ def _forward_backward_micro( elementwise_loss = -action_log_probs if loss_mask is not None: elementwise_loss = elementwise_loss * loss_mask + elementwise_loss = elementwise_loss * grad_sum_correction_factor # Build per-sequence loss_fn_outputs (matches Tinker's ForwardBackwardOutput structure) # Trim to actual response length per sample (Tinker expects variable-length arrays @@ -837,7 +848,7 @@ def _forward_backward_micro( ) status = { - "loss": loss.item(), + "loss": unscaled_loss.item(), "response_length": num_actions, "lr": self.scheduler.get_last_lr()[0], "loss_fn_outputs": loss_fn_outputs, @@ -869,7 +880,8 @@ def _forward_backward_micro( kl_loss = torch.tensor(0.0) kl_loss_term = kl_loss * self.cfg.algorithm.kl_loss_coef - loss = policy_loss + kl_loss_term - entropy_loss_term + unscaled_loss = policy_loss + kl_loss_term - entropy_loss_term + loss = unscaled_loss * grad_sum_correction_factor self.strategy.backward(loss, self.model, self.optimizer) # Build per-sequence loss_fn_outputs with logprobs. @@ -893,7 +905,7 @@ def _forward_backward_micro( ) status = { - "final_loss": loss.item(), + "final_loss": unscaled_loss.item(), "policy_loss": policy_loss.item(), "policy_entropy": entropy.item(), "response_length": num_actions, @@ -905,37 +917,18 @@ def _forward_backward_micro( if self.cfg.algorithm.use_kl_loss: status["policy_kl"] = kl_loss.item() - loss_fn_outputs = status.pop("loss_fn_outputs", None) - - # All-reduce metrics across DP workers - status = all_reduce_metrics(status, self.strategy) - - # Add back loss_fn_outputs after all_reduce - if loss_fn_outputs is not None: - status["loss_fn_outputs"] = loss_fn_outputs - return status def optim_step(self) -> float: """ - Scale gradients by 1/micro_batches_accumulated, perform optimizer step, and reset counter. + Perform optimizer step. Returns: The gradient norm (before scaling, after clipping) """ - # Scale accumulated gradients by 1/N to get correct average - if self._micro_batches_accumulated > 0: - scale = 1.0 / self._micro_batches_accumulated - for param in self.model.parameters(): - if param.grad is not None: - param.grad.mul_(scale) - # Perform optimizer step (includes gradient clipping) grad_norm = self.strategy.optimizer_step(self.optimizer, self.model, self.scheduler, name="actor") - # Reset counter for next accumulation cycle - self._micro_batches_accumulated = 0 - if grad_norm is not None: grad_norm = grad_norm.detach().cpu().item() return grad_norm @@ -1058,7 +1051,13 @@ def forward_backward(self, data: TrainingInputBatch) -> Dict[str, float]: for k, v in metrics.items(): all_metrics[k].append(v) - return reduce_metrics(dict(all_metrics)) + # reduce metrics across micro batches + result = reduce_metrics(all_metrics) + + # all reduce metrics across DP workers + result = all_reduce_metrics(result, self.strategy) + + return result def _forward_backward_micro(self, experience: Experience) -> Dict[str, float]: """ @@ -1109,9 +1108,6 @@ def _forward_backward_micro(self, experience: Experience) -> Dict[str, float]: "critic_lr": self.scheduler.get_last_lr()[0], } - # All-reduce metrics across DP workers - status = all_reduce_metrics(status, self.strategy) - return status def optim_step(self) -> float: diff --git a/skyrl/backends/skyrl_train/workers/worker_utils.py b/skyrl/backends/skyrl_train/workers/worker_utils.py index ca47f60e58..cecaccc17e 100644 --- a/skyrl/backends/skyrl_train/workers/worker_utils.py +++ b/skyrl/backends/skyrl_train/workers/worker_utils.py @@ -6,9 +6,19 @@ from skyrl.train.dataset.replay_buffer import Experience -def reduce_metrics(metrics: Dict[str, List[float]]) -> Dict[str, float]: - """ - Reduce scalar metrics from a list of entries per key by averaging. +def reduce_metrics(metrics: Dict[str, List[float]], sum_loss_metrics: bool = False) -> Dict[str, float]: + """Reduce scalar metrics from a list of entries per key with the appropriate reduction. + + Default reduction is mean. Metrics ending in `_min` or `_max` use min/max respectively. + + If sum_loss_metrics is True, metrics ending in `_loss` are summed instead of averaged. + This should be used if the scaling is already done at the advantage level. + See `apply_loss_reduction_to_advantages_minibatch` for more details. + + Args: + metrics: Dictionary of metrics with keys as metric names and values as lists of metric values. + The list of values corresponds to micro-batches within a single mini-batch. + sum_loss_metrics: If True, metrics ending in `_loss` are summed (for pre-scaled policy losses). """ reduced_metrics = dict() for k, v in metrics.items(): @@ -20,21 +30,43 @@ def reduce_metrics(metrics: Dict[str, List[float]]) -> Dict[str, float]: reduced_metrics[k] = max(v) elif k.endswith("_min"): reduced_metrics[k] = min(v) + elif sum_loss_metrics and k.endswith("_loss"): + reduced_metrics[k] = sum(v) else: reduced_metrics[k] = sum(v) / len(v) return reduced_metrics -def all_reduce_metrics(metrics: Dict[str, List[float]], strategy: DistributedStrategy) -> Dict[str, float]: - """All reduce metrics across all processes.""" +def all_reduce_metrics( + metrics: Dict[str, float], + strategy: DistributedStrategy, + group=None, + sum_loss_metrics: bool = False, +) -> Dict[str, float]: + """All reduce metrics across all processes. + + Default reduction is mean. Metrics ending in `_min` or `_max` use min/max respectively. + If sum_loss_metrics is True, metrics ending in `_loss` are summed instead of averaged. + + Args: + metrics: Dictionary of metric name to scalar value. + strategy: Distributed strategy for all-reduce. + group: Process group for all-reduce. + sum_loss_metrics: If True, metrics ending in `_loss` are summed (for pre-scaled policy losses). + """ min_metrics = {k: v for k, v in metrics.items() if k.endswith("_min")} max_metrics = {k: v for k, v in metrics.items() if k.endswith("_max")} - mean_metrics = {k: v for k, v in metrics.items() if k not in min_metrics and k not in max_metrics} - status_mean = strategy.all_reduce(mean_metrics, op="mean") - status_min = strategy.all_reduce(min_metrics, op="min") - status_max = strategy.all_reduce(max_metrics, op="max") + sum_metrics = {k: v for k, v in metrics.items() if sum_loss_metrics and k.endswith("_loss")} + mean_metrics = { + k: v for k, v in metrics.items() if k not in min_metrics and k not in max_metrics and k not in sum_metrics + } + status_mean = strategy.all_reduce(mean_metrics, op="mean", group=group) + status_min = strategy.all_reduce(min_metrics, op="min", group=group) + status_max = strategy.all_reduce(max_metrics, op="max", group=group) + status_sum = strategy.all_reduce(sum_metrics, op="sum", group=group) status_mean.update(status_min) status_mean.update(status_max) + status_mean.update(status_sum) return status_mean diff --git a/skyrl/train/fully_async_trainer.py b/skyrl/train/fully_async_trainer.py index 89075e14bd..3c5ecacccb 100644 --- a/skyrl/train/fully_async_trainer.py +++ b/skyrl/train/fully_async_trainer.py @@ -29,7 +29,6 @@ ) from skyrl.backends.skyrl_train.training_batch import TrainingInputBatch from skyrl.backends.skyrl_train.utils.io import io -from skyrl.backends.skyrl_train.utils.ppo_utils import normalize_advantages_dict from skyrl.train.generators.base import GeneratorOutput from skyrl.train.generators.utils import ( concatenate_generator_outputs, @@ -519,9 +518,6 @@ async def _run_training(self, training_input: TrainingInputBatch): training_input.pop(key) training_input.metadata.pop("uids") - if self.cfg.trainer.algorithm.advantage_batch_normalize: - training_input = normalize_advantages_dict(training_input) - if self.cfg.trainer.dump_data_batch: # dump data to file with Timer("dump_data_batch"): diff --git a/skyrl/train/trainer.py b/skyrl/train/trainer.py index e22312c56e..f01eed964c 100644 --- a/skyrl/train/trainer.py +++ b/skyrl/train/trainer.py @@ -33,9 +33,9 @@ from skyrl.backends.skyrl_train.utils.ppo_utils import ( AdaptiveKLController, FixedKLController, + apply_loss_reduction_to_advantages_minibatch, compute_approx_kl, get_kl_controller, - normalize_advantages_dict, ) from skyrl.backends.skyrl_train.utils.torch_utils import masked_mean from skyrl.backends.skyrl_train.workers.worker import PPORayActorGroup @@ -278,9 +278,6 @@ async def train(self): training_input.pop(key) training_input.metadata.pop("uids") - if self.cfg.trainer.algorithm.advantage_batch_normalize: - training_input = normalize_advantages_dict(training_input) - if self.cfg.trainer.dump_data_batch: # dump data to file with Timer("dump_data_batch"): @@ -1057,6 +1054,37 @@ def apply_reward_kl_penalty( return data + @torch.no_grad() + def _normalize_advantages(self, data: TrainingInputBatch, mini_batch_size: int) -> TrainingInputBatch: + advantages = data["advantages"] + response_mask = data["response_mask"] + + # Step 1: Z-score normalization (if enabled) + if self.cfg.trainer.algorithm.advantage_batch_normalize: + num_actions = response_mask.sum() + mean = advantages.mean() + std = ((advantages - mean).pow(2) * response_mask).sum() + rstd = (std / num_actions).clamp(min=1e-8).rsqrt() + data["advantages"] = (advantages - mean) * rstd + + # Step 2: Loss reduction normalization per mini-batch + num_mini_batches = len(data) // mini_batch_size + normalized_advantages = torch.zeros_like(advantages) + for local_step in range(num_mini_batches): + start_idx = local_step * mini_batch_size + end_idx = (local_step + 1) * mini_batch_size + mini_batch = data[start_idx:end_idx] + normalized_advantages[start_idx:end_idx] = apply_loss_reduction_to_advantages_minibatch( + advantages=mini_batch["advantages"], + loss_mask=mini_batch["loss_mask"], + loss_reduction=self.cfg.trainer.algorithm.loss_reduction, + micro_batch_size=self.cfg.trainer.micro_train_batch_size_per_gpu, + max_seq_len=self.cfg.trainer.algorithm.max_seq_len, + ) + + data["advantages"] = normalized_advantages + return data + def _execute_training_step(self, model: str, data: TrainingInputBatch) -> Dict[str, float]: """ Execute training step using forward_backward + optim_step. @@ -1078,6 +1106,8 @@ def _execute_training_step(self, model: str, data: TrainingInputBatch) -> Dict[s n_samples = self.cfg.generator.n_samples_per_prompt if model == "policy": mini_batch_size = self.cfg.trainer.policy_mini_batch_size * n_samples + # Normalize advantages for policy training; critic training does not need this + data = self._normalize_advantages(data, mini_batch_size) else: mini_batch_size = self.cfg.trainer.critic_mini_batch_size * n_samples @@ -1102,7 +1132,7 @@ def _execute_training_step(self, model: str, data: TrainingInputBatch) -> Dict[s # Reduce metrics across all mini-batches and epochs # pop out loss_fn_outputs since it's not a scalar metric and to avoid logging it all_metrics.pop("loss_fn_outputs", None) - reduced_metrics = reduce_metrics(all_metrics) + reduced_metrics = reduce_metrics(all_metrics, sum_loss_metrics=False) return reduced_metrics def train_critic_and_policy(self, data: TrainingInputBatch): diff --git a/skyrl/train/utils/utils.py b/skyrl/train/utils/utils.py index 1fa7b1d47b..98d0b2befa 100644 --- a/skyrl/train/utils/utils.py +++ b/skyrl/train/utils/utils.py @@ -276,6 +276,7 @@ def validate_cfg(cfg: SkyRLTrainConfig): assert cfg.trainer.algorithm.loss_reduction in ( "token_mean", + "token_mean_legacy", "sequence_mean", "seq_mean_token_sum_norm", ), ( diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_megatron_worker.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_megatron_worker.py index 94726a4f4d..1ebbba45a0 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/test_megatron_worker.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_megatron_worker.py @@ -84,13 +84,16 @@ def get_test_training_batch(batch_size=4) -> TrainingInputBatch: sequences = [tokenizer.encode(sentence) for sentence in sentences] attention_masks = [[1] * len(seq) for seq in sequences] - num_actions = 10 + num_actions = 15 # max seq len 1 longer than the longest sequence so we always have some padding max_seq_length = max([len(seq) for seq in sequences]) + 7 pad_token_id = tokenizer.pad_token_id pad_before = [4, 0, 1, 6] * num_repeats pad_after = [max_seq_length - len(seq) - pad_before[i] for i, seq in enumerate(sequences)] + loss_masks = torch.stack( + [torch.cat([torch.ones(num_actions - pad_after[i]), torch.zeros(pad_after[i])]) for i in range(batch_size)] + ) for i, (pad_before, pad_after) in enumerate(zip(pad_before, pad_after)): sequences[i] = [pad_token_id] * pad_before + sequences[i] + [pad_token_id] * pad_after @@ -109,8 +112,8 @@ def get_test_training_batch(batch_size=4) -> TrainingInputBatch: "values": torch.tensor([[0.1] * num_actions] * batch_size), "returns": torch.tensor([[0.1] * num_actions] * batch_size), "advantages": torch.tensor([[0.5] * num_actions] * batch_size), - "loss_mask": torch.tensor([[1] * num_actions] * batch_size), - "response_mask": torch.tensor([[1] * num_actions] * batch_size), + "loss_mask": loss_masks, + "response_mask": loss_masks, } ) data.metadata = {"response_length": num_actions} @@ -449,7 +452,7 @@ async def test_megatron_lora_forward(ray_init_fixture, tp, pp, cp, ep, etp, gpus "tp2_pp2_policy_seq_packing_with_entropy_loss", "tp2_pp2_policy_lora", "tp2_pp2_policy_unpacked", - "tp2_cp2_policy_seq_packing", + "tp2_cp2_policy_seq_packing_no_entropy_loss", "tp4_pp1_cp1_ep4_etp1_policy_seq_packing", "tp4_pp1_cp1_ep4_etp1_policy_seq_packing_lora", ], @@ -462,7 +465,8 @@ async def test_megatron_train( Full test: initialize actor group, send dummy experience to training_step, validate output. """ cfg = get_test_actor_config(model_name=MODEL_NAME if ep == 1 else MOE_MODEL_NAME) - batch = get_test_training_batch(batch_size=gpus_per_node) + batch_size = gpus_per_node * 8 + batch = get_test_training_batch(batch_size=batch_size) cfg.trainer.strategy = "megatron" cfg.trainer.placement.policy_num_gpus_per_node = gpus_per_node @@ -472,6 +476,7 @@ async def test_megatron_train( cfg.trainer.policy.megatron_config.expert_model_parallel_size = ep cfg.trainer.policy.megatron_config.expert_tensor_parallel_size = etp cfg.trainer.use_sample_packing = use_sample_packing + cfg.trainer.algorithm.use_kl_loss = False if use_entropy_loss: cfg.trainer.algorithm.use_entropy_loss = True cfg.trainer.algorithm.entropy_loss_coef = 0.01 @@ -489,7 +494,7 @@ async def test_megatron_train( cfg.trainer.algorithm.off_policy_correction.geo_mask_low = 0.98 # set batch sizes correctly - cfg.trainer.train_batch_size = gpus_per_node + cfg.trainer.train_batch_size = batch_size cfg.trainer.policy_mini_batch_size = gpus_per_node cfg.generator.n_samples_per_prompt = 1 cfg.trainer.micro_train_batch_size_per_gpu = 1 @@ -552,7 +557,7 @@ async def test_megatron_train( # Both FSDP and Megatron use forward_backward + optim_step (unified interface) batch.metadata["global_step"] = 0 - results_fsdp = ray.get(actor_group.async_run_ray_method("pass_through", "forward_backward", batch)) + results_fsdp = ray.get(actor_group.async_run_ray_method("mesh", "forward_backward", batch)) ray.get(actor_group.async_run_ray_method("pass_through", "optim_step")) # Get learning rate from worker lr_results = ray.get(actor_group.async_run_ray_method("pass_through", "get_lr")) @@ -568,7 +573,6 @@ async def test_megatron_train( "policy_lr", "loss_metrics/clip_ratio", "policy_entropy", - "policy_kl", "final_loss", ] if ep > 1: diff --git a/tests/backends/skyrl_train/gpu/test_grpo_sp_sanity.py b/tests/backends/skyrl_train/gpu/test_grpo_sp_sanity.py index 44b4003264..ddd35051e3 100644 --- a/tests/backends/skyrl_train/gpu/test_grpo_sp_sanity.py +++ b/tests/backends/skyrl_train/gpu/test_grpo_sp_sanity.py @@ -9,7 +9,6 @@ from loguru import logger from tqdm import tqdm -from skyrl.backends.skyrl_train.utils.ppo_utils import normalize_advantages_dict from skyrl.train.config import SkyRLTrainConfig from skyrl.train.entrypoints.main_base import BasePPOExp from skyrl.train.trainer import RayPPOTrainer @@ -117,9 +116,6 @@ def train(self): # remove some unwanted keys data.pop(batch_keys=["rewards"]) - if self.cfg.trainer.algorithm.advantage_batch_normalize: - data = normalize_advantages_dict(data) - # 4. train policy/critic model with Timer("train_critic_and_policy", self.all_timings): status = self.train_critic_and_policy(data) diff --git a/tests/backends/skyrl_train/utils/test_ppo_utils.py b/tests/backends/skyrl_train/utils/test_ppo_utils.py index 8f61a9bc8e..bd15e089e0 100644 --- a/tests/backends/skyrl_train/utils/test_ppo_utils.py +++ b/tests/backends/skyrl_train/utils/test_ppo_utils.py @@ -14,6 +14,7 @@ AdvantageEstimatorRegistry, FixedKLController, PolicyLossRegistry, + apply_loss_reduction_to_advantages_minibatch, compute_advantages_and_returns, compute_approx_kl, compute_gae_advantage_return, @@ -245,29 +246,17 @@ def test_compute_gae_advantage_return_lam(advantage_test_data): def test_reduce_loss(): - """Test the reduce_loss function with different reduction types.""" - # Test data: 2x3 loss tensor with different valid token counts per sequence + """Test that reduce_loss computes the masked sum correctly.""" loss = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) - loss_mask = torch.tensor([[1.0, 1.0, 1.0], [1.0, 0.0, 0.0]]) # seq0 has 3 tokens, seq1 has 1 token + loss_mask = torch.tensor([[1.0, 1.0, 1.0], [1.0, 0.0, 0.0]]) - # Test token_mean: sum all valid losses / count valid tokens - # Valid losses: [1.0, 2.0, 3.0, 4.0], mean = 10.0/4 = 2.5 - result_token = reduce_loss(loss, loss_mask, "token_mean") - expected_token = torch.tensor(2.5) - assert torch.allclose(result_token, expected_token), f"Expected {expected_token}, got {result_token}" + # With mask: sum of valid losses = 1.0 + 2.0 + 3.0 + 4.0 = 10.0 + result = reduce_loss(loss, loss_mask) + assert torch.allclose(result, torch.tensor(10.0)) - # Test sequence_mean: mean of per-sequence means - # Seq 0: (1.0 + 2.0 + 3.0) / 3 = 2.0, Seq 1: 4.0 / 1 = 4.0, batch mean = (2.0 + 4.0) / 2 = 3.0 - result_seq = reduce_loss(loss, loss_mask, "sequence_mean") - expected_seq = torch.tensor(3.0) - assert torch.allclose(result_seq, expected_seq), f"Expected {expected_seq}, got {result_seq}" - - # Test seq_mean_token_sum_norm: sum per sequence / max_len, then batch mean - # Seq 0: (1.0 + 2.0 + 3.0) / 4 = 1.5, Seq 1: 4.0 / 4 = 1.0, batch mean = (1.5 + 1.0) / 2 = 1.25 - max_seq_len = 4 - result_max = reduce_loss(loss, loss_mask, "seq_mean_token_sum_norm", max_seq_len) - expected_max = torch.tensor(1.25) - assert torch.allclose(result_max, expected_max), f"Expected {expected_max}, got {result_max}" + # Without mask: sum of all losses = 1+2+3+4+5+6 = 21.0 + result_no_mask = reduce_loss(loss, None) + assert torch.allclose(result_no_mask, torch.tensor(21.0)) def test_adaptive_kl_controller_update(): @@ -558,3 +547,98 @@ def test_func(**kwargs): finally: ray.shutdown() + + +class TestApplyLossReductionToAdvantagesMinibatch: + """Tests for apply_loss_reduction_to_advantages_minibatch.""" + + def test_token_mean(self): + advantages = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + loss_mask = torch.tensor([[1.0, 1.0, 0.0], [1.0, 1.0, 1.0]]) + # valid tokens: 1+2+4+5+6 = 18, count = 5, mean = 3.6 + scaled = apply_loss_reduction_to_advantages_minibatch( + advantages=advantages, + loss_mask=loss_mask, + loss_reduction="token_mean", + micro_batch_size=1, + max_seq_len=3, + ) + loss = reduce_loss(scaled, loss_mask) + assert torch.allclose(loss, torch.tensor(3.6)) + + def test_token_mean_all_masked(self): + """Token mean with all-zero mask should produce zero loss, not NaN.""" + advantages = torch.tensor([[1.0, 2.0]]) + loss_mask = torch.tensor([[0.0, 0.0]]) + scaled = apply_loss_reduction_to_advantages_minibatch( + advantages=advantages, + loss_mask=loss_mask, + loss_reduction="token_mean", + micro_batch_size=1, + max_seq_len=2, + ) + loss = reduce_loss(scaled, loss_mask) + assert torch.allclose(loss, torch.tensor(0.0)) + + def test_sequence_mean(self): + advantages = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) + loss_mask = torch.tensor([[1.0, 1.0], [1.0, 0.0]]) + # seq 0: token mean = (1+2)/2 = 1.5 + # seq 1: token mean = 3/1 = 3.0 + # sequence mean = (1.5 + 3.0) / 2 = 2.25 + scaled = apply_loss_reduction_to_advantages_minibatch( + advantages=advantages, + loss_mask=loss_mask, + loss_reduction="sequence_mean", + micro_batch_size=1, + max_seq_len=2, + ) + loss = reduce_loss(scaled, loss_mask) + assert torch.allclose(loss, torch.tensor(2.25)) + + def test_seq_mean_token_sum_norm(self): + advantages = torch.tensor([[2.0, 3.0], [9.0, 12.0]]) + loss_mask = torch.tensor([[1.0, 1.0], [1.0, 0.0]]) + # seq 0: (2+3)/10 = 0.5 + # seq 1: 9/10 = 0.9 + # mean across batch = (0.5 + 0.9) / 2 = 0.7 + scaled = apply_loss_reduction_to_advantages_minibatch( + advantages=advantages, + loss_mask=loss_mask, + loss_reduction="seq_mean_token_sum_norm", + micro_batch_size=1, + max_seq_len=10, + ) + loss = reduce_loss(scaled, loss_mask) + assert torch.allclose(loss, torch.tensor(0.7)) + + def test_token_mean_legacy(self): + """Legacy token mean: per-microbatch token mean, then averaged across microbatches.""" + # 4 sequences, micro_batch_size=2 -> 2 microbatches + advantages = torch.tensor([[2.0, 4.0], [6.0, 8.0], [10.0, 12.0], [14.0, 16.0]]) + loss_mask = torch.tensor([[1.0, 1.0], [1.0, 0.0], [1.0, 1.0], [1.0, 1.0]]) + # microbatch 0: token mean = (2+4+6)/3 = 4 + # microbatch 1: token mean = (10+12+14+16)/4 = 13 + # average of microbatch means = (4.0 + 13.0) / 2 = 8.5 + scaled = apply_loss_reduction_to_advantages_minibatch( + advantages=advantages, + loss_mask=loss_mask, + loss_reduction="token_mean_legacy", + micro_batch_size=2, + max_seq_len=2, + ) + loss = reduce_loss(scaled, loss_mask) + assert torch.allclose(loss, torch.tensor(8.5)) + + def test_invalid_loss_reduction_raises(self): + """Invalid loss_reduction should raise ValueError.""" + advantages = torch.tensor([[1.0]]) + loss_mask = torch.tensor([[1.0]]) + with pytest.raises(ValueError, match="Invalid loss reduction type"): + apply_loss_reduction_to_advantages_minibatch( + advantages=advantages, + loss_mask=loss_mask, + loss_reduction="invalid", + micro_batch_size=1, + max_seq_len=1, + ) diff --git a/tests/backends/skyrl_train/workers/test_worker_utils.py b/tests/backends/skyrl_train/workers/test_worker_utils.py index a319eba7a4..2c01499faa 100644 --- a/tests/backends/skyrl_train/workers/test_worker_utils.py +++ b/tests/backends/skyrl_train/workers/test_worker_utils.py @@ -28,10 +28,22 @@ def test_reduce_metrics_min_suffix(self): assert result["is_ratio_min"] == 1.0 def test_reduce_metrics_mean_default(self): - """Keys without _max/_min suffix should use mean reduction.""" + """Keys without _max/_min/_loss suffix should use mean reduction.""" + metrics = {"entropy": [1.0, 2.0, 3.0]} + result = reduce_metrics(metrics) + assert result["entropy"] == 2.0 + + def test_reduce_metrics_loss_default_mean(self): + """_loss keys default to mean when sum_loss_metrics=False.""" metrics = {"policy_loss": [1.0, 2.0, 3.0]} result = reduce_metrics(metrics) - assert result["policy_loss"] == 2.0 # mean of [1, 2, 3] + assert result["policy_loss"] == 2.0 + + def test_reduce_metrics_sum_loss_metrics(self): + """_loss keys are summed when sum_loss_metrics=True.""" + metrics = {"policy_loss": [1.0, 2.0, 3.0]} + result = reduce_metrics(metrics, sum_loss_metrics=True) + assert result["policy_loss"] == 6.0 def test_reduce_metrics_mixed(self): """Test mixed metric types are reduced correctly.""" @@ -39,11 +51,13 @@ def test_reduce_metrics_mixed(self): "is_ratio_max": [1.0, 10.0], "is_ratio_min": [0.5, 2.0], "policy_loss": [1.0, 3.0], + "entropy": [1.0, 3.0], } - result = reduce_metrics(metrics) + result = reduce_metrics(metrics, sum_loss_metrics=True) assert result["is_ratio_max"] == 10.0 assert result["is_ratio_min"] == 0.5 - assert result["policy_loss"] == 2.0 + assert result["policy_loss"] == 4.0 # sum + assert result["entropy"] == 2.0 # mean def test_reduce_metrics_single_value(self): """Test reduction with single value lists.""" @@ -65,12 +79,13 @@ def test_reduce_metrics_empty_raises(self): class TestAllReduceMetrics: - def test_all_reduce_metrics_separates_by_suffix(self): + @pytest.mark.parametrize("sum_loss_metrics", [True, False]) + def test_all_reduce_metrics_separates_by_suffix(self, sum_loss_metrics): """Verify metrics are correctly separated by suffix and reduced with correct ops.""" strategy = MagicMock() # Mock all_reduce to return the input dict unchanged but track calls - def mock_all_reduce(d, op): + def mock_all_reduce(d, op, group=None): return {k: v for k, v in d.items()} strategy.all_reduce.side_effect = mock_all_reduce @@ -82,10 +97,10 @@ def mock_all_reduce(d, op): "entropy": 0.5, } - _ = all_reduce_metrics(metrics, strategy) + _ = all_reduce_metrics(metrics, strategy, sum_loss_metrics=sum_loss_metrics) - # Verify all_reduce was called 3 times - assert strategy.all_reduce.call_count == 3 + # Verify all_reduce was called 4 times + assert strategy.all_reduce.call_count == 4 # Check that the correct ops were used calls = strategy.all_reduce.call_args_list @@ -98,9 +113,19 @@ def mock_all_reduce(d, op): op = kwargs.get("op") if kwargs else args[1] ops_and_keys.append((op, set(data_dict.keys()))) - # Verify mean metrics (policy_loss, entropy) + # Verify mean metrics (entropy) mean_call = [c for c in ops_and_keys if c[0] == "mean"][0] - assert mean_call[1] == {"policy_loss", "entropy"} + if sum_loss_metrics: + assert mean_call[1] == {"entropy"} + else: + assert mean_call[1] == {"entropy", "policy_loss"} + + # Verify sum metrics (explicit sum_keys) + sum_call = [c for c in ops_and_keys if c[0] == "sum"][0] + if sum_loss_metrics: + assert sum_call[1] == {"policy_loss"} + else: + assert sum_call[1] == set() # Verify min metrics min_call = [c for c in ops_and_keys if c[0] == "min"][0] @@ -110,14 +135,40 @@ def mock_all_reduce(d, op): max_call = [c for c in ops_and_keys if c[0] == "max"][0] assert max_call[1] == {"is_ratio_max"} + def test_all_reduce_metrics_average_loss_metrics(self): + """Verify _loss keys are averaged when sum_loss_metrics=False.""" + strategy = MagicMock() + + # Mock all_reduce to return the input dict unchanged but track calls + def mock_all_reduce(d, op, group=None): + return {k: v for k, v in d.items()} + + strategy.all_reduce.side_effect = mock_all_reduce + + metrics = {"critic_loss": 1.5, "entropy": 0.5} + _ = all_reduce_metrics(metrics, strategy, sum_loss_metrics=False) + + # Both should be mean-reduced (critic_loss is NOT summed without the flag) + ops_and_keys = [] + for args, kwargs in strategy.all_reduce.call_args_list: + data_dict = args[0] + op = kwargs["op"] + if data_dict: + ops_and_keys.append((op, set(data_dict.keys()))) + + mean_call = [c for c in ops_and_keys if c[0] == "mean"][0] + assert mean_call[1] == {"critic_loss", "entropy"} + def test_all_reduce_metrics_returns_merged_results(self): """Verify results from all reductions are merged correctly.""" strategy = MagicMock() # Mock all_reduce to modify values based on op - def mock_all_reduce(d, op): + def mock_all_reduce(d, op, group=None): if op == "mean": return {k: v * 2 for k, v in d.items()} # Double for mean + elif op == "sum": + return {k: v * 4 for k, v in d.items()} # Quadruple for sum elif op == "min": return {k: v / 2 for k, v in d.items()} # Halve for min elif op == "max": @@ -130,49 +181,19 @@ def mock_all_reduce(d, op): "is_ratio_max": 10.0, "is_ratio_min": 0.1, "policy_loss": 1.5, + "entropy": 0.5, } - result = all_reduce_metrics(metrics, strategy) + result = all_reduce_metrics(metrics, strategy, sum_loss_metrics=True) # Check all keys are present assert "is_ratio_max" in result assert "is_ratio_min" in result assert "policy_loss" in result + assert "entropy" in result # Check values were transformed correctly assert result["is_ratio_max"] == 30.0 # 10.0 * 3 (max op) assert result["is_ratio_min"] == 0.05 # 0.1 / 2 (min op) - assert result["policy_loss"] == 3.0 # 1.5 * 2 (mean op) - - def test_all_reduce_metrics_only_max(self): - """Test with only _max metrics.""" - strategy = MagicMock() - strategy.all_reduce.side_effect = lambda d, op: d - - metrics = {"loss_max": 5.0, "ratio_max": 10.0} - - result = all_reduce_metrics(metrics, strategy) - - assert result == {"loss_max": 5.0, "ratio_max": 10.0} - - def test_all_reduce_metrics_only_min(self): - """Test with only _min metrics.""" - strategy = MagicMock() - strategy.all_reduce.side_effect = lambda d, op: d - - metrics = {"loss_min": 0.1, "ratio_min": 0.01} - - result = all_reduce_metrics(metrics, strategy) - - assert result == {"loss_min": 0.1, "ratio_min": 0.01} - - def test_all_reduce_metrics_only_mean(self): - """Test with only mean metrics (no _max/_min suffix).""" - strategy = MagicMock() - strategy.all_reduce.side_effect = lambda d, op: d - - metrics = {"policy_loss": 1.5, "entropy": 0.5} - - result = all_reduce_metrics(metrics, strategy) - - assert result == {"policy_loss": 1.5, "entropy": 0.5} + assert result["policy_loss"] == 6.0 # sum op + assert result["entropy"] == 1.0 # 0.5 * 2 (mean op) diff --git a/tests/train/algorithms/test_losses.py b/tests/train/algorithms/test_losses.py index 2f4b21a2ca..dab3434ebc 100644 --- a/tests/train/algorithms/test_losses.py +++ b/tests/train/algorithms/test_losses.py @@ -47,7 +47,6 @@ def test_policy_loss_dual_clip(): eps_clip_high=0.2, clip_ratio_c=3.0, policy_loss_type="dual_clip", - loss_reduction="token_mean", max_seq_len=4, off_policy_correction=NULL_OFF_POLICY_CORR, ) @@ -71,7 +70,7 @@ def test_policy_loss_dual_clip(): # For negative advantages, use dual clipped loss final_loss = torch.where(advantages < 0, min_loss, max_loss) # [-0.5, 1.0, 12.0] assert torch.allclose(final_loss, torch.tensor([[-0.5, 1.0, 12.0]], device=device), rtol=1e-3) - expected_loss = final_loss.mean() # -(-12.5/3) = 4.1667 + expected_loss = final_loss.sum() # Calculate actual loss actual_loss, _ = loss_fn(log_probs=log_probs, old_log_probs=old_log_probs, advantages=advantages, config=config) @@ -79,7 +78,7 @@ def test_policy_loss_dual_clip(): # Verify results torch.testing.assert_close(actual_loss, expected_loss, rtol=1e-3, atol=1e-8) # close to hand calculated value - assert actual_loss.item() == pytest.approx(4.1667, abs=1e-4) + assert actual_loss.item() == pytest.approx(12.5, abs=1e-4) def test_policy_loss_cispo(): @@ -98,7 +97,6 @@ def test_policy_loss_cispo(): config = AlgorithmConfig( cispo=CISPOConfig(cispo_eps_clip_low=0.2, cispo_eps_clip_high=0.2), policy_loss_type="cispo", - loss_reduction="token_mean", max_seq_len=4, off_policy_correction=NULL_OFF_POLICY_CORR, ) @@ -119,9 +117,9 @@ def test_policy_loss_cispo(): # loss_per_token[0] = -(1.0 * 0.8 * -1.69315) = 1.35452 # loss_per_token[1] = -(-1.0 * 1.0 * -1.0) = -1.0 # loss_per_token[2] = -(-4.0 * 1.2 * -0.69741) = -3.347568 - # mean(loss) = (1.35452 - 1.0 - 3.347568) / 3 = -0.99768266666 + # sum(loss) = (1.35452 - 1.0 - 3.347568) = -2.9930 loss = -ratio.clamp(1 - 0.2, 1 + 0.2) * advantages * log_probs - expected_loss = loss.mean() + expected_loss = loss.sum() # Calculate actual loss actual_loss, _ = loss_fn( @@ -134,158 +132,7 @@ def test_policy_loss_cispo(): # Verify results torch.testing.assert_close(actual_loss, expected_loss, rtol=1e-3, atol=1e-8) # close to hand calculated value - assert actual_loss.item() == pytest.approx(-0.99768266666, abs=1e-4) - - -def test_policy_loss_reduction_modes(): - """Tests different loss_reduction modes in PolicyLoss function. - - Note: token_mean and sequence_mean give the same result when all sequences - have the same length and no mask is applied, but differ when masking creates - different effective sequence lengths. - """ - - device = "cpu" - - clip_eps_low = 0.2 - clip_eps_high = 0.2 - - advantages = torch.tensor( - [ - [2.0, 2.0, 2.0], # sequence 1: consistently higher advantages - [1.0, 1.0, 1.0], # sequence 2: consistently lower advantages - ], - device=device, - ) - - old_log_probs = torch.tensor([[-1.0, -1.0, -1.0], [-1.0, -1.0, -1.0]], device=device) - - log_probs = torch.tensor( - [[-1.5, -0.5, -1.2], [-0.8, -1.3, -0.9]], # ratios ≈ [[0.61, 1.65, 0.83],[1.22, 0.74, 1.11]] - device=device, - ) - - # Create masks to test sequences with different numbers of valid tokens - loss_mask = torch.tensor([[1.0, 1.0, 1.0], [1.0, 0.0, 0.0]], device=device) - - # Create configs for different reduction modes - config_token = AlgorithmConfig( - eps_clip_low=clip_eps_low, - eps_clip_high=clip_eps_high, - clip_ratio_c=3.0, - policy_loss_type="regular", - loss_reduction="token_mean", - max_seq_len=4, - off_policy_correction=NULL_OFF_POLICY_CORR, - ) - - config_seq = AlgorithmConfig( - eps_clip_low=clip_eps_low, - eps_clip_high=clip_eps_high, - clip_ratio_c=3.0, - policy_loss_type="regular", - loss_reduction="sequence_mean", - max_seq_len=4, - off_policy_correction=NULL_OFF_POLICY_CORR, - ) - - # Get loss function - loss_fn = PolicyLossRegistry.get("regular") - - # Test token_mean without mask - loss_token_no_mask, _ = loss_fn(log_probs, old_log_probs, advantages, config_token) - - # Test token_mean with mask - loss_token_with_mask, _ = loss_fn(log_probs, old_log_probs, advantages, config_token, loss_mask) - - # Test sequence_mean without mask - loss_seq_no_mask, _ = loss_fn(log_probs, old_log_probs, advantages, config_seq) - - # Test sequence_mean with mask - loss_seq_with_mask, _ = loss_fn(log_probs, old_log_probs, advantages, config_seq, loss_mask) - - # Manual calculations to verify (using default PolicyLoss parameters) - ratio = torch.exp(log_probs - old_log_probs) - surr1 = ratio * advantages - surr2 = ratio.clamp(1 - clip_eps_low, 1 + clip_eps_high) * advantages # clip_eps_low=0.2, clip_eps_high=0.2 - loss_per_token = -torch.min(surr1, surr2) - - # Expected token_mean without mask: mean of all tokens - expected_token_no_mask = loss_per_token.mean() - - # Expected token_mean with mask: masked mean of all tokens - expected_token_with_mask = (loss_per_token * loss_mask).sum() / (loss_mask.sum() + 1e-8) - - # Expected sequence_mean without mask: mean of sequence means - expected_seq_no_mask = loss_per_token.mean(dim=1).mean() - - # Expected sequence_mean with mask: mean of masked sequence means - seq_means_masked = (loss_per_token * loss_mask).sum(dim=1) / (loss_mask.sum(dim=1) + 1e-8) - expected_seq_with_mask = seq_means_masked.mean() - - # Verify results - torch.testing.assert_close(loss_token_no_mask, expected_token_no_mask, rtol=1e-5, atol=1e-8) - torch.testing.assert_close(loss_token_with_mask, expected_token_with_mask, rtol=1e-5, atol=1e-8) - torch.testing.assert_close(loss_seq_no_mask, expected_seq_no_mask, rtol=1e-5, atol=1e-8) - torch.testing.assert_close(loss_seq_with_mask, expected_seq_with_mask, rtol=1e-5, atol=1e-8) - - # Verify that the two reduction modes give the same results when sequences have equal length and no mask - assert torch.allclose( - loss_token_no_mask, loss_seq_no_mask, rtol=1e-5 - ), "token_mean and sequence_mean should give same results when sequences have equal length and no mask" - # But they should give different results when mask creates different effective sequence lengths - assert not torch.allclose( - loss_token_with_mask, loss_seq_with_mask, rtol=1e-3 - ), "token_mean and sequence_mean with mask should give different results" - - -def test_policy_loss_reduction_edge_cases(): - """Tests edge cases for loss_reduction modes.""" - - device = "cpu" - - # Test with single sequence (should give same result for both modes) - advantages = torch.tensor([[1.0, -1.0, 2.0]], device=device) - old_log_probs = torch.tensor([[-1.0, -1.0, -1.0]], device=device) - log_probs = torch.tensor([[-1.5, -0.5, -1.2]], device=device) - - # Create configs for different reduction modes - config_token = AlgorithmConfig( - eps_clip_low=0.2, - eps_clip_high=0.2, - clip_ratio_c=3.0, - policy_loss_type="regular", - loss_reduction="token_mean", - max_seq_len=4, - off_policy_correction=NULL_OFF_POLICY_CORR, - ) - - config_seq = AlgorithmConfig( - eps_clip_low=0.2, - eps_clip_high=0.2, - clip_ratio_c=3.0, - policy_loss_type="regular", - loss_reduction="sequence_mean", - max_seq_len=4, - off_policy_correction=NULL_OFF_POLICY_CORR, - ) - # Get loss function - loss_fn = PolicyLossRegistry.get("regular") - - loss_token, _ = loss_fn(log_probs, old_log_probs, advantages, config_token) - loss_seq, _ = loss_fn(log_probs, old_log_probs, advantages, config_seq) - - # With single sequence, both modes should give same result - torch.testing.assert_close(loss_token, loss_seq, rtol=1e-6, atol=1e-8) - - # Test with completely masked sequence - loss_mask = torch.tensor([[0.0, 0.0, 0.0]], device=device) - loss_token_masked, _ = loss_fn(log_probs, old_log_probs, advantages, config_token, loss_mask) - loss_seq_masked, _ = loss_fn(log_probs, old_log_probs, advantages, config_seq, loss_mask) - - # Should handle zero mask gracefully (due to +1e-8 in denominator) - assert torch.isfinite(loss_token_masked) - assert torch.isfinite(loss_seq_masked) + assert actual_loss.item() == pytest.approx(-2.9930, abs=1e-4) def test_gspo_importance_sampling_levels(): @@ -348,7 +195,6 @@ def test_gspo_importance_sampling_levels(): eps_clip_high=clip_eps_high, clip_ratio_c=3.0, policy_loss_type="regular", - loss_reduction="token_mean", max_seq_len=4, off_policy_correction=NULL_OFF_POLICY_CORR, ) @@ -361,7 +207,6 @@ def test_gspo_importance_sampling_levels(): eps_clip_high=clip_eps_high, clip_ratio_c=3.0, policy_loss_type="gspo", - loss_reduction="sequence_mean", # GSPO recommended reduction max_seq_len=4, off_policy_correction=NULL_OFF_POLICY_CORR, ) @@ -374,7 +219,7 @@ def test_gspo_importance_sampling_levels(): surr1_token = ratio_token * advantages surr2_token = ratio_token.clamp(1 - clip_eps_low, 1 + clip_eps_high) * advantages loss_per_token_token = -torch.min(surr1_token, surr2_token) - expected_token = (loss_per_token_token * loss_mask).sum() / (loss_mask.sum() + 1e-8) + expected_token = (loss_per_token_token * loss_mask).sum() # Calculate token-level clipping ratio is_clipped_token = (-surr2_token > -surr1_token) & (loss_mask.bool()) @@ -390,8 +235,8 @@ def test_gspo_importance_sampling_levels(): surr1_sequence = ratio_sequence * advantages surr2_sequence = ratio_sequence.clamp(1 - clip_eps_low, 1 + clip_eps_high) * advantages loss_per_token_sequence = -torch.min(surr1_sequence, surr2_sequence) - # GSPO uses sequence_mean reduction - expected_sequence = masked_mean(loss_per_token_sequence, loss_mask, dim=-1).mean() + # GSPO uses sum reduction + expected_sequence = loss_per_token_sequence.sum() # Calculate sequence-level clipping ratio is_clipped_sequence = (-surr2_sequence > -surr1_sequence) & (loss_mask.bool()) @@ -465,7 +310,6 @@ def test_clip_cov_policy_loss(): eps_clip_low=0.2, eps_clip_high=0.2, policy_loss_type="clip_cov", - loss_reduction="token_mean", max_seq_len=4, clip_cov=ClipCovConfig(clip_ratio=0.5, clip_cov_lb=-5.0, clip_cov_ub=5.0), # Large ratio for testing off_policy_correction=NULL_OFF_POLICY_CORR, @@ -487,7 +331,6 @@ def test_clip_cov_policy_loss(): eps_clip_low=0.2, eps_clip_high=0.2, policy_loss_type="regular", - loss_reduction="token_mean", max_seq_len=4, off_policy_correction=NULL_OFF_POLICY_CORR, ) @@ -525,7 +368,6 @@ def test_kl_cov_policy_loss(): # Create KL-Cov config config = AlgorithmConfig( policy_loss_type="kl_cov", - loss_reduction="token_mean", max_seq_len=4, kl_cov=KLCovConfig(kl_cov_frac=0.5, ppo_kl_coef=1.0), # Apply KL to 50% of tokens off_policy_correction=NULL_OFF_POLICY_CORR, @@ -546,7 +388,6 @@ def test_kl_cov_policy_loss(): eps_clip_low=0.2, eps_clip_high=0.2, policy_loss_type="regular", - loss_reduction="token_mean", max_seq_len=4, use_tis=False, off_policy_correction=NULL_OFF_POLICY_CORR, @@ -574,10 +415,9 @@ def test_sapo_policy_loss_basic(): # Ratios ≈ [exp(-0.5), exp(0.2), exp(-0.1)] ≈ [0.6065, 1.2214, 0.9048] log_probs = torch.tensor([[-1.5, -0.8, -1.1]], device=device) - # SAPO config: uses sequence_mean reduction and distinct tau_pos / tau_neg + # SAPO config with distinct tau_pos / tau_neg config = AlgorithmConfig( policy_loss_type="sapo", - loss_reduction="sequence_mean", max_seq_len=4, sapo=SAPOConfig(tau_pos=1.0, tau_neg=2.0), off_policy_correction=NULL_OFF_POLICY_CORR, @@ -609,8 +449,8 @@ def gate_function(x, tau): gates = gate_function(ratio, taus) loss_per_token = -gates * advantages - # sequence_mean reduction: per-sequence token mean, then batch mean - expected_loss = loss_per_token.mean(dim=-1).mean() + # sum reduction + expected_loss = loss_per_token.sum() torch.testing.assert_close(actual_loss, expected_loss, rtol=1e-5, atol=1e-8) diff --git a/tests/train/test_trainer.py b/tests/train/test_trainer.py index faea4ea922..b252cb4825 100644 --- a/tests/train/test_trainer.py +++ b/tests/train/test_trainer.py @@ -447,7 +447,11 @@ def create_test_worker(worker_class): # Mock dependencies worker.strategy = MagicMock() worker.strategy.is_rank_0.return_value = False # Disable progress bars - worker.strategy.all_reduce.return_value = {"loss": 0.5, "lr": 1e-4} + worker.strategy.all_reduce.side_effect = lambda d, op, group=None: d # Return input dict unchanged + + # Mock device_mesh for DP group access + worker.device_mesh = MagicMock() + worker.device_mesh.get_group.return_value = None # No actual process group in tests # Always set model for all worker types worker.model = MagicMock()