Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions src/prime_rl/trainer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,23 +394,33 @@ def get_load_balance_stats(
model: nn.Module, reset_stats: bool = True, try_to_avoid_padding_experts: bool = True
) -> dict[str, Tensor | None]:
per_layer_max_vio = []
per_layer_routing_confidence = []
language_model = get_language_model(model)
for transformer_block in language_model.layers:
# This is necessary for models that have mixed dense layers
block_mlp = getattr(transformer_block, "mlp", None)
if block_mlp is None or not hasattr(block_mlp, "tokens_per_expert"):
continue
tokens_per_expert: torch.Tensor = block_mlp.tokens_per_expert
num_routed_tokens = tokens_per_expert.sum() / block_mlp.router.top_k
if try_to_avoid_padding_experts:
tokens_per_expert = tokens_per_expert.sort(dim=0, descending=True).values[block_mlp.router.top_k :]
balanced_load = tokens_per_expert.mean()
max_vio = (tokens_per_expert.max() - balanced_load) / balanced_load
per_layer_max_vio.append(max_vio.item())
per_layer_max_vio.append(max_vio.detach())

routing_confidence = block_mlp.routing_confidence_sum / num_routed_tokens
per_layer_routing_confidence.append(routing_confidence.detach())

if reset_stats:
block_mlp.tokens_per_expert.zero_()
block_mlp.routing_confidence_sum.zero_()
if len(per_layer_max_vio) == 0:
return {"max_vio": None}
return {"max_vio": torch.tensor(per_layer_max_vio, device=torch.device("cuda"))}
return {"max_vio": None, "routing_confidence": None}
return {
"max_vio": torch.stack(per_layer_max_vio),
"routing_confidence": torch.stack(per_layer_routing_confidence),
}


