diff --git a/src/prime_rl/trainer/model.py b/src/prime_rl/trainer/model.py index 910a978a66..679c52cad1 100644 --- a/src/prime_rl/trainer/model.py +++ b/src/prime_rl/trainer/model.py @@ -394,6 +394,7 @@ def get_load_balance_stats( model: nn.Module, reset_stats: bool = True, try_to_avoid_padding_experts: bool = True ) -> dict[str, Tensor | None]: per_layer_max_vio = [] + per_layer_routing_confidence = [] language_model = get_language_model(model) for transformer_block in language_model.layers: # This is necessary for models that have mixed dense layers @@ -401,16 +402,25 @@ def get_load_balance_stats( if block_mlp is None or not hasattr(block_mlp, "tokens_per_expert"): continue tokens_per_expert: torch.Tensor = block_mlp.tokens_per_expert + num_routed_tokens = tokens_per_expert.sum() / block_mlp.router.top_k if try_to_avoid_padding_experts: tokens_per_expert = tokens_per_expert.sort(dim=0, descending=True).values[block_mlp.router.top_k :] balanced_load = tokens_per_expert.mean() max_vio = (tokens_per_expert.max() - balanced_load) / balanced_load - per_layer_max_vio.append(max_vio.item()) + per_layer_max_vio.append(max_vio.detach()) + + routing_confidence = block_mlp.routing_confidence_sum / num_routed_tokens + per_layer_routing_confidence.append(routing_confidence.detach()) + if reset_stats: block_mlp.tokens_per_expert.zero_() + block_mlp.routing_confidence_sum.zero_() if len(per_layer_max_vio) == 0: - return {"max_vio": None} - return {"max_vio": torch.tensor(per_layer_max_vio, device=torch.device("cuda"))} + return {"max_vio": None, "routing_confidence": None} + return { + "max_vio": torch.stack(per_layer_max_vio), + "routing_confidence": torch.stack(per_layer_routing_confidence), + } def get_model( diff --git a/src/prime_rl/trainer/models/layers/moe.py b/src/prime_rl/trainer/models/layers/moe.py index c59a720539..14d46b2f89 100644 --- a/src/prime_rl/trainer/models/layers/moe.py +++ b/src/prime_rl/trainer/models/layers/moe.py @@ -387,6 +387,16 @@ def init_weights(self, init_std: float): nn.init.zeros_(self.down_proj_bias) +def _selected_probability_mass_sum( + scores: torch.Tensor, top_scores: torch.Tensor, score_func: Literal["softmax", "sigmoid"] +) -> torch.Tensor: + with torch.no_grad(): + if score_func == "softmax": + return top_scores.sum() + selected_prob_mass = top_scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) + return selected_prob_mass.sum(dim=-1).sum() + + class TokenChoiceTopKRouter(nn.Module): """This class implements token-choice routing. In token-choice top-K routing, each token is routed to top K experts based on the router scores. @@ -420,7 +430,7 @@ def __init__( def forward( self, x: torch.Tensor, expert_bias: torch.Tensor | None = None, routed_experts: torch.Tensor | None = None - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Args: x (torch.Tensor): Input tensor with shape ``(bs*slen, dim)``. @@ -429,13 +439,15 @@ def forward( routed_experts (torch.Tensor | None, optional): Optional tensor with shape ``(bs * slen, top_k)``. Returns: - tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - top_scores (torch.Tensor): Routing scores for selected experts with shape ``(bs*slen, top_k)``. - selected_experts_indices (torch.Tensor): Expert indices selected for each token with shape ``(bs*slen, top_k)``. - num_tokens_per_expert (torch.Tensor): Number of tokens assigned to each expert with shape ``(num_experts,)``. + - routing_confidence_sum (torch.Tensor): + Sum over tokens of the selected-expert probability mass before route normalization/scaling. """ # scores shape (bs*slen, num_experts) assert routed_experts is None or routed_experts.shape[-1] == self.top_k, ( @@ -469,6 +481,8 @@ def forward( else: top_scores, selected_experts_indices = torch.topk(scores, k=self.top_k, dim=1) + routing_confidence_sum = _selected_probability_mass_sum(scores, top_scores, self.score_func) + if self.route_norm: denominator = top_scores.sum(dim=-1, keepdim=True) + 1e-20 top_scores = top_scores / denominator @@ -482,7 +496,7 @@ def forward( max=self.num_experts, ) - return top_scores, selected_experts_indices, num_tokens_per_expert + return top_scores, selected_experts_indices, num_tokens_per_expert, routing_confidence_sum def init_weights(self, init_std: float): nn.init.trunc_normal_(self.gate.weight, mean=0.0, std=init_std) @@ -600,6 +614,7 @@ def __init__(self, moe_args: MoEArgs, dim: int, hidden_dim: int): torch.zeros(num_experts, dtype=torch.float32), persistent=False, ) + self.register_buffer("routing_confidence_sum", torch.tensor(0.0, dtype=torch.float32), persistent=False) def set_ep_comm_backend(self, backend: EPCommBackend) -> None: self.ep_comm_backend = backend @@ -718,6 +733,7 @@ def forward( top_scores, selected_experts_indices, num_tokens_per_expert, + routing_confidence_sum, ) = self.router(x, self.expert_bias, routed_experts=routed_experts) # tokens_per_expert will be used to update the expert bias for load balancing. @@ -727,6 +743,7 @@ def forward( # routed expert compute below. with torch.no_grad(): self.tokens_per_expert.add_(num_tokens_per_expert) + self.routing_confidence_sum.add_(routing_confidence_sum) if self.ep_comm_backend == "deepep": routed_output = self._run_deepep_routed_experts(x, selected_experts_indices, top_scores) @@ -774,6 +791,7 @@ def init_weights( with torch.device(buffer_device): self.tokens_per_expert = torch.zeros(self.experts.num_experts, dtype=torch.float32) + self.routing_confidence_sum = torch.tensor(0.0, dtype=torch.float32) if self.load_balance_coeff is not None: self.expert_bias = torch.zeros(self.experts.num_experts, dtype=torch.float32) @@ -938,7 +956,7 @@ def __init__( def forward( self, x: torch.Tensor, expert_bias: torch.Tensor | None = None - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: scores = F.linear(x.float(), self.gate.float()).sigmoid() scores_for_choice = scores + self.e_score_correction_bias @@ -964,6 +982,7 @@ def forward( selected_experts_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1] top_scores = scores.gather(1, selected_experts_indices) + routing_confidence_sum = _selected_probability_mass_sum(scores, top_scores, "sigmoid") if self.norm_topk_prob: denominator = top_scores.sum(dim=-1, keepdim=True) + 1e-20 @@ -976,7 +995,7 @@ def forward( max=self.num_experts, ) - return top_scores, selected_experts_indices, num_tokens_per_expert + return top_scores, selected_experts_indices, num_tokens_per_expert, routing_confidence_sum def init_weights(self, init_std: float): nn.init.trunc_normal_(self.gate, mean=0.0, std=init_std) @@ -1067,6 +1086,7 @@ def __init__( torch.zeros(num_experts, dtype=torch.float32), persistent=False, ) + self.register_buffer("routing_confidence_sum", torch.tensor(0.0, dtype=torch.float32), persistent=False) def set_ep_comm_backend(self, backend: EPCommBackend) -> None: self.ep_comm_backend = backend @@ -1161,10 +1181,13 @@ def forward(self, x: torch.Tensor, routed_experts: torch.Tensor | None = None) - bs, slen, dim = x.shape x_flat = x.view(-1, dim) - top_scores, selected_experts_indices, num_tokens_per_expert = self.router(x_flat, self.expert_bias) + top_scores, selected_experts_indices, num_tokens_per_expert, routing_confidence_sum = self.router( + x_flat, self.expert_bias + ) with torch.no_grad(): self.tokens_per_expert.add_(num_tokens_per_expert) + self.routing_confidence_sum.add_(routing_confidence_sum) if self.ep_comm_backend == "deepep": routed_output = self._run_deepep_routed_experts(x_flat, selected_experts_indices, top_scores) @@ -1196,5 +1219,6 @@ def init_weights(self, init_std: float, buffer_device: torch.device): with torch.device(buffer_device): self.tokens_per_expert = torch.zeros(self.experts.num_experts, dtype=torch.float32) + self.routing_confidence_sum = torch.tensor(0.0, dtype=torch.float32) if self.load_balance_coeff is not None: self.expert_bias = torch.zeros(self.experts.num_experts, dtype=torch.float32) diff --git a/src/prime_rl/trainer/rl/train.py b/src/prime_rl/trainer/rl/train.py index 4b2b932297..ec3f9bfbeb 100644 --- a/src/prime_rl/trainer/rl/train.py +++ b/src/prime_rl/trainer/rl/train.py @@ -494,6 +494,8 @@ def load_run_checkpoint(_optimizer, idx: int) -> None: micro_step_message += f" | Mismatch KL: {tensors['mismatch_kl'][-1].mean().item():.4f}" if "max_vio" in tensors: micro_step_message += f" | Max Vio: {tensors['max_vio'][-1].mean().item():.4f}" + if "routing_confidence" in tensors: + micro_step_message += f" | Routing Conf.: {tensors['routing_confidence'][-1].mean().item():.4f}" logger.debug(micro_step_message) # Optionally, clip the gradients @@ -547,6 +549,8 @@ def load_run_checkpoint(_optimizer, idx: int) -> None: step_message += f" | LR: {current_lr:.2e} | Throughput: {throughput:.0f} tokens/s | MFU: {mfu:.1f}% | Peak Mem.: {peak_memory:.1f} GiB" if "max_vio/mean" in tensor_stats: step_message += f" | Max Vio: {tensor_stats['max_vio/mean']:.4f}" + if "routing_confidence/mean" in tensor_stats: + step_message += f" | Routing Conf.: {tensor_stats['routing_confidence/mean']:.4f}" logger.success(step_message) # Log performance metrics diff --git a/src/prime_rl/trainer/sft/train.py b/src/prime_rl/trainer/sft/train.py index 1c12b342ee..1da7fe60af 100644 --- a/src/prime_rl/trainer/sft/train.py +++ b/src/prime_rl/trainer/sft/train.py @@ -381,7 +381,15 @@ def run_validation(step: int) -> None: step_loss_sum = torch.tensor(0.0, device="cuda") step_local_token_count = torch.tensor(0, dtype=torch.int64, device="cuda") nan_loss_count = torch.tensor(0, device="cuda") - batch_max_vio = torch.tensor(0.0, device="cuda") + is_moe_model = is_tt_moe_model(model) + moe_stats = ( + { + "max_vio": torch.tensor(0.0, device="cuda"), + "routing_confidence": torch.tensor(0.0, device="cuda"), + } + if is_moe_model + else {} + ) for micro_step in range(grad_accum_steps): micro_batch = next(dataiter) @@ -406,12 +414,16 @@ def run_validation(step: int) -> None: with maybe_record_function("backward"): scaled_loss.backward() - if is_tt_moe_model(model): - max_vio = get_load_balance_stats(model)["max_vio"] - if max_vio is not None: - max_vio = max_vio.mean() - dist.all_reduce(max_vio, op=dist.ReduceOp.MAX) - batch_max_vio += max_vio / grad_accum_steps + if is_moe_model: + for name, values in get_load_balance_stats(model).items(): + if values is None: + continue + value = values.mean() + reduce_op = dist.ReduceOp.MAX if name == "max_vio" else dist.ReduceOp.SUM + dist.all_reduce(value, op=reduce_op) + if reduce_op == dist.ReduceOp.SUM: + value /= dist.get_world_size() + moe_stats[name] += value / grad_accum_steps forward_backward_time = time.perf_counter() - forward_backward_start_time @@ -489,8 +501,11 @@ def run_validation(step: int) -> None: if grad_norm is not None: step_message += f" | Grad. Norm: {grad_norm:.4f}" step_message += f" | LR: {current_lr:.2e} | Throughput: {throughput:.0f} tokens/s | MFU: {mfu:.1f}% | Peak Mem.: {peak_memory:.1f}/{max_memory:.1f} GiB ({peak_memory / max_memory * 100:.1f}%)" - if is_tt_moe_model(model) and batch_max_vio.item() > 0: - step_message += f" | Max Vio: {batch_max_vio.item():.4f}" + if is_moe_model: + for name, label in (("max_vio", "Max Vio"), ("routing_confidence", "Routing Conf.")): + value = moe_stats[name].item() + if value > 0: + step_message += f" | {label}: {value:.4f}" logger.success(step_message) # Log progress metrics @@ -558,8 +573,9 @@ def run_validation(step: int) -> None: disk_metrics["step"] = progress.step monitor.log(disk_metrics, step=progress.step) - if is_tt_moe_model(model) and batch_max_vio.item() > 0: - monitor.log({"max_vio/mean": batch_max_vio.item(), "step": progress.step}, step=progress.step) + moe_log_metrics = {f"{name}/mean": value.item() for name, value in moe_stats.items() if value.item() > 0} + if moe_log_metrics: + monitor.log({**moe_log_metrics, "step": progress.step}, step=progress.step) is_first_step = False progress.step += 1