def get_model(
Expand Down
36 changes: 30 additions & 6 deletions src/prime_rl/trainer/models/layers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,16 @@ def init_weights(self, init_std: float):
nn.init.zeros_(self.down_proj_bias)


def _selected_probability_mass_sum(
scores: torch.Tensor, top_scores: torch.Tensor, score_func: Literal["softmax", "sigmoid"]
) -> torch.Tensor:
with torch.no_grad():
if score_func == "softmax":
return top_scores.sum()
selected_prob_mass = top_scores / (scores.sum(dim=-1, keepdim=True) + 1e-20)
return selected_prob_mass.sum(dim=-1).sum()


class TokenChoiceTopKRouter(nn.Module):
"""This class implements token-choice routing. In token-choice top-K routing, each token is
routed to top K experts based on the router scores.
Expand Down Expand Up @@ -420,7 +430,7 @@ def __init__(

def forward(
self, x: torch.Tensor, expert_bias: torch.Tensor | None = None, routed_experts: torch.Tensor | None = None
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
x (torch.Tensor): Input tensor with shape ``(bs*slen, dim)``.
Expand All @@ -429,13 +439,15 @@ def forward(
routed_experts (torch.Tensor | None, optional): Optional tensor with shape ``(bs * slen, top_k)``.

Returns:
tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
- top_scores (torch.Tensor):
Routing scores for selected experts with shape ``(bs*slen, top_k)``.
- selected_experts_indices (torch.Tensor):
Expert indices selected for each token with shape ``(bs*slen, top_k)``.
- num_tokens_per_expert (torch.Tensor):
Number of tokens assigned to each expert with shape ``(num_experts,)``.
- routing_confidence_sum (torch.Tensor):
Sum over tokens of the selected-expert probability mass before route normalization/scaling.
"""
# scores shape (bs*slen, num_experts)
assert routed_experts is None or routed_experts.shape[-1] == self.top_k, (
Expand Down Expand Up @@ -469,6 +481,8 @@ def forward(
else:
top_scores, selected_experts_indices = torch.topk(scores, k=self.top_k, dim=1)

routing_confidence_sum = _selected_probability_mass_sum(scores, top_scores, self.score_func)

if self.route_norm:
denominator = top_scores.sum(dim=-1, keepdim=True) + 1e-20
top_scores = top_scores / denominator
Expand All @@ -482,7 +496,7 @@ def forward(
max=self.num_experts,
)

return top_scores, selected_experts_indices, num_tokens_per_expert
return top_scores, selected_experts_indices, num_tokens_per_expert, routing_confidence_sum

def init_weights(self, init_std: float):
nn.init.trunc_normal_(self.gate.weight, mean=0.0, std=init_std)
Expand Down Expand Up @@ -600,6 +614,7 @@ def __init__(self, moe_args: MoEArgs, dim: int, hidden_dim: int):
torch.zeros(num_experts, dtype=torch.float32),
persistent=False,
)
self.register_buffer("routing_confidence_sum", torch.tensor(0.0, dtype=torch.float32), persistent=False)

def set_ep_comm_backend(self, backend: EPCommBackend) -> None:
self.ep_comm_backend = backend
Expand Down Expand Up @@ -718,6 +733,7 @@ def forward(
top_scores,
selected_experts_indices,
num_tokens_per_expert,
routing_confidence_sum,
) = self.router(x, self.expert_bias, routed_experts=routed_experts)

# tokens_per_expert will be used to update the expert bias for load balancing.
Expand All @@ -727,6 +743,7 @@ def forward(
# routed expert compute below.
with torch.no_grad():
self.tokens_per_expert.add_(num_tokens_per_expert)
self.routing_confidence_sum.add_(routing_confidence_sum)

if self.ep_comm_backend == "deepep":
routed_output = self._run_deepep_routed_experts(x, selected_experts_indices, top_scores)
Expand Down Expand Up @@ -774,6 +791,7 @@ def init_weights(

with torch.device(buffer_device):
self.tokens_per_expert = torch.zeros(self.experts.num_experts, dtype=torch.float32)
self.routing_confidence_sum = torch.tensor(0.0, dtype=torch.float32)
if self.load_balance_coeff is not None:
self.expert_bias = torch.zeros(self.experts.num_experts, dtype=torch.float32)

Expand Down Expand Up @@ -938,7 +956,7 @@ def __init__(

def forward(
self, x: torch.Tensor, expert_bias: torch.Tensor | None = None
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
scores = F.linear(x.float(), self.gate.float()).sigmoid()
scores_for_choice = scores + self.e_score_correction_bias

Expand All @@ -964,6 +982,7 @@ def forward(

selected_experts_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1]
top_scores = scores.gather(1, selected_experts_indices)
routing_confidence_sum = _selected_probability_mass_sum(scores, top_scores, "sigmoid")

if self.norm_topk_prob:
denominator = top_scores.sum(dim=-1, keepdim=True) + 1e-20
Expand All @@ -976,7 +995,7 @@ def forward(
max=self.num_experts,
)

return top_scores, selected_experts_indices, num_tokens_per_expert
return top_scores, selected_experts_indices, num_tokens_per_expert, routing_confidence_sum

def init_weights(self, init_std: float):
nn.init.trunc_normal_(self.gate, mean=0.0, std=init_std)
Expand Down Expand Up @@ -1067,6 +1086,7 @@ def __init__(
torch.zeros(num_experts, dtype=torch.float32),
persistent=False,
)
self.register_buffer("routing_confidence_sum", torch.tensor(0.0, dtype=torch.float32), persistent=False)

def set_ep_comm_backend(self, backend: EPCommBackend) -> None:
self.ep_comm_backend = backend
Expand Down Expand Up @@ -1161,10 +1181,13 @@ def forward(self, x: torch.Tensor, routed_experts: torch.Tensor | None = None) -
bs, slen, dim = x.shape
x_flat = x.view(-1, dim)

top_scores, selected_experts_indices, num_tokens_per_expert = self.router(x_flat, self.expert_bias)
top_scores, selected_experts_indices, num_tokens_per_expert, routing_confidence_sum = self.router(
x_flat, self.expert_bias
)

with torch.no_grad():
self.tokens_per_expert.add_(num_tokens_per_expert)
self.routing_confidence_sum.add_(routing_confidence_sum)

if self.ep_comm_backend == "deepep":
routed_output = self._run_deepep_routed_experts(x_flat, selected_experts_indices, top_scores)
Expand Down Expand Up @@ -1196,5 +1219,6 @@ def init_weights(self, init_std: float, buffer_device: torch.device):

with torch.device(buffer_device):
self.tokens_per_expert = torch.zeros(self.experts.num_experts, dtype=torch.float32)
self.routing_confidence_sum = torch.tensor(0.0, dtype=torch.float32)
if self.load_balance_coeff is not None:
self.expert_bias = torch.zeros(self.experts.num_experts, dtype=torch.float32)
4 changes: 4 additions & 0 deletions src/prime_rl/trainer/rl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,8 @@ def load_run_checkpoint(_optimizer, idx: int) -> None:
micro_step_message += f" | Mismatch KL: {tensors['mismatch_kl'][-1].mean().item():.4f}"
if "max_vio" in tensors:
micro_step_message += f" | Max Vio: {tensors['max_vio'][-1].mean().item():.4f}"
if "routing_confidence" in tensors:
micro_step_message += f" | Routing Conf.: {tensors['routing_confidence'][-1].mean().item():.4f}"
logger.debug(micro_step_message)

# Optionally, clip the gradients
Expand Down Expand Up @@ -547,6 +549,8 @@ def load_run_checkpoint(_optimizer, idx: int) -> None:
step_message += f" | LR: {current_lr:.2e} | Throughput: {throughput:.0f} tokens/s | MFU: {mfu:.1f}% | Peak Mem.: {peak_memory:.1f} GiB"
if "max_vio/mean" in tensor_stats:
step_message += f" | Max Vio: {tensor_stats['max_vio/mean']:.4f}"
if "routing_confidence/mean" in tensor_stats:
step_message += f" | Routing Conf.: {tensor_stats['routing_confidence/mean']:.4f}"
logger.success(step_message)

# Log performance metrics
Expand Down
38 changes: 27 additions & 11 deletions src/prime_rl/trainer/sft/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,15 @@ def run_validation(step: int) -> None:
step_loss_sum = torch.tensor(0.0, device="cuda")
step_local_token_count = torch.tensor(0, dtype=torch.int64, device="cuda")
nan_loss_count = torch.tensor(0, device="cuda")
batch_max_vio = torch.tensor(0.0, device="cuda")
is_moe_model = is_tt_moe_model(model)
moe_stats = (
{
"max_vio": torch.tensor(0.0, device="cuda"),
"routing_confidence": torch.tensor(0.0, device="cuda"),
}
if is_moe_model
else {}
)
for micro_step in range(grad_accum_steps):
micro_batch = next(dataiter)

Expand All @@ -406,12 +414,16 @@ def run_validation(step: int) -> None:
with maybe_record_function("backward"):
scaled_loss.backward()

if is_tt_moe_model(model):
max_vio = get_load_balance_stats(model)["max_vio"]
if max_vio is not None:
max_vio = max_vio.mean()
dist.all_reduce(max_vio, op=dist.ReduceOp.MAX)
batch_max_vio += max_vio / grad_accum_steps
if is_moe_model:
for name, values in get_load_balance_stats(model).items():
if values is None:
continue
value = values.mean()
reduce_op = dist.ReduceOp.MAX if name == "max_vio" else dist.ReduceOp.SUM
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

eheheeh not a fan of this

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm what part? The if statement?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah dont like it but also not sure how to do better

dist.all_reduce(value, op=reduce_op)
if reduce_op == dist.ReduceOp.SUM:
value /= dist.get_world_size()
moe_stats[name] += value / grad_accum_steps

forward_backward_time = time.perf_counter() - forward_backward_start_time

Expand Down Expand Up @@ -489,8 +501,11 @@ def run_validation(step: int) -> None:
if grad_norm is not None:
step_message += f" | Grad. Norm: {grad_norm:.4f}"
step_message += f" | LR: {current_lr:.2e} | Throughput: {throughput:.0f} tokens/s | MFU: {mfu:.1f}% | Peak Mem.: {peak_memory:.1f}/{max_memory:.1f} GiB ({peak_memory / max_memory * 100:.1f}%)"
if is_tt_moe_model(model) and batch_max_vio.item() > 0:
step_message += f" | Max Vio: {batch_max_vio.item():.4f}"
if is_moe_model:
for name, label in (("max_vio", "Max Vio"), ("routing_confidence", "Routing Conf.")):
value = moe_stats[name].item()
if value > 0:
step_message += f" | {label}: {value:.4f}"
logger.success(step_message)

# Log progress metrics
Expand Down Expand Up @@ -558,8 +573,9 @@ def run_validation(step: int) -> None:
disk_metrics["step"] = progress.step
monitor.log(disk_metrics, step=progress.step)

if is_tt_moe_model(model) and batch_max_vio.item() > 0:
monitor.log({"max_vio/mean": batch_max_vio.item(), "step": progress.step}, step=progress.step)
moe_log_metrics = {f"{name}/mean": value.item() for name, value in moe_stats.items() if value.item() > 0}
if moe_log_metrics:
monitor.log({**moe_log_metrics, "step": progress.step}, step=progress.step)

is_first_step = False
progress.step += 1
Expand Down
Loading