From 0966bb87c522b7b86cb986de8ee95d36c69d048b Mon Sep 17 00:00:00 2001 From: AleHD Date: Mon, 4 Mar 2024 19:13:51 +0100 Subject: [PATCH 01/13] Implemented wandb entity configuration --- src/nanotron/config/config.py | 2 ++ src/nanotron/trainer.py | 1 + 2 files changed, 3 insertions(+) diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index bf816ed1..e7d50e2b 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -149,6 +149,7 @@ class GeneralArgs: Args: project: Name of the project (a project gather several runs in common tensorboard/hub-folders) + entity: Weights and bias entity name (optional) run: Name of the run step: Global step (updated when we save the checkpoint) consumed_train_samples: Number of samples consumed during training (should be actually just step*batch_size) @@ -156,6 +157,7 @@ class GeneralArgs: """ project: str + entity: Optional[str] = None run: Optional[str] = None seed: Optional[int] = None step: Optional[int] = None diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 2f653735..68b545ed 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -254,6 +254,7 @@ def pre_training(self, *args, **kwargs): wandb.init( project=self.config.general.project, name=f"{current_time}_{self.config.general.project}_{self.config.general.run}", + entity=self.config.general.entity, config={"nanotron_config": self.config.as_dict()}, ) From 88934e8276bc92e2839c5998d700d110e150afb2 Mon Sep 17 00:00:00 2001 From: Alex Hagele Date: Thu, 11 Apr 2024 10:11:27 +0000 Subject: [PATCH 02/13] move moe in src of nanotron, update config --- src/nanotron/config/models_config.py | 3 + src/nanotron/models/llama.py | 112 ++++-- src/nanotron/models/moe.py | 578 +++++++++++++++++++++++++++ 3 files changed, 669 insertions(+), 24 deletions(-) create mode 100644 src/nanotron/models/moe.py diff --git a/src/nanotron/config/models_config.py b/src/nanotron/config/models_config.py index 610acc06..048746d3 100644 --- a/src/nanotron/config/models_config.py +++ b/src/nanotron/config/models_config.py @@ -40,6 +40,9 @@ class LlamaConfig: tie_word_embeddings: bool = False use_cache: bool = True vocab_size: int = 32000 + # MoE specific + moe_num_experts: int = 1 + num_experts_per_tok: int = 1 def __post_init__(self): # for backward compatibility diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index b930e0eb..421bdf64 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -14,8 +14,9 @@ # limitations under the License. """ PyTorch LLaMa model. """ -from typing import Dict, Optional, Union import math +from typing import Dict, Optional, Union + import torch from flash_attn import bert_padding from flash_attn.flash_attn_interface import ( @@ -31,6 +32,7 @@ from nanotron.generation.generate_store import AttachableStore from nanotron.logging import log_rank from nanotron.models import NanotronModel +from nanotron.models.moe import dMoE from nanotron.nn.activations import ACT2FN from nanotron.nn.layer_norm import TritonRMSNorm from nanotron.parallel import ParallelContext @@ -181,7 +183,12 @@ def forward(self, hidden_states): # [seq_length, batch_size, hidden_dim] class CoreAttention(nn.Module): - def __init__(self, config: LlamaConfig, parallel_config: Optional[ParallelismArgs], layer_idx: int): + def __init__( + self, + config: LlamaConfig, + parallel_config: Optional[ParallelismArgs], + layer_idx: int, + ): super().__init__() # TODO @thomasw21: GPT has a weird `d_kv` config which I'm guessing is essentically a `d_qkv` assert ( @@ -202,10 +209,28 @@ def forward( kv_sequence_mask: torch.Tensor, # torch.BoolTensor [batch_size, kv_length] (can be broadcasted to that size) ): # TODO @thomasw21: Compute once, instead of computing for each layers. - cu_seqlens_q = torch.zeros((q_sequence_mask.shape[0] + 1), dtype=torch.int32, device=query_states.device) - cu_seqlens_k = torch.zeros((kv_sequence_mask.shape[0] + 1), dtype=torch.int32, device=query_states.device) - torch.cumsum(q_sequence_mask.sum(-1, dtype=torch.int32), dim=0, dtype=torch.int32, out=cu_seqlens_q[1:]) - torch.cumsum(kv_sequence_mask.sum(-1, dtype=torch.int32), dim=0, dtype=torch.int32, out=cu_seqlens_k[1:]) + cu_seqlens_q = torch.zeros( + (q_sequence_mask.shape[0] + 1), + dtype=torch.int32, + device=query_states.device, + ) + cu_seqlens_k = torch.zeros( + (kv_sequence_mask.shape[0] + 1), + dtype=torch.int32, + device=query_states.device, + ) + torch.cumsum( + q_sequence_mask.sum(-1, dtype=torch.int32), + dim=0, + dtype=torch.int32, + out=cu_seqlens_q[1:], + ) + torch.cumsum( + kv_sequence_mask.sum(-1, dtype=torch.int32), + dim=0, + dtype=torch.int32, + out=cu_seqlens_k[1:], + ) # TODO(kunhao): flash attn's causal means that the query can only attend to the keys before it. This is not # what we want if we are using kv cache. This is a hack as we always have q_length == 1 when using kv cache. @@ -524,7 +549,7 @@ def forward( value_states, rotary_cos=None, rotary_sin=None, - # TODO @nouamane: seems like this doesnt help to indicate padding in (for first iteration it's just 0) + # TODO @nouamane: seems like this doesn't help to indicate padding in (for first iteration it's just 0) cache_seqlens=position_offsets.contiguous(), softmax_scale=None, causal=True, @@ -589,6 +614,7 @@ def __init__( self, config: LlamaConfig, parallel_config: Optional[ParallelismArgs], + parallel_context: ParallelContext, tp_pg: dist.ProcessGroup, layer_idx: int, ): @@ -602,7 +628,14 @@ def __init__( ) self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.mlp = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg) + if config.moe_num_experts > 1: + self.mlp = dMoE( + config=config, + parallel_config=parallel_config, + parallel_context=parallel_context, + ) + else: + self.mlp = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg) def forward( self, @@ -628,14 +661,19 @@ def forward( class Embedding(nn.Module, AttachableStore): - def __init__(self, tp_pg: dist.ProcessGroup, config: LlamaConfig, parallel_config: Optional[ParallelismArgs]): + def __init__( + self, + tp_pg: dist.ProcessGroup, + config: LlamaConfig, + parallel_config: Optional[ParallelismArgs], + ): super().__init__() self.token_embedding = TensorParallelEmbedding( num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, padding_idx=config.pad_token_id, pg=tp_pg, - mode=parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE, + mode=(parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE), ) self.pg = tp_pg @@ -699,6 +737,7 @@ def __init__( "config": config, "parallel_config": parallel_config, "tp_pg": parallel_context.tp_pg, + "parallel_context": parallel_context, "layer_idx": layer_idx, }, module_input_keys={"hidden_states", "sequence_mask"}, @@ -711,7 +750,10 @@ def __init__( self.final_layer_norm = PipelineBlock( p2p=self.p2p, module_builder=TritonRMSNorm, - module_kwargs={"hidden_size": config.hidden_size, "eps": config.rms_norm_eps}, + module_kwargs={ + "hidden_size": config.hidden_size, + "eps": config.rms_norm_eps, + }, module_input_keys={"input"}, module_output_keys={"hidden_states"}, ) # TODO @@ -830,7 +872,10 @@ def forward( # Megatron by defaults cast everything in fp32. `--f16-lm-cross-entropy` is an option you can use to keep current precision. # https://github.com/NVIDIA/Megatron-LM/blob/f267e6186eae1d6e2055b412b00e2e545a8e896a/megatron/model/gpt_model.py#L38 loss = sharded_cross_entropy( - sharded_logits, label_ids.transpose(0, 1).contiguous(), group=self.tp_pg, dtype=torch.float + sharded_logits, + label_ids.transpose(0, 1).contiguous(), + group=self.tp_pg, + dtype=torch.float, ).transpose(0, 1) # TODO @thomasw21: It's unclear what kind of normalization we want to do. loss = masked_mean(loss, label_mask, dtype=torch.float) @@ -848,7 +893,11 @@ def __init__( random_states: Optional[RandomStates] = None, ): super().__init__() - self.model = LlamaModel(config=config, parallel_context=parallel_context, parallel_config=parallel_config) + self.model = LlamaModel( + config=config, + parallel_context=parallel_context, + parallel_config=parallel_config, + ) self.loss = PipelineBlock( p2p=self.model.p2p, module_builder=Loss, @@ -881,7 +930,7 @@ def forward( label_mask=label_mask, )["loss"] return {"loss": loss} - + @torch.no_grad() def init_model_randomly(self, config): """Initialize model parameters randomly. @@ -898,12 +947,12 @@ def init_model_randomly(self, config): std = config.model.init_method.std sigma = config.model.init_method.std num_layers = config.model.model_config.num_hidden_layers - + for param_name, param in model.named_parameters(): assert isinstance(param, NanotronParameter) - - module_name, param_name = param_name.rsplit('.', 1) - + + module_name, param_name = param_name.rsplit(".", 1) + if param.is_tied: tied_info = param.get_tied_info() full_param_name = tied_info.get_full_name_from_module_id_to_prefix( @@ -940,25 +989,40 @@ def init_model_randomly(self, config): module.bias.zero_() else: raise ValueError(f"Who the fuck is {param_name}?") + elif isinstance(module, nn.Linear): + fan_in = None + if "weight" == param_name: + fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(module.weight) + torch.nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5)) + elif "bias" == param_name: + bound = 1 / math.sqrt(fan_in) if (fan_in is not None and fan_in > 0) else 0 + torch.nn.init.uniform_(module.bias, -bound, bound) + else: + raise ValueError(f"Who the fuck is {param_name}?") elif isinstance(module, TensorParallelEmbedding): nn.init.normal_(module.weight, mean=0.0, std=std) else: - raise Exception(f"Parameter {full_param_name} was not intialized") + raise Exception(f"Parameter {full_param_name} was not initialized") assert full_param_name not in initialized_parameters initialized_parameters.add(full_param_name) - + assert initialized_parameters == { - param.get_tied_info().get_full_name_from_module_id_to_prefix(module_id_to_prefix=module_id_to_prefix) - if param.is_tied - else name + ( + param.get_tied_info().get_full_name_from_module_id_to_prefix(module_id_to_prefix=module_id_to_prefix) + if param.is_tied + else name + ) for name, param in model.named_parameters() }, f"Somehow the initialized set of parameters don't match:\n - Expected: { {name for name, _ in model.named_parameters()} }\n - Got: {initialized_parameters}" def get_embeddings_lm_head_tied_names(self): """Get the names of the tied embeddings and lm_head weights""" if self.config.tie_word_embeddings is True: - return ["model.token_position_embeddings.pp_block.token_embedding.weight", "model.lm_head.pp_block.weight"] + return [ + "model.token_position_embeddings.pp_block.token_embedding.weight", + "model.lm_head.pp_block.weight", + ] else: return [] diff --git a/src/nanotron/models/moe.py b/src/nanotron/models/moe.py new file mode 100644 index 00000000..b4744f66 --- /dev/null +++ b/src/nanotron/models/moe.py @@ -0,0 +1,578 @@ +""" LlaMa model with MoEs""" + +import warnings +from functools import partial +from typing import Optional, Tuple + +import numpy as np +import stk +import torch +import torch.nn.functional as F +from megablocks.layers import weight_parallel as wp +from megablocks.layers.activation_fn import act_fn +from torch import nn + +from nanotron import distributed as dist +from nanotron import logging +from nanotron.config import LlamaConfig as Config +from nanotron.config import ParallelismArgs +from nanotron.nn.activations import ACT2FN +from nanotron.parallel.context import ParallelContext +from nanotron.parallel.sharded_parameters import ( + SplitConfig, + mark_all_parameters_in_module_as_sharded, +) +from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode +from nanotron.parallel.tensor_parallel.nn import ( + TensorParallelColumnLinear, + TensorParallelRowLinear, +) + +try: + import megablocks.ops as ops + from megablocks.layers.all_to_all import all_to_all +except ImportError: + warnings.warn("Please install megablocks to use MoEs: `pip install megablocks`") + + +logger = logging.get_logger(__name__) + + +class dMoE(torch.nn.Module): + def __init__( + self, + config: Config, + parallel_context: "ParallelContext", + parallel_config: Optional[ParallelismArgs], + ): + super().__init__() + self.config = config + self.tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE + if self.tp_mode == TensorParallelLinearMode.REDUCE_SCATTER: + logging.warn_once( + logger=logger, + msg="TensorParallelLinearMode.REDUCE_SCATTER is still experimental for MoEs. Use at your own risk.", + rank=0, + ) + + # Token router. + self.gate = LearnedRouter(config) + + # Expert computation helper. + self.experts = ParallelDroplessMLP( + config, + use_bias=False, + parallel_context=parallel_context, + parallel_config=parallel_config, + ) + + def forward(self, hidden_states: torch.Tensor): + """ + Args: + x: input tensor of shape [sequence_length, batch_size, hidden_size] + """ + # Compute the expert scores and assignments. + # TODO: support sequence parallelism + batch_size, sequence_length, _ = hidden_states.size() + x = hidden_states.view(-1, self.config.hidden_size) + scores, expert_weights, top_experts = self.gate(x) + + # Compute the experts. + x = self.experts(x, scores, expert_weights, top_experts) + return {"hidden_states": x.reshape(batch_size, sequence_length, -1)} + + +# Adapted from megablocks.layers.router.LearnedRouter +class LearnedRouter(torch.nn.Module): + def __init__(self, config: Config): + super().__init__() + self.layer = torch.nn.Linear(config.hidden_size, config.moe_num_experts, bias=False) + self.config = config + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + router_logits = self.layer(x) # (batch * sequence_length, n_experts) + scores = F.softmax(router_logits, dim=-1, dtype=torch.float32) # TODO: fuse? + + if self.config.num_experts_per_tok == 1: + expert_weights, expert_indices = scores.max(dim=-1, keepdim=True) + else: + expert_weights, expert_indices = torch.topk(scores, self.config.num_experts_per_tok, dim=-1) + + return scores, expert_weights, expert_indices.int() + + +# Adapted from megablocks.layers.mlp.ParallelDroplessMLP +class ParallelDroplessMLP(torch.nn.Module): + def __init__( + self, + config: Config, + use_bias: bool, + parallel_context: "ParallelContext", + parallel_config: Optional[ParallelismArgs], + ): + super().__init__() + self.config = config + self.use_bias = use_bias + + self.expert_pg_size = parallel_context.expert_pg.size() + self.expert_parallel_group = parallel_context.expert_pg + + self.hidden_sharding_degree = self.expert_pg_size // min(self.expert_pg_size, self.config.moe_num_experts) + self.experts_per_rank = self.config.moe_num_experts // min(self.expert_pg_size, self.config.moe_num_experts) + + self.num_experts = config.moe_num_experts + self.num_experts_per_tok = self.config.num_experts_per_tok + + # Calculate the number of bits needed to represent the expert indices + # so that we can pass it to radix sort. + self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1) + + if use_bias: + self.bias = torch.nn.Parameter(torch.empty(config.hidden_size)) + + # Select the forward function for the operating mode. + self.forward_fn = self.parallel_forward_once if self.expert_pg_size > 1 else self.forward_once + + self.blocking = 128 + + if self.experts_per_rank == 1: + self.mlp = MLP( + config=config, + parallel_config=parallel_config, + tp_pg=parallel_context.tp_pg, + ) + else: + self.mlp = SparseMLP( + config=config, + parallel_config=parallel_config, + parallel_context=parallel_context, + ) + + max_column_index = (self.config.intermediate_size * self.num_experts) // self.blocking + self.transpose_sort_end_bit = max(int(np.ceil(np.log2(max_column_index))), 1) + + def indices_and_bins(self, top_expert): + # Sort the expert ids to produce the scatter/gather + # indices for the permutation. + top_expert = top_expert.int() + bin_ids, indices = ops.sort(top_expert, self.sort_end_bit) + tokens_per_expert = ops.histogram(top_expert, self.num_experts) + + # Calculate the bin bounds for the sorted tokens. + bins = inclusive_cumsum(tokens_per_expert, 0) + return indices, bin_ids, bins, tokens_per_expert + + def indices_and_padded_bins(self, top_experts): + # Sort the expert ids to produce the scatter/gather + # indices for the permutation. + bin_ids, indices = ops.sort(top_experts, self.sort_end_bit) + + # Histogram the expert ids to identify the number of + # tokens routed to each expert. + tokens_per_expert = ops.histogram(top_experts, self.num_experts) + + # Round the token counts up to the block size used in + # the matrix muliplications. Calculate the starting + # position of each bin. + padded_tokens_per_expert = ops.round_up(tokens_per_expert, self.blocking) + padded_bins = inclusive_cumsum(padded_tokens_per_expert, 0) + + # Calculate the bin bounds for the sorted tokens. + bins = inclusive_cumsum(tokens_per_expert, 0) + return indices, bin_ids, bins, padded_bins, tokens_per_expert + + def forward_once(self, x, expert_weights, top_experts): # TODO: sparse + with torch.no_grad(): + ( + indices, + bin_ids, + bins, + padded_bins, + tokens_per_expert, + ) = self.indices_and_padded_bins(top_experts) + + # Route the tokens for MoE computation. + x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, self.num_experts_per_tok) + + with torch.no_grad(): + topo = self.topology(x, padded_bins) + + x = self.mlp(x, topo) + + # Un-route the data for the MoE output. + x = ops.padded_scatter( + x, + indices, + bin_ids, + expert_weights, + bins, + padded_bins, + self.num_experts_per_tok, + -1, + ) + return x, tokens_per_expert + + def parallel_forward_once(self, x, expert_weights, top_experts): + with torch.no_grad(): + indices, bin_ids, bins, tokens_per_expert = self.indices_and_bins(top_experts) + repeated_tokens_per_expert = ops.repeat(tokens_per_expert, (self.hidden_sharding_degree,)) + parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert) + tpe_handle = torch.distributed.all_to_all_single( + parallel_tokens_per_expert, + repeated_tokens_per_expert, + group=self.expert_parallel_group, + async_op=True, + ) + + x = ops.gather(x, indices, bin_ids, bins, self.num_experts_per_tok) + + # Compute the number of tokens that will be received from each + # device and permute the input data across the devices. + with torch.no_grad(): + tpe_handle.wait() + + # Reshape to [expert_pg_size, num_experts_per_rank]. + repeated_tokens_per_expert = repeated_tokens_per_expert.view(self.expert_pg_size, self.experts_per_rank) + parallel_tokens_per_expert = parallel_tokens_per_expert.view(self.expert_pg_size, self.experts_per_rank) + + send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1) + parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu() + recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1) + + # Convert the send/recv counts to lists. + send_counts = send_counts.tolist() + recv_counts = recv_counts.tolist() + tokens_received = sum(recv_counts) + + x = ops.repeat(x, (self.hidden_sharding_degree, 1)) + + # Start the cross-device permutation asynchronously so we can + # overlap communication with computation. + parallel_x, parallel_x_handle = all_to_all( + x, recv_counts, send_counts, self.expert_parallel_group, async_op=True + ) + + with torch.no_grad(): + replicate_bins = inclusive_cumsum(parallel_tokens_per_expert.flatten(), 0) + + # Construct the expert indices for the permuted tokens. + parallel_top_expert = torch.remainder( + torch.arange( + self.num_experts * self.hidden_sharding_degree, + dtype=torch.int32, + device=indices.device, + ), + self.experts_per_rank, + ) + parallel_top_expert = ops.replicate( + parallel_top_expert.unsqueeze(dim=0), replicate_bins, tokens_received + ).flatten() + + parallel_bin_ids, parallel_indices = ops.sort(parallel_top_expert, self.sort_end_bit) + + # Calculate the bins boundaries from the token counts. + parallel_tokens_per_expert = parallel_tokens_per_expert.sum(dim=0, dtype=torch.int) + parallel_bins = inclusive_cumsum(parallel_tokens_per_expert, 0) + + # Locally permute the tokens and perform the expert computation. + # Block to make sure that the cross-device permutation is complete. + parallel_x_handle.wait() + parallel_x = self.permute_and_compute( + parallel_x, + parallel_tokens_per_expert, + parallel_indices, + parallel_bin_ids, + None, # expert_weights + parallel_bins, + num_experts_per_tok=1, + ) + + # Un-permute the tokens across the devices. + x, _ = all_to_all(parallel_x, send_counts, recv_counts, self.expert_parallel_group) + + # Reduce along the hidden sharding to get the final outputs. + shape = (self.hidden_sharding_degree, -1, self.config.hidden_size) + x = ops.sum(x.view(shape), dim=0) + + # Un-permute locally to setup for the next series of operations. + x = ops.scatter( + x, + indices, + bin_ids, + expert_weights, + bins, + self.num_experts_per_tok, + ) + return x, tokens_per_expert.flatten() + + def forward(self, x, scores, expert_weights, top_experts): + """ + Args: + x: input tensor of shape [sequence_length, batch_size, hidden_size] + scores: tensor of shape [sequence_length * batch_size, n_experts] + expert_weights: tensor of shape [sequence_length * batch_size, num_experts_per_tok] + top_experts: tensor of shape [sequence_length * batch_size, num_experts_per_tok] + """ + # Compute the experts. + x, tokens_per_expert = self.forward_fn(x, expert_weights.flatten(), top_experts.flatten()) + + if self.use_bias: + return x + self.bias + return x + + def permute_and_compute( + self, + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + num_experts_per_tok, + ): + # Round the token counts up to the block size used in the matrix + # multiplication. Calculate the starting position of each bin. + padded_tokens_per_expert = ops.round_up(tokens_per_expert, self.blocking) + padded_bins = inclusive_cumsum(padded_tokens_per_expert, 0) + + # Route the tokens for MoE computation. + x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, num_experts_per_tok) + + # Perform the expert computation. + with torch.no_grad(): + topo = self.topology(x, padded_bins) + x = self.mlp(x, topo) + + # Un-route the data for the MoE output. + return ops.padded_scatter(x, indices, bin_ids, expert_weights, bins, padded_bins, num_experts_per_tok) + + def sparse_transpose(self, size, row_indices, column_indices, offsets): + block_columns = size[1] // self.blocking + _, gather_indices = ops.sort(column_indices.int(), self.transpose_sort_end_bit) + column_indices_t = row_indices.gather(0, gather_indices.long()) + block_offsets_t = gather_indices.int() + + zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device) + nnz_per_column = ops.histogram(column_indices, block_columns) + nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0) + offsets_t = torch.cat([zero, nnz_per_column]) + return column_indices_t, offsets_t, block_offsets_t + + def topology(self, x, padded_bins): + padded_tokens, _ = x.size() + assert padded_tokens % self.blocking == 0 + assert self.config.intermediate_size % self.blocking == 0 + + # Offsets for the sparse matrix. All rows have the + # same number of nonzero blocks dictated by the + # dimensionality of a single expert. + block_rows = padded_tokens // self.blocking + blocks_per_row = self.config.intermediate_size // self.blocking + offsets = torch.arange( + 0, + block_rows * blocks_per_row + 1, + blocks_per_row, + dtype=torch.int32, + device=x.device, + ) + + # Indices for the sparse matrix. The indices for + # the intermediate matrix are dynamic depending + # on the mapping of tokens to experts. + column_indices = ops.topology(padded_bins, self.blocking, block_rows, blocks_per_row) + + # TODO(tgale): This is unused. Remove the need for this in stk. + # For now, use meta init to save the device memory. + data = torch.empty( + column_indices.numel(), + self.blocking, + self.blocking, + dtype=x.dtype, + device="meta", + ) + shape = (padded_tokens, self.config.intermediate_size * self.experts_per_rank) + row_indices = stk.ops.row_indices(shape, data, offsets, column_indices) + column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose( + shape, row_indices, column_indices, offsets + ) + return stk.Matrix( + shape, + data, + row_indices, + column_indices, + offsets, + column_indices_t, + offsets_t, + block_offsets_t, + ) + + +class ScaleGradient(torch.autograd.Function): + @staticmethod + @torch.cuda.amp.custom_fwd + def forward(ctx, x, scale): + ctx.scale = scale + return x + + @staticmethod + @torch.cuda.amp.custom_bwd + def backward(ctx, grad): + return grad * ctx.scale, None + + +scale_gradient = ScaleGradient.apply + + +class ExpertParallel(nn.Module): + """ + ExpertParallel serves to scale the gradients of the expert weights because unlike DP the gradients are not averaged across the expert parallel group. + """ + + def __init__(self, module, expert_parallel_size: int): + super().__init__() + self.module = module + self.expert_parallel_size = expert_parallel_size + + def forward(self, *args, **kwargs): + self.scale_gradients() + return self.module(*args, **kwargs) + + def scale_gradients(self): + scale_gradient(self.module, 1 / self.expert_parallel_size) + + +class SparseMLP(nn.Module): + def __init__( + self, + config: Config, + parallel_config: Optional[ParallelismArgs], + parallel_context: "ParallelContext", + ): + super().__init__() + + self.expert_pg_size = parallel_config.expert_parallel_size if parallel_config is not None else 1 + self.experts_per_rank = config.moe_num_experts // min(self.expert_pg_size, config.moe_num_experts) + self.tp_pg = parallel_context.tp_pg + + self.w1 = ExpertParallel( + nn.Linear( + config.hidden_size, + config.intermediate_size * self.experts_per_rank // self.tp_pg.size(), + bias=False, + ), + expert_parallel_size=self.expert_pg_size, + ) + self.w2 = ExpertParallel( + nn.Linear( + config.hidden_size, + config.intermediate_size * self.experts_per_rank // self.tp_pg.size(), + bias=False, + ), + expert_parallel_size=self.expert_pg_size, + ) + + if self.tp_pg.size() == 1: + self.w1.module.weight.data = self.w1.module.weight.data.T.contiguous() + + # TODO @nouamane: jit + self.act = ACT2FN[config.hidden_act] + self.sdd = partial(wp.sdd_nt, group=self.tp_pg) if self.tp_pg.size() > 1 else stk.ops.sdd + self.dsd = partial(wp.dsd_nn, group=self.tp_pg) if self.tp_pg.size() > 1 else stk.ops.dsd + + def forward(self, x, topo): + self.w1.scale_gradients(), self.w2.scale_gradients() + x = self.sdd(x.contiguous(), self.w1.module.weight, topo) + activation_fn_out = act_fn(x, self.act) + return self.dsd(activation_fn_out, self.w2.module.weight) + + +class MLP(nn.Module): + def __init__( + self, + config: Config, + parallel_config: Optional[ParallelismArgs], + tp_pg: dist.ProcessGroup, + ): + super().__init__() + + tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE + tp_linear_async_communication = ( + parallel_config.tp_linear_async_communication if parallel_config is not None else False + ) + + self.expert_pg_size = parallel_config.expert_parallel_size + self.experts_per_rank = config.moe_num_experts // min(self.expert_pg_size, config.moe_num_experts) + + assert self.experts_per_rank == 1, "moe.MLP only supports 1 expert per rank, otherwise use moe.SparseMLP" + + self.w1 = ExpertParallel( + TensorParallelColumnLinear( + config.hidden_size, + config.intermediate_size * self.experts_per_rank, + pg=tp_pg, + mode=tp_mode, + bias=False, + async_communication=tp_linear_async_communication, + ), + expert_parallel_size=self.expert_pg_size, + ) + + self.w2 = ExpertParallel( + TensorParallelRowLinear( + config.intermediate_size * self.experts_per_rank, + config.hidden_size, + pg=tp_pg, + mode=tp_mode, + bias=False, + async_communication=tp_linear_async_communication + and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER, + ), + expert_parallel_size=self.expert_pg_size, + ) + # TODO @nouamane: jit + self.act = partial(F.gelu, approximate="tanh") + + def forward(self, hidden_states, topo): # [seq_length, batch_size, hidden_dim] + merged_states = self.w1(hidden_states) + hidden_states = self.w2(self.act(merged_states)) + return hidden_states + + +def inclusive_cumsum(x, dim): + scalar = ops.inclusive_cumsum(x, dim) + return scalar.view(1) if not len(scalar.size()) else scalar + + +class SparseGLU(SparseMLP): + def __init__( + self, + config: Config, + parallel_config: Optional[ParallelismArgs], + parallel_context: "ParallelContext", + ): + super().__init__(config, parallel_config, parallel_context) + self.w3 = ExpertParallel( + nn.Linear( + config.hidden_size, + config.intermediate_size * self.experts_per_rank // self.tp_pg.size(), + bias=False, + ), + expert_parallel_size=self.expert_pg_size, + ) + if self.tp_pg.size() == 1: + self.w3.module.weight.data = self.w3.module.weight.data.T.contiguous() + + mark_all_parameters_in_module_as_sharded( + self, + pg=parallel_context.tp_and_expert_pg, + split_config=SplitConfig(split_dim=0), + ) + + def forward(self, x, topo): + # We need to scale gradients manually since we don't call the linears forward + self.w1.scale_gradients(), self.w2.scale_gradients(), self.w3.scale_gradients() + x = x.contiguous() + x1 = self.sdd(x, self.w1.module.weight, topo) + x2 = self.sdd(x, self.w3.module.weight, topo) + x = stk.ops.mul(act_fn(x1, self.act), x2) + return self.dsd(x, self.w2.module.weight) From d37a66c38635a51f77eff4d2bbea42f73cbe49ed Mon Sep 17 00:00:00 2001 From: Alex Hagele Date: Thu, 11 Apr 2024 15:58:17 +0000 Subject: [PATCH 03/13] 1: switch to swiglu (llama), 2: fix router weights, 3: add load balancing loss and adapt logging --- src/nanotron/config/models_config.py | 1 + src/nanotron/models/llama.py | 18 +- src/nanotron/models/moe.py | 91 +++++++- .../parallel/pipeline_parallel/engine.py | 18 +- src/nanotron/trainer.py | 209 ++++++++++++++---- 5 files changed, 282 insertions(+), 55 deletions(-) diff --git a/src/nanotron/config/models_config.py b/src/nanotron/config/models_config.py index 048746d3..2222c1c2 100644 --- a/src/nanotron/config/models_config.py +++ b/src/nanotron/config/models_config.py @@ -43,6 +43,7 @@ class LlamaConfig: # MoE specific moe_num_experts: int = 1 num_experts_per_tok: int = 1 + moe_loss_weight: float = 0.01 def __post_init__(self): # for backward compatibility diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 421bdf64..5632583e 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -32,7 +32,11 @@ from nanotron.generation.generate_store import AttachableStore from nanotron.logging import log_rank from nanotron.models import NanotronModel -from nanotron.models.moe import dMoE +from nanotron.models.moe import ( + batched_load_balancing_loss, + clear_load_balancing_stats, + dMoE, +) from nanotron.nn.activations import ACT2FN from nanotron.nn.layer_norm import TritonRMSNorm from nanotron.parallel import ParallelContext @@ -651,8 +655,8 @@ def forward( residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states=hidden_states)["hidden_states"] - hidden_states = hidden_states + residual + mlp_output = self.mlp(hidden_states=hidden_states) + hidden_states = mlp_output["hidden_states"] + residual return { "hidden_states": hidden_states, @@ -928,8 +932,12 @@ def forward( sharded_logits=sharded_logits, label_ids=label_ids, label_mask=label_mask, - )["loss"] - return {"loss": loss} + ) + if self.config.moe_num_experts > 1: + aux_loss = batched_load_balancing_loss(self.config, self.parallel_context.pp_pg.size()) + loss["load_balancing_loss"] = aux_loss + clear_load_balancing_stats() + return loss @torch.no_grad() def init_model_randomly(self, config): diff --git a/src/nanotron/models/moe.py b/src/nanotron/models/moe.py index b4744f66..f9a253e2 100644 --- a/src/nanotron/models/moe.py +++ b/src/nanotron/models/moe.py @@ -16,6 +16,7 @@ from nanotron import logging from nanotron.config import LlamaConfig as Config from nanotron.config import ParallelismArgs +from nanotron.config.models_config import LlamaConfig from nanotron.nn.activations import ACT2FN from nanotron.parallel.context import ParallelContext from nanotron.parallel.sharded_parameters import ( @@ -38,6 +39,78 @@ logger = logging.get_logger(__name__) +_LOAD_BALANCING_LOSS = [] + + +def save_load_balancing_stats(loss): + global _LOAD_BALANCING_LOSS + _LOAD_BALANCING_LOSS.append(loss) + + +def get_load_balancing_stats(): + global _LOAD_BALANCING_LOSS + return _LOAD_BALANCING_LOSS + + +def clear_load_balancing_stats(): + global _LOAD_BALANCING_LOSS + _LOAD_BALANCING_LOSS.clear() + + +def batched_load_balancing_loss( + # from config + config: LlamaConfig, + pipeline_parallel_size: int, +): + tokens_per_expert, expert_scores = zip(*get_load_balancing_stats()) + # tokens_per_expert[i].shape = (num_experts) + # expert_scores[i].shape = (tokens, num_experts) + num_hidden_layers = config.num_hidden_layers + moe_num_experts = config.moe_num_experts + moe_loss_weight = config.moe_loss_weight + num_experts_per_token = config.num_experts_per_tok + + num_layers_per_pipeline_stage = num_hidden_layers // pipeline_parallel_size + if len(tokens_per_expert) != num_layers_per_pipeline_stage: + raise ValueError( + f"Expected {num_layers_per_pipeline_stage} token_per_experts " + f"but found {len(tokens_per_expert)}.\nnum_layers = " + f"{num_hidden_layers}\npipeline_model_parallel_size = " + f"{pipeline_parallel_size}\n" + ) + if len(expert_scores) != num_layers_per_pipeline_stage: + raise ValueError( + f"Expected {num_layers_per_pipeline_stage} expert_scores " + f"but found {len(tokens_per_expert)}.\nnum_layers = " + f"{num_hidden_layers}\npipeline_model_parallel_size = " + f"{pipeline_parallel_size}\n" + ) + + # Verify the shape of the tokens_per_expert and expert_scores tensors. + assert all(x.ndim == 1 and x.numel() == moe_num_experts for x in tokens_per_expert) + + tokens = expert_scores[0].shape[0] + assert all((x.ndim == 2 and x.shape[1] == moe_num_experts and x.shape[0] == tokens) for x in expert_scores) + + # Concatenate the contributions of each layer and convert to + # the correct types and formats for the dot product. + # TODO @haeggee: conversion to float before mean? + # expert_scores = torch.cat(expert_scores, dim=1).float().mean(dim=0) + expert_scores = torch.cat(expert_scores, dim=1).mean(dim=0) + tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype) + + expected_values = num_layers_per_pipeline_stage * moe_num_experts + assert tokens_per_expert.numel() == expected_values + assert expert_scores.numel() == expected_values + + # Calculate the total scale across all factors. + # loss_weight * num_experts / (num_layers * tokens * top_k) + scale_numerator = moe_num_experts * moe_loss_weight + scale_denominator = num_hidden_layers * tokens * num_experts_per_token + scale = scale_numerator / scale_denominator + return scale * torch.dot(tokens_per_expert, expert_scores) + + class dMoE(torch.nn.Module): def __init__( self, @@ -94,10 +167,16 @@ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Te scores = F.softmax(router_logits, dim=-1, dtype=torch.float32) # TODO: fuse? if self.config.num_experts_per_tok == 1: - expert_weights, expert_indices = scores.max(dim=-1, keepdim=True) + expert_weights, expert_indices = router_logits.max(dim=-1, keepdim=True) else: - expert_weights, expert_indices = torch.topk(scores, self.config.num_experts_per_tok, dim=-1) - + expert_weights, expert_indices = torch.topk(router_logits, self.config.num_experts_per_tok, dim=-1) + # IMPORTANT step to normalize, otherwise weights are very low + expert_weights = expert_weights / torch.norm( + expert_weights, + p=1, + dim=-1, + keepdim=True, + ) return scores, expert_weights, expert_indices.int() @@ -142,7 +221,7 @@ def __init__( tp_pg=parallel_context.tp_pg, ) else: - self.mlp = SparseMLP( + self.mlp = SparseGLU( config=config, parallel_config=parallel_config, parallel_context=parallel_context, @@ -315,6 +394,8 @@ def forward(self, x, scores, expert_weights, top_experts): """ # Compute the experts. x, tokens_per_expert = self.forward_fn(x, expert_weights.flatten(), top_experts.flatten()) + if self.training: + save_load_balancing_stats((tokens_per_expert, scores)) if self.use_bias: return x + self.bias @@ -530,7 +611,7 @@ def __init__( expert_parallel_size=self.expert_pg_size, ) # TODO @nouamane: jit - self.act = partial(F.gelu, approximate="tanh") + self.act = ACT2FN[config.hidden_act] def forward(self, hidden_states, topo): # [seq_length, batch_size, hidden_dim] merged_states = self.w1(hidden_states) diff --git a/src/nanotron/parallel/pipeline_parallel/engine.py b/src/nanotron/parallel/pipeline_parallel/engine.py index ca9df312..10f51380 100644 --- a/src/nanotron/parallel/pipeline_parallel/engine.py +++ b/src/nanotron/parallel/pipeline_parallel/engine.py @@ -2,6 +2,9 @@ from typing import Dict, Iterable, Optional, Union import torch +from torch import nn as torch_nn +from torch.nn.parallel import DistributedDataParallel + from nanotron import distributed as dist from nanotron import logging from nanotron.distributed import ProcessGroup @@ -12,8 +15,6 @@ from nanotron.parallel.pipeline_parallel.state import PipelineTrainBatchState from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer from nanotron.utils import ContextManagers -from torch import nn as torch_nn -from torch.nn.parallel import DistributedDataParallel logger = logging.get_logger(__name__) @@ -49,8 +50,12 @@ def forward( # We normalize our loss if not isinstance(output["loss"], TensorPointer): - output["loss"] = output["loss"] / self.nb_microbatches - + output = {k: v / self.nb_microbatches for k, v in output.items()} + if len(output) > 1: + output["original_loss"] = output["loss"].clone().detach() + for k, v in output.items(): + if k != "loss" and k != "original_loss": + output["loss"] += v # Add output as activations that require backward pass if not isinstance(output["loss"], TensorPointer): assert output["loss"].requires_grad @@ -65,7 +70,10 @@ def _get_fwd_context(model: torch_nn.Module): return context def backward( - self, context: ContextManagers, state: PipelineTrainBatchState, grad_accumulator: Optional[GradientAccumulator] + self, + context: ContextManagers, + state: PipelineTrainBatchState, + grad_accumulator: Optional[GradientAccumulator], ): # Increment the number of backwards state.nb_backwards += 1 diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 68b545ed..785aca76 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -125,7 +125,9 @@ def __init__( super().__init__() self.config = get_config_from_file( - config_or_config_file, config_class=config_class, model_config_class=model_config_class + config_or_config_file, + config_class=config_class, + model_config_class=model_config_class, ) self.model_config = self.config.model.model_config if model_class is not None: @@ -170,7 +172,9 @@ def __init__( # Init optimizer self.optimizer, self.grad_accumulator = init_optimizer_and_grad_accumulator( - model=self.model, optimizer_args=self.config.optimizer, parallel_context=self.parallel_context + model=self.model, + optimizer_args=self.config.optimizer, + parallel_context=self.parallel_context, ) if self.init_checkpoint_path is not None: load_optimizer( @@ -196,7 +200,8 @@ def __init__( # Define iteration start state if self.init_checkpoint_path is not None: checkpoint_metadata = load_meta( - parallel_context=self.parallel_context, root_folder=self.init_checkpoint_path + parallel_context=self.parallel_context, + root_folder=self.init_checkpoint_path, ) log_rank(str(checkpoint_metadata), logger=logger, level=logging.INFO, rank=0) self.start_iteration_step = checkpoint_metadata.metas["last_train_step"] @@ -210,10 +215,7 @@ def __init__( # Setup tensorboard write and log writers on output rank self.logger_ranks = self.parallel_context.get_global_rank( - ep_rank=0, - pp_rank=self.unwrapped_model.output_pp_rank, - dp_rank=0, - tp_rank=0 + ep_rank=0, pp_rank=self.unwrapped_model.output_pp_rank, dp_rank=0, tp_rank=0 ).flatten() self.loggerwriter = self.setup_log_writers() @@ -319,13 +321,19 @@ def clear_dataloader_from_memory(dataloader: DataLoader, stage_name: str): if dataloader is not None: self.current_dataloader = sanity_check_dataloader( - dataloader=dataloader, parallel_context=self.parallel_context, config=self.config + dataloader=dataloader, + parallel_context=self.parallel_context, + config=self.config, ) def train( self, dataloader_or_dls: Dict[ - str, Union[Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]], Tuple[Iterator, ...]] + str, + Union[ + Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]], + Tuple[Iterator, ...], + ], ], **kwargs, ) -> None: @@ -358,13 +366,17 @@ def train( self._update_dataloader_based_on_training_stages(dataloader_or_dls) # Training step - outputs, loss_avg = self.training_step(dataloader=self.current_dataloader) + outputs, loss_avg, aux_losses = self.training_step(dataloader=self.current_dataloader) # Training Logs self.consumed_train_samples += self.global_batch_size if (self.iteration_step - 1) % self.config.logging.iteration_step_info_interval == 0: - self.train_step_logs(outputs=outputs, loss_avg=loss_avg) + self.train_step_logs( + outputs=outputs, + loss_avg=loss_avg, + aux_losses=aux_losses, + ) # Checkpoint if self.iteration_step % self.config.checkpoints.checkpoint_interval == 0: @@ -377,7 +389,12 @@ def train( def training_step( self, dataloader: Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]] ) -> Tuple[Iterable[Dict], Optional[torch.Tensor]]: - before_tbi_sanity_checks(self.config, self.parallel_context, self.unwrapped_model, self.grad_accumulator) + before_tbi_sanity_checks( + self.config, + self.parallel_context, + self.unwrapped_model, + self.grad_accumulator, + ) if self.iteration_step < 5: log_memory(logger=logger) @@ -393,7 +410,12 @@ def training_step( if self.iteration_step < 5: log_memory(logger=logger) - after_tbi_sanity_checks(self.config, self.parallel_context, self.unwrapped_model, self.grad_accumulator) + after_tbi_sanity_checks( + self.config, + self.parallel_context, + self.unwrapped_model, + self.grad_accumulator, + ) if isinstance(self.model, DistributedDataParallel) and self.grad_accumulator is not None: # Wait for fp32 grads allreduce to finish to make sure grads are synced across DP @@ -437,7 +459,10 @@ def training_step( ) before_optim_step_sanity_checks( - self.config, self.parallel_context, self.unwrapped_model, self.grad_accumulator + self.config, + self.parallel_context, + self.unwrapped_model, + self.grad_accumulator, ) # Compute DP average loss and overlap with optimizer step @@ -447,7 +472,25 @@ def training_step( [output["loss"] for output in outputs] ).sum() # already divided by n_micro_batches_per_batch # sync loss across DP - handle = dist.all_reduce(loss_avg, group=self.parallel_context.dp_pg, async_op=True, op=dist.ReduceOp.AVG) + handle = dist.all_reduce( + loss_avg, + group=self.parallel_context.dp_pg, + async_op=True, + op=dist.ReduceOp.AVG, + ) + aux_losses = {} + for k in outputs[0].keys(): + if k != "loss": + aux_losses[k] = torch.stack( + [output[k] for output in outputs] + ).sum() # already divided by n_micro_batches_per_batch + # sync loss across DP + handle = dist.all_reduce( + aux_losses[k], + group=self.parallel_context.dp_pg, + async_op=True, + op=dist.ReduceOp.AVG, + ) else: loss_avg = None handle = None @@ -459,14 +502,19 @@ def training_step( # Update the learning rate self.lr_scheduler.step() - after_optim_step_sanity_checks(self.config, self.parallel_context, self.unwrapped_model, self.grad_accumulator) + after_optim_step_sanity_checks( + self.config, + self.parallel_context, + self.unwrapped_model, + self.grad_accumulator, + ) if handle is not None: handle.wait() self.post_train_step() - return outputs, loss_avg + return outputs, loss_avg, aux_losses def validation_step(self, dataloader: Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]]) -> Iterable[Dict]: outputs = self.pipeline_engine.validate_batch_iter( @@ -480,6 +528,7 @@ def train_step_logs( self, outputs: Iterable[Dict[str, Union[torch.Tensor, TensorPointer]]], loss_avg: Optional[torch.Tensor], + aux_losses: Optional[dict] = {}, ) -> None: # TODO @nouamanetazi: Megatron-LM seems to be using a barrier to report their interval time. Check if this is necessary. https://github.com/NouamaneTazi/Megatron-LM/blob/e241a96c3085b18e36c6cee1d68a8155de77b5a6/megatron/training.py#L607 dist.barrier() @@ -502,15 +551,24 @@ def train_step_logs( log_entries = [ # LogItem("consumed_samples", self.consumed_train_samples, "human_format"), # , "12d"), LogItem( - "consumed_tokens", self.consumed_train_samples * self.config.tokens.sequence_length, "human_format" + "consumed_tokens", + self.consumed_train_samples * self.config.tokens.sequence_length, + "human_format", ), # , "12d"), - LogItem("elapsed_time_per_iteration_ms", elapsed_time_per_iteration_ms, "human_format"), # , ".1f"), + LogItem( + "elapsed_time_per_iteration_ms", + elapsed_time_per_iteration_ms, + "human_format", + ), # , ".1f"), LogItem("tokens_per_sec", tokens_per_sec, "human_format"), # , "1.6E"), LogItem( - "tokens_per_sec_per_gpu", tokens_per_sec / self.parallel_context.world_pg.size(), "human_format" + "tokens_per_sec_per_gpu", + tokens_per_sec / self.parallel_context.world_pg.size(), + "human_format", ), # , "1.6E"), LogItem("global_batch_size", self.global_batch_size, "human_format"), # , "5d"), LogItem("lm_loss", loss_avg.item(), "human_format"), # , "1.6E"), + *[LogItem(k, v.item(), "human_format") for k, v in aux_losses.items()], LogItem("lr", lr, "human_format"), # , ".3E"), LogItem("model_tflops_per_gpu", model_tflops, "human_format"), # , ".2f"), LogItem("hardware_tflops_per_gpu", hardware_tflops, "human_format"), # , ".2f"), @@ -525,10 +583,14 @@ def train_step_logs( log_entries.extend( [ LogItem( - "cuda_memory_allocated", torch.cuda.memory_allocated(), "human_format" + "cuda_memory_allocated", + torch.cuda.memory_allocated(), + "human_format", ), # / 1024**2, ".2f"), LogItem( - "cuda_max_memory_reserved", torch.cuda.max_memory_reserved(), "human_format" + "cuda_max_memory_reserved", + torch.cuda.max_memory_reserved(), + "human_format", ), # / 1024**2, ".2f"), LogItem("hd_total_memory_tb", total, "human_format"), # / (2**40), ".2f"), LogItem("hd_used_memory_tb", used, "human_format"), # / (2**40), ".2f"), @@ -591,8 +653,18 @@ def init_model(self) -> Union[NanotronModel, DistributedDataParallel]: ) self.model_config.max_position_embeddings = self.config.tokens.sequence_length - log_rank("Config:\n" + pformat(self.config), logger=logger, level=logging.INFO, rank=0) - log_rank("Model Config:\n" + pformat(self.model_config), logger=logger, level=logging.INFO, rank=0) + log_rank( + "Config:\n" + pformat(self.config), + logger=logger, + level=logging.INFO, + rank=0, + ) + log_rank( + "Model Config:\n" + pformat(self.model_config), + logger=logger, + level=logging.INFO, + rank=0, + ) model = self._init_model_instance() model = self._load_model_checkpoint(model) @@ -622,9 +694,16 @@ def _load_model_checkpoint(self, model: NanotronModel) -> NanotronModel: reloaded_from_checkpoint = False if self.init_checkpoint_path is not None: # Reload from a training checkpoint - log_rank(f"Loading weights from {self.init_checkpoint_path}", logger=logger, level=logging.INFO, rank=0) + log_rank( + f"Loading weights from {self.init_checkpoint_path}", + logger=logger, + level=logging.INFO, + rank=0, + ) self.param_shard_metadata = load_weights( - model=unwrapped_model, parallel_context=self.parallel_context, root_folder=self.init_checkpoint_path + model=unwrapped_model, + parallel_context=self.parallel_context, + root_folder=self.init_checkpoint_path, ) reloaded_from_checkpoint = True if not reloaded_from_checkpoint: @@ -688,17 +767,41 @@ def _init_model( module.init_rotary_embeddings() # Mark some parameters as tied - self._mark_tied_parameters(model=model, parallel_context=parallel_context, parallel_config=parallel_config) + self._mark_tied_parameters( + model=model, + parallel_context=parallel_context, + parallel_config=parallel_config, + ) # count number of parameters num_params = sum(p.numel() for p in model.parameters()) size_params = sum(p.numel() * p.element_size() for p in model.parameters()) total_params = torch.tensor(num_params, device="cuda") total_size = torch.tensor(size_params, device="cuda") - dist.all_reduce(total_params, group=parallel_context.tp_pg, async_op=False, op=dist.ReduceOp.SUM) # TP - dist.all_reduce(total_params, group=parallel_context.pp_pg, async_op=False, op=dist.ReduceOp.SUM) # PP - dist.all_reduce(total_size, group=parallel_context.tp_pg, async_op=False, op=dist.ReduceOp.SUM) - dist.all_reduce(total_size, group=parallel_context.pp_pg, async_op=False, op=dist.ReduceOp.SUM) + dist.all_reduce( + total_params, + group=parallel_context.tp_pg, + async_op=False, + op=dist.ReduceOp.SUM, + ) # TP + dist.all_reduce( + total_params, + group=parallel_context.pp_pg, + async_op=False, + op=dist.ReduceOp.SUM, + ) # PP + dist.all_reduce( + total_size, + group=parallel_context.tp_pg, + async_op=False, + op=dist.ReduceOp.SUM, + ) + dist.all_reduce( + total_size, + group=parallel_context.pp_pg, + async_op=False, + op=dist.ReduceOp.SUM, + ) # TODO @nouamanetazi: better memory logs log_rank( @@ -778,7 +881,12 @@ def save_checkpoint(self) -> Path: checkpoint_path.mkdir(parents=True, exist_ok=True) dist.barrier(self.parallel_context.world_pg) - log_rank(f"Saving checkpoint at {checkpoint_path}", logger=logger, level=logging.WARNING, rank=0) + log_rank( + f"Saving checkpoint at {checkpoint_path}", + logger=logger, + level=logging.WARNING, + rank=0, + ) checkpoint_metadata = { "last_train_step": self.iteration_step, # TODO: @nouamanetazi: Add more metadata to the checkpoint to be able to resume dataloader states properly @@ -809,7 +917,9 @@ def save_checkpoint(self) -> Path: config=self.config, ) save_random_states( - random_states=self.random_states, parallel_context=self.parallel_context, root_folder=checkpoint_path + random_states=self.random_states, + parallel_context=self.parallel_context, + root_folder=checkpoint_path, ) with open(checkpoints_path / "latest.txt", mode="w") as fo: fo.write(f"{self.iteration_step}") @@ -830,11 +940,17 @@ def _mark_tied_parameters( parallel_context: ParallelContext, parallel_config: Optional[ParallelismArgs] = None, ): - mark_tied_parameters(model=model, parallel_context=parallel_context, parallel_config=parallel_config) + mark_tied_parameters( + model=model, + parallel_context=parallel_context, + parallel_config=parallel_config, + ) def mark_tied_parameters( - model: NanotronModel, parallel_context: ParallelContext, parallel_config: Optional[ParallelismArgs] = None + model: NanotronModel, + parallel_context: ParallelContext, + parallel_config: Optional[ParallelismArgs] = None, ): # Tie embeddings embeddings_lm_head_tied_names = model.get_embeddings_lm_head_tied_names() @@ -854,7 +970,10 @@ def mark_tied_parameters( for target in embeddings_lm_head_tied_names ] tie_parameters( - root_module=model, ties=shared_embeddings, parallel_context=parallel_context, reduce_op=dist.ReduceOp.SUM + root_module=model, + ties=shared_embeddings, + parallel_context=parallel_context, + reduce_op=dist.ReduceOp.SUM, ) # Tie custom params @@ -869,7 +988,9 @@ def mark_tied_parameters( def mark_unsharded_params_as_tied_across_tp( - model: NanotronModel, parallel_context: ParallelContext, parallel_config: "ParallelismArgs" + model: NanotronModel, + parallel_context: ParallelContext, + parallel_config: "ParallelismArgs", ): for module_name, module in model.named_modules(): for param_name, param in module.named_parameters(recurse=False): @@ -905,12 +1026,17 @@ def mark_unsharded_params_as_tied_across_tp( reduce_op = dist.ReduceOp.SUM tie_parameters( - root_module=model, ties=shared_weights, parallel_context=parallel_context, reduce_op=reduce_op + root_module=model, + ties=shared_weights, + parallel_context=parallel_context, + reduce_op=reduce_op, ) def mark_unsharded_params_as_tied_across_expert( - model: NanotronModel, parallel_context: ParallelContext, parallel_config: "ParallelismArgs" + model: NanotronModel, + parallel_context: ParallelContext, + parallel_config: "ParallelismArgs", ): for module_name, module in model.named_modules(): for param_name, param in module.named_parameters(recurse=False): @@ -939,5 +1065,8 @@ def mark_unsharded_params_as_tied_across_expert( reduce_op = None tie_parameters( - root_module=model, ties=shared_weights, parallel_context=parallel_context, reduce_op=reduce_op + root_module=model, + ties=shared_weights, + parallel_context=parallel_context, + reduce_op=reduce_op, ) From 1cc9cb31ea09e6be1efb1cf5466b98b18a3bf947 Mon Sep 17 00:00:00 2001 From: Alex Hagele Date: Sun, 14 Apr 2024 17:00:49 +0000 Subject: [PATCH 04/13] fix pipeline bug with moe load balancing loss; correct compute estimation includes moes --- src/nanotron/models/llama.py | 10 ++++-- src/nanotron/models/moe.py | 33 +++++++------------ .../parallel/pipeline_parallel/engine.py | 33 ++++++++++++------- src/nanotron/trainer.py | 6 ++-- 4 files changed, 44 insertions(+), 38 deletions(-) diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 5632583e..60f9f35c 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -823,10 +823,14 @@ def get_block_compute_costs(self): model_config = self.config d_ff = model_config.intermediate_size d_qkv = model_config.hidden_size // model_config.num_attention_heads + attention_cost = 4 * model_config.num_attention_heads * d_qkv * model_config.hidden_size + mlp_cost = 3 * d_ff * model_config.hidden_size + if model_config.moe_num_experts > 1: + mlp_cost *= model_config.num_experts_per_tok # active experts + mlp_cost += model_config.hidden_size * model_config.moe_num_experts # routing block_compute_costs = { # CausalSelfAttention (qkv proj + attn out) + MLP - LlamaDecoderLayer: 4 * model_config.num_attention_heads * d_qkv * model_config.hidden_size - + 3 * d_ff * model_config.hidden_size, + LlamaDecoderLayer: attention_cost + mlp_cost, # This is the last lm_head TensorParallelColumnLinear: model_config.vocab_size * model_config.hidden_size, } @@ -934,7 +938,7 @@ def forward( label_mask=label_mask, ) if self.config.moe_num_experts > 1: - aux_loss = batched_load_balancing_loss(self.config, self.parallel_context.pp_pg.size()) + aux_loss = batched_load_balancing_loss(self.config) loss["load_balancing_loss"] = aux_loss clear_load_balancing_stats() return loss diff --git a/src/nanotron/models/moe.py b/src/nanotron/models/moe.py index f9a253e2..a5b532ef 100644 --- a/src/nanotron/models/moe.py +++ b/src/nanotron/models/moe.py @@ -60,7 +60,6 @@ def clear_load_balancing_stats(): def batched_load_balancing_loss( # from config config: LlamaConfig, - pipeline_parallel_size: int, ): tokens_per_expert, expert_scores = zip(*get_load_balancing_stats()) # tokens_per_expert[i].shape = (num_experts) @@ -70,22 +69,6 @@ def batched_load_balancing_loss( moe_loss_weight = config.moe_loss_weight num_experts_per_token = config.num_experts_per_tok - num_layers_per_pipeline_stage = num_hidden_layers // pipeline_parallel_size - if len(tokens_per_expert) != num_layers_per_pipeline_stage: - raise ValueError( - f"Expected {num_layers_per_pipeline_stage} token_per_experts " - f"but found {len(tokens_per_expert)}.\nnum_layers = " - f"{num_hidden_layers}\npipeline_model_parallel_size = " - f"{pipeline_parallel_size}\n" - ) - if len(expert_scores) != num_layers_per_pipeline_stage: - raise ValueError( - f"Expected {num_layers_per_pipeline_stage} expert_scores " - f"but found {len(tokens_per_expert)}.\nnum_layers = " - f"{num_hidden_layers}\npipeline_model_parallel_size = " - f"{pipeline_parallel_size}\n" - ) - # Verify the shape of the tokens_per_expert and expert_scores tensors. assert all(x.ndim == 1 and x.numel() == moe_num_experts for x in tokens_per_expert) @@ -99,10 +82,6 @@ def batched_load_balancing_loss( expert_scores = torch.cat(expert_scores, dim=1).mean(dim=0) tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype) - expected_values = num_layers_per_pipeline_stage * moe_num_experts - assert tokens_per_expert.numel() == expected_values - assert expert_scores.numel() == expected_values - # Calculate the total scale across all factors. # loss_weight * num_experts / (num_layers * tokens * top_k) scale_numerator = moe_num_experts * moe_loss_weight @@ -610,6 +589,18 @@ def __init__( ), expert_parallel_size=self.expert_pg_size, ) + + self.w1 = ExpertParallel( + TensorParallelColumnLinear( + config.hidden_size, + config.intermediate_size * self.experts_per_rank, + pg=tp_pg, + mode=tp_mode, + bias=False, + async_communication=tp_linear_async_communication, + ), + expert_parallel_size=self.expert_pg_size, + ) # TODO @nouamane: jit self.act = ACT2FN[config.hidden_act] diff --git a/src/nanotron/parallel/pipeline_parallel/engine.py b/src/nanotron/parallel/pipeline_parallel/engine.py index 10f51380..7aacacd5 100644 --- a/src/nanotron/parallel/pipeline_parallel/engine.py +++ b/src/nanotron/parallel/pipeline_parallel/engine.py @@ -11,7 +11,9 @@ from nanotron.logging import log_rank from nanotron.optim.gradient_accumulator import GradientAccumulator from nanotron.parallel.data_parallel.utils import ddp_trigger_sync_in_bwd -from nanotron.parallel.pipeline_parallel.context_manager import attach_pipeline_state_to_model +from nanotron.parallel.pipeline_parallel.context_manager import ( + attach_pipeline_state_to_model, +) from nanotron.parallel.pipeline_parallel.state import PipelineTrainBatchState from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer from nanotron.utils import ContextManagers @@ -49,13 +51,16 @@ def forward( output = {"loss": output} # We normalize our loss - if not isinstance(output["loss"], TensorPointer): - output = {k: v / self.nb_microbatches for k, v in output.items()} - if len(output) > 1: - output["original_loss"] = output["loss"].clone().detach() - for k, v in output.items(): - if k != "loss" and k != "original_loss": - output["loss"] += v + for k, v in output.items(): + if not isinstance(v, TensorPointer): + output[k] = v / self.nb_microbatches + + # inside the model, we can have load balancing losses for some pipeline + # ranks without the final loss. still need to backpropagate through them + # TODO @haeggee just loop over all keys and check if they are TensorPointer? + if "load_balancing_loss" in output and not isinstance(output["load_balancing_loss"], TensorPointer): + assert output["load_balancing_loss"].requires_grad + state.register_activation_requiring_backward(output["load_balancing_loss"]) # Add output as activations that require backward pass if not isinstance(output["loss"], TensorPointer): assert output["loss"].requires_grad @@ -165,6 +170,8 @@ def validate_batch_iter( # Store the loss for each microbatch if not isinstance(output["loss"], TensorPointer): output = {k: v.detach() for k, v in output.items()} + if "load_balancing_loss" in output and not isinstance(output["load_balancing_loss"], TensorPointer): + output["load_balancing_loss"] = output["load_balancing_loss"].detach() outputs.append(output) return outputs @@ -277,8 +284,9 @@ def train_batch_iter( send_activation() # Store the loss for each microbatch - if not isinstance(output["loss"], TensorPointer): - output = {k: v.detach() for k, v in output.items()} + for k, v in output.items(): + if not isinstance(v, TensorPointer): + output[k] = v.detach() outputs.append(output) for micro_batch in batch: @@ -290,8 +298,9 @@ def train_batch_iter( output = {"loss": output} # Store the loss for each microbatch - if not isinstance(output["loss"], TensorPointer): - output = {k: v.detach() for k, v in output.items()} + for k, v in output.items(): + if not isinstance(v, TensorPointer): + output[k] = v.detach() outputs.append(output) # One backward diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 8a5ac338..660243f5 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -282,7 +282,9 @@ def _update_dataloader_based_on_training_stages(self, dataloaders: List[DataLoad else: dataloader = dataloaders self.current_dataloader = sanity_check_dataloader( - dataloader=dataloader, parallel_context=self.parallel_context, config=self.config + dataloader=dataloader, + parallel_context=self.parallel_context, + config=self.config, ) return @@ -477,6 +479,7 @@ def training_step( self.grad_accumulator, ) + aux_losses = {} # Compute DP average loss and overlap with optimizer step if isinstance(outputs[0]["loss"], torch.Tensor): # This is an average on only one data rank. @@ -490,7 +493,6 @@ def training_step( async_op=True, op=dist.ReduceOp.AVG, ) - aux_losses = {} for k in outputs[0].keys(): if k != "loss": aux_losses[k] = torch.stack( From 5a06abbf7dd41ea2a211893527bf792a526f00bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexander=20H=C3=A4gele?= Date: Thu, 2 May 2024 15:02:02 +0200 Subject: [PATCH 05/13] proper swiglu for mlp in expert parallel --- src/nanotron/models/moe.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/nanotron/models/moe.py b/src/nanotron/models/moe.py index a5b532ef..879c6885 100644 --- a/src/nanotron/models/moe.py +++ b/src/nanotron/models/moe.py @@ -590,7 +590,7 @@ def __init__( expert_parallel_size=self.expert_pg_size, ) - self.w1 = ExpertParallel( + self.w3 = ExpertParallel( TensorParallelColumnLinear( config.hidden_size, config.intermediate_size * self.experts_per_rank, @@ -606,7 +606,7 @@ def __init__( def forward(self, hidden_states, topo): # [seq_length, batch_size, hidden_dim] merged_states = self.w1(hidden_states) - hidden_states = self.w2(self.act(merged_states)) + hidden_states = self.w2(self.act(merged_states) * self.w3(hidden_states)) return hidden_states From d0bacc75eb30142939c05e9ede8ad57419fcfbc9 Mon Sep 17 00:00:00 2001 From: Alex Hagele Date: Thu, 6 Jun 2024 08:57:00 +0000 Subject: [PATCH 06/13] add nn.linear module in init parametrization (for moe router) --- src/nanotron/scaling/parametrization.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/nanotron/scaling/parametrization.py b/src/nanotron/scaling/parametrization.py index e6241651..380ad460 100644 --- a/src/nanotron/scaling/parametrization.py +++ b/src/nanotron/scaling/parametrization.py @@ -34,6 +34,7 @@ class StandardParametrizator(Parametrizator): def __init__(self, config: ModelArgs): super().__init__(config) self.MODULE_TO_PARAMETRIZE = { + nn.Linear: self._parametrize_column_linear, TensorParallelColumnLinear: self._parametrize_column_linear, TensorParallelRowLinear: self._parametrize_row_linear, TritonRMSNorm: self._parametrize_layer_norm, From 86ac964e73efeb57b263431bb14b784a23e6a878 Mon Sep 17 00:00:00 2001 From: Alex Hagele Date: Thu, 13 Jun 2024 09:25:14 +0000 Subject: [PATCH 07/13] refactor: aux loss computed in forward of each mlp -> enables correct logs bc loss forwarded through layers --- src/nanotron/models/llama.py | 57 ++++++++++++++++++++++++------------ src/nanotron/models/moe.py | 56 ++++++++++++----------------------- 2 files changed, 57 insertions(+), 56 deletions(-) diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index e929b67b..0ace114d 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -14,7 +14,7 @@ # limitations under the License. """PyTorch LLaMa model.""" -from typing import Dict, Optional, Union, List +from typing import Dict, Optional, Union import torch from torch import nn @@ -27,8 +27,6 @@ from nanotron.logging import log_rank from nanotron.models import NanotronModel from nanotron.models.moe import ( - batched_load_balancing_loss, - clear_load_balancing_stats, dMoE, ) from nanotron.nn.activations import ACT2FN @@ -45,7 +43,10 @@ TensorParallelRowLinear, ) from nanotron.random import RandomStates -from nanotron.scaling.parametrization import SpectralMupParametrizator, StandardParametrizator +from nanotron.scaling.parametrization import ( + SpectralMupParametrizator, + StandardParametrizator, +) from nanotron.utils import checkpoint_method logger = logging.get_logger(__name__) @@ -652,13 +653,16 @@ def __init__( parallel_config=parallel_config, parallel_context=parallel_context, ) + self._is_moe = True else: self.mlp = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg) + self._is_moe = False def forward( self, hidden_states: Union[torch.Tensor, TensorPointer], sequence_mask: Union[torch.Tensor, TensorPointer], + aux_loss: Union[torch.Tensor, TensorPointer], ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -672,9 +676,13 @@ def forward( mlp_output = self.mlp(hidden_states=hidden_states) hidden_states = mlp_output["hidden_states"] + residual + if self._is_moe: + aux_loss = aux_loss + mlp_output["aux_loss"] + return { "hidden_states": hidden_states, "sequence_mask": output["sequence_mask"], + "aux_loss": aux_loss, } @@ -710,7 +718,9 @@ def forward(self, input_ids: torch.Tensor, input_mask: torch.Tensor): # [batch_ # Format input in `[seq_length, batch_size]` to support high TP with low batch_size input_ids = input_ids.transpose(0, 1) input_embeds = self.token_embedding(input_ids) - return {"input_embeds": input_embeds} + return { + "input_embeds": input_embeds, + } class LlamaModel(nn.Module): @@ -758,8 +768,8 @@ def __init__( "parallel_context": parallel_context, "layer_idx": layer_idx, }, - module_input_keys={"hidden_states", "sequence_mask"}, - module_output_keys={"hidden_states", "sequence_mask"}, + module_input_keys={"hidden_states", "sequence_mask", "aux_loss"}, + module_output_keys={"hidden_states", "sequence_mask", "aux_loss"}, ) for layer_idx in range(config.num_hidden_layers) ] @@ -805,22 +815,34 @@ def forward( self, input_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] input_mask: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] + aux_loss: Union[torch.Tensor, TensorPointer], ): - return self.forward_with_hidden_states(input_ids=input_ids, input_mask=input_mask)[0] + sharded_logits, hidden_states, aux_loss = self.forward_with_hidden_states( + input_ids=input_ids, + input_mask=input_mask, + aux_loss=aux_loss, + ) + return {"sharded_logits": sharded_logits, "aux_loss": aux_loss} def forward_with_hidden_states( self, input_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] input_mask: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] + aux_loss: Union[torch.Tensor, TensorPointer], ): # all tensors are optional as most ranks don't need anything from the dataloader. - output = self.token_position_embeddings(input_ids=input_ids, input_mask=input_mask) + output = self.token_position_embeddings( + input_ids=input_ids, + input_mask=input_mask, + ) hidden_encoder_states = { "hidden_states": output["input_embeds"], "sequence_mask": input_mask, + "aux_loss": aux_loss, } + for encoder_block in self.decoder: hidden_encoder_states = encoder_block(**hidden_encoder_states) @@ -830,7 +852,7 @@ def forward_with_hidden_states( fp32_sharded_logits = self.cast_to_fp32(x=sharded_logits)["output"] - return fp32_sharded_logits, hidden_states + return fp32_sharded_logits, hidden_states, hidden_encoder_states["aux_loss"] def get_block_compute_costs(self): """Computes the compute cost of each block in the model so that we can do a better job of load balancing.""" @@ -943,19 +965,18 @@ def forward( label_ids: Union[torch.Tensor, TensorPointer], label_mask: Union[torch.Tensor, TensorPointer], ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: - sharded_logits = self.model( - input_ids=input_ids, - input_mask=input_mask, + init_zero_aux_loss = ( # aux_loss is used for load balancing + torch.zeros(1, device=input_ids.device) + if not isinstance(input_ids, TensorPointer) + else TensorPointer(self.input_pp_rank) ) + output = self.model(input_ids=input_ids, input_mask=input_mask, aux_loss=init_zero_aux_loss) loss = self.loss( - sharded_logits=sharded_logits, + sharded_logits=output["sharded_logits"], label_ids=label_ids, label_mask=label_mask, ) - if self.config.moe_num_experts > 1: - aux_loss = batched_load_balancing_loss(self.config) - loss["load_balancing_loss"] = aux_loss - clear_load_balancing_stats() + loss["load_balancing_loss"] = output["aux_loss"] return loss @torch.no_grad() diff --git a/src/nanotron/models/moe.py b/src/nanotron/models/moe.py index 879c6885..d26a67a1 100644 --- a/src/nanotron/models/moe.py +++ b/src/nanotron/models/moe.py @@ -16,7 +16,6 @@ from nanotron import logging from nanotron.config import LlamaConfig as Config from nanotron.config import ParallelismArgs -from nanotron.config.models_config import LlamaConfig from nanotron.nn.activations import ACT2FN from nanotron.parallel.context import ParallelContext from nanotron.parallel.sharded_parameters import ( @@ -39,48 +38,24 @@ logger = logging.get_logger(__name__) -_LOAD_BALANCING_LOSS = [] - - -def save_load_balancing_stats(loss): - global _LOAD_BALANCING_LOSS - _LOAD_BALANCING_LOSS.append(loss) - - -def get_load_balancing_stats(): - global _LOAD_BALANCING_LOSS - return _LOAD_BALANCING_LOSS - - -def clear_load_balancing_stats(): - global _LOAD_BALANCING_LOSS - _LOAD_BALANCING_LOSS.clear() - - -def batched_load_balancing_loss( - # from config - config: LlamaConfig, -): - tokens_per_expert, expert_scores = zip(*get_load_balancing_stats()) - # tokens_per_expert[i].shape = (num_experts) - # expert_scores[i].shape = (tokens, num_experts) +def load_balancing_loss(tokens_per_expert, expert_scores, config: Config): + # tokens_per_expert.shape = (num_experts) + # expert_scores.shape = (tokens, num_experts) num_hidden_layers = config.num_hidden_layers moe_num_experts = config.moe_num_experts moe_loss_weight = config.moe_loss_weight num_experts_per_token = config.num_experts_per_tok # Verify the shape of the tokens_per_expert and expert_scores tensors. - assert all(x.ndim == 1 and x.numel() == moe_num_experts for x in tokens_per_expert) + assert tokens_per_expert.ndim == 1 and tokens_per_expert.numel() == moe_num_experts - tokens = expert_scores[0].shape[0] - assert all((x.ndim == 2 and x.shape[1] == moe_num_experts and x.shape[0] == tokens) for x in expert_scores) + tokens = expert_scores.shape[0] + assert expert_scores.ndim == 2 and expert_scores.shape[1] == moe_num_experts - # Concatenate the contributions of each layer and convert to - # the correct types and formats for the dot product. # TODO @haeggee: conversion to float before mean? - # expert_scores = torch.cat(expert_scores, dim=1).float().mean(dim=0) - expert_scores = torch.cat(expert_scores, dim=1).mean(dim=0) - tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype) + # expert_scores = expert_scores.float().mean(dim=0) + expert_scores = expert_scores.mean(dim=0) + tokens_per_expert = tokens_per_expert.to(expert_scores.dtype) # Calculate the total scale across all factors. # loss_weight * num_experts / (num_layers * tokens * top_k) @@ -130,8 +105,11 @@ def forward(self, hidden_states: torch.Tensor): scores, expert_weights, top_experts = self.gate(x) # Compute the experts. - x = self.experts(x, scores, expert_weights, top_experts) - return {"hidden_states": x.reshape(batch_size, sequence_length, -1)} + x, aux_loss = self.experts(x, scores, expert_weights, top_experts) + return { + "hidden_states": x.reshape(batch_size, sequence_length, -1), + "aux_loss": aux_loss, + } # Adapted from megablocks.layers.router.LearnedRouter @@ -374,11 +352,13 @@ def forward(self, x, scores, expert_weights, top_experts): # Compute the experts. x, tokens_per_expert = self.forward_fn(x, expert_weights.flatten(), top_experts.flatten()) if self.training: - save_load_balancing_stats((tokens_per_expert, scores)) + aux_loss = load_balancing_loss(tokens_per_expert, scores, self.config) + else: + aux_loss = torch.zeros(1, device=x.device) if self.use_bias: return x + self.bias - return x + return x, aux_loss def permute_and_compute( self, From 620a884020b96f42f8b87ee53b1b57fde07ee730 Mon Sep 17 00:00:00 2001 From: Alex Hagele Date: Tue, 2 Jul 2024 13:33:10 +0000 Subject: [PATCH 08/13] fix generation since introduction of moes --- run_generate.py | 7 +++++-- src/nanotron/generation/decode.py | 6 ++++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/run_generate.py b/run_generate.py index f389770d..2dd592b1 100644 --- a/run_generate.py +++ b/run_generate.py @@ -57,8 +57,9 @@ def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--ckpt-path", type=Path, required=True, help="Checkpoint path") parser.add_argument("--dp", type=int, default=1) - parser.add_argument("--pp", type=int, default=0) - parser.add_argument("--tp", type=int, default=0) + parser.add_argument("--pp", type=int, default=1) + parser.add_argument("--tp", type=int, default=1) + parser.add_argument("--ep", type=int, default=1) parser.add_argument("--max-new-tokens", type=int, default=128, help="Maximum number of new tokens to generate") return parser.parse_args() @@ -79,6 +80,7 @@ def main(): pp_engine=OneForwardOneBackwardPipelineEngine(), tp_mode=TensorParallelLinearMode.ALL_REDUCE, tp_linear_async_communication=False, + expert_parallel_size=args.ep or config.parallelism.expert_parallel_size, ) # Initialise all process groups @@ -86,6 +88,7 @@ def main(): data_parallel_size=parallel_config.dp, pipeline_parallel_size=parallel_config.pp, tensor_parallel_size=parallel_config.tp, + expert_parallel_size=parallel_config.expert_parallel_size, ) # Set log levels diff --git a/src/nanotron/generation/decode.py b/src/nanotron/generation/decode.py index 6ab71fad..905c92db 100644 --- a/src/nanotron/generation/decode.py +++ b/src/nanotron/generation/decode.py @@ -257,7 +257,8 @@ def decode_text( sharded_logits = model( input_ids=state.new_input_ids, input_mask=state.new_input_mask, - ) + aux_loss=torch.zeros(1, device=state.new_input_ids.device), + )["sharded_logits"] else: if isinstance(state.new_input_ids, torch.Tensor): batch_generated_ids = torch.cat(state.generation_ids, dim=-1) @@ -268,7 +269,8 @@ def decode_text( sharded_logits = model( input_ids=batch_generated_ids, input_mask=batch_generated_mask, - ) + aux_loss=torch.zeros(1, device=state.new_input_ids.device), + )["sharded_logits"] if isinstance(sharded_logits, torch.Tensor) and logits_are_batch_first: sharded_logits = sharded_logits.transpose(0, 1) From 29df08614682963c877ac7399a19e438f5139780 Mon Sep 17 00:00:00 2001 From: Alex Hagele Date: Tue, 2 Jul 2024 13:59:11 +0000 Subject: [PATCH 09/13] fix generation in case of pipeline parallel --- run_generate.py | 6 +++--- src/nanotron/generation/decode.py | 12 ++++++++++-- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/run_generate.py b/run_generate.py index 2dd592b1..8408f80d 100644 --- a/run_generate.py +++ b/run_generate.py @@ -57,9 +57,9 @@ def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--ckpt-path", type=Path, required=True, help="Checkpoint path") parser.add_argument("--dp", type=int, default=1) - parser.add_argument("--pp", type=int, default=1) - parser.add_argument("--tp", type=int, default=1) - parser.add_argument("--ep", type=int, default=1) + parser.add_argument("--pp", type=int, default=0) + parser.add_argument("--tp", type=int, default=0) + parser.add_argument("--ep", type=int, default=0) parser.add_argument("--max-new-tokens", type=int, default=128, help="Maximum number of new tokens to generate") return parser.parse_args() diff --git a/src/nanotron/generation/decode.py b/src/nanotron/generation/decode.py index 905c92db..069743c6 100644 --- a/src/nanotron/generation/decode.py +++ b/src/nanotron/generation/decode.py @@ -257,7 +257,11 @@ def decode_text( sharded_logits = model( input_ids=state.new_input_ids, input_mask=state.new_input_mask, - aux_loss=torch.zeros(1, device=state.new_input_ids.device), + aux_loss=( + torch.zeros(1, device=state.new_input_ids.device) + if is_decoder_input_rank + else TensorPointer(decoder_input_rank) + ), )["sharded_logits"] else: if isinstance(state.new_input_ids, torch.Tensor): @@ -269,7 +273,11 @@ def decode_text( sharded_logits = model( input_ids=batch_generated_ids, input_mask=batch_generated_mask, - aux_loss=torch.zeros(1, device=state.new_input_ids.device), + aux_loss=( + torch.zeros(1, device=state.new_input_ids.device) + if is_decoder_input_rank + else TensorPointer(decoder_input_rank) + ), )["sharded_logits"] if isinstance(sharded_logits, torch.Tensor) and logits_are_batch_first: From 34ba0ab7af23e27faa9754a057500eeed097a75e Mon Sep 17 00:00:00 2001 From: Alex Hagele Date: Tue, 2 Jul 2024 16:55:59 +0000 Subject: [PATCH 10/13] fix expert weights, move lbl in log space for stability, start z-loss code --- src/nanotron/models/moe.py | 84 ++++++++++++++++++++++++++++++-------- 1 file changed, 68 insertions(+), 16 deletions(-) diff --git a/src/nanotron/models/moe.py b/src/nanotron/models/moe.py index d26a67a1..8502931e 100644 --- a/src/nanotron/models/moe.py +++ b/src/nanotron/models/moe.py @@ -38,9 +38,28 @@ logger = logging.get_logger(__name__) -def load_balancing_loss(tokens_per_expert, expert_scores, config: Config): - # tokens_per_expert.shape = (num_experts) - # expert_scores.shape = (tokens, num_experts) +def log_mean(x, dim): + return torch.logsumexp(x, dim=dim) - torch.log(torch.tensor(x.shape[dim], dtype=torch.float32)) + + +def load_balancing_loss(router_logits, tokens_per_expert, config: Config) -> torch.Tensor: + """Computes auxiliary load balancing loss as in Switch Transformer. + + See Switch Transformer (https://arxiv.org/abs/2101.03961). This function + implements the loss function presented in equations (4) - (6). It aims to + penalize those cases where the routing between experts is unbalanced. + + Args: + logits: logits assigned to each expert per token. Shape: + [batch_size * sequence_length, num_experts]. + tokens_per_expert: [num_selected_experts] + + config: Config + + Returns: + The auxiliary loss. + """ + # tokens = batch_size * sequence_length num_hidden_layers = config.num_hidden_layers moe_num_experts = config.moe_num_experts moe_loss_weight = config.moe_loss_weight @@ -49,12 +68,16 @@ def load_balancing_loss(tokens_per_expert, expert_scores, config: Config): # Verify the shape of the tokens_per_expert and expert_scores tensors. assert tokens_per_expert.ndim == 1 and tokens_per_expert.numel() == moe_num_experts - tokens = expert_scores.shape[0] - assert expert_scores.ndim == 2 and expert_scores.shape[1] == moe_num_experts + tokens = router_logits.shape[0] + assert router_logits.ndim == 2 and router_logits.shape[1] == moe_num_experts + + # compute router probability per expert in log space for numerical stability + logprobs = F.log_softmax(router_logits, dim=-1) + # take mean probability over batch + # shape [num_experts] + logprobs = log_mean(logprobs, dim=0) + expert_scores = torch.exp(logprobs) - # TODO @haeggee: conversion to float before mean? - # expert_scores = expert_scores.float().mean(dim=0) - expert_scores = expert_scores.mean(dim=0) tokens_per_expert = tokens_per_expert.to(expert_scores.dtype) # Calculate the total scale across all factors. @@ -65,6 +88,34 @@ def load_balancing_loss(tokens_per_expert, expert_scores, config: Config): return scale * torch.dot(tokens_per_expert, expert_scores) +def router_z_loss(router_logits, config: Config) -> torch.Tensor: + """ + The router z-loss was introduced in ST-MoE + (https://arxiv.org/abs/2202.08906). It encourages router logits to remain + small in an effort to improve stability. + + Args: + router_logits: [batch_size * sequence_length, num_experts] + router logits + config: Config + + Returns: + Scalar router z-loss. + """ + num_hidden_layers = config.num_hidden_layers + tokens = router_logits.shape[0] + z_loss_weight = config.moe_z_loss_weight + + log_z = torch.logsumexp(router_logits, dim=-1) + z_loss = log_z**2 + + scale_numerator = z_loss_weight + scale_denominator = num_hidden_layers * tokens + scale = scale_numerator / scale_denominator + + return scale * z_loss.sum(dim=0) + + class dMoE(torch.nn.Module): def __init__( self, @@ -102,10 +153,10 @@ def forward(self, hidden_states: torch.Tensor): # TODO: support sequence parallelism batch_size, sequence_length, _ = hidden_states.size() x = hidden_states.view(-1, self.config.hidden_size) - scores, expert_weights, top_experts = self.gate(x) + router_logits, expert_weights, top_experts = self.gate(x) # Compute the experts. - x, aux_loss = self.experts(x, scores, expert_weights, top_experts) + x, aux_loss = self.experts(x, router_logits, expert_weights, top_experts) return { "hidden_states": x.reshape(batch_size, sequence_length, -1), "aux_loss": aux_loss, @@ -124,9 +175,9 @@ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Te scores = F.softmax(router_logits, dim=-1, dtype=torch.float32) # TODO: fuse? if self.config.num_experts_per_tok == 1: - expert_weights, expert_indices = router_logits.max(dim=-1, keepdim=True) + expert_weights, expert_indices = scores.max(dim=-1, keepdim=True) else: - expert_weights, expert_indices = torch.topk(router_logits, self.config.num_experts_per_tok, dim=-1) + expert_weights, expert_indices = torch.topk(scores, self.config.num_experts_per_tok, dim=-1) # IMPORTANT step to normalize, otherwise weights are very low expert_weights = expert_weights / torch.norm( expert_weights, @@ -134,7 +185,7 @@ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Te dim=-1, keepdim=True, ) - return scores, expert_weights, expert_indices.int() + return router_logits, expert_weights, expert_indices.int() # Adapted from megablocks.layers.mlp.ParallelDroplessMLP @@ -341,18 +392,19 @@ def parallel_forward_once(self, x, expert_weights, top_experts): ) return x, tokens_per_expert.flatten() - def forward(self, x, scores, expert_weights, top_experts): + def forward(self, x, router_logits, expert_weights, top_experts): """ Args: x: input tensor of shape [sequence_length, batch_size, hidden_size] - scores: tensor of shape [sequence_length * batch_size, n_experts] + router_logits: tensor of shape [sequence_length * batch_size, n_experts] expert_weights: tensor of shape [sequence_length * batch_size, num_experts_per_tok] top_experts: tensor of shape [sequence_length * batch_size, num_experts_per_tok] """ # Compute the experts. x, tokens_per_expert = self.forward_fn(x, expert_weights.flatten(), top_experts.flatten()) if self.training: - aux_loss = load_balancing_loss(tokens_per_expert, scores, self.config) + aux_loss = load_balancing_loss(router_logits, tokens_per_expert, self.config) + # z_loss = router_z_loss(router_logits, self.config) else: aux_loss = torch.zeros(1, device=x.device) From b2420e1c7dad1048b18070a2d601cc44f5e41dcd Mon Sep 17 00:00:00 2001 From: Alex Hagele Date: Wed, 3 Jul 2024 11:39:20 +0000 Subject: [PATCH 11/13] add z-loss, make aux_losses a dict to be bit cleaner --- src/nanotron/config/models_config.py | 1 + src/nanotron/models/llama.py | 55 ++++++++++++------- src/nanotron/models/moe.py | 14 +++-- .../parallel/pipeline_parallel/engine.py | 22 +++----- 4 files changed, 54 insertions(+), 38 deletions(-) diff --git a/src/nanotron/config/models_config.py b/src/nanotron/config/models_config.py index a9bb91a8..df7eef22 100644 --- a/src/nanotron/config/models_config.py +++ b/src/nanotron/config/models_config.py @@ -55,6 +55,7 @@ class LlamaConfig: moe_num_experts: int = 1 num_experts_per_tok: int = 1 moe_loss_weight: float = 0.01 + moe_z_loss_weight: float = 0.001 def __post_init__(self): # NOTE: user don't set self._init_method, ModelArgs will set it diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 0ace114d..aa3ea067 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -662,8 +662,8 @@ def forward( self, hidden_states: Union[torch.Tensor, TensorPointer], sequence_mask: Union[torch.Tensor, TensorPointer], - aux_loss: Union[torch.Tensor, TensorPointer], - ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: + aux_losses: Dict[str, Union[torch.Tensor, TensorPointer]], + ) -> Dict[str, Union[torch.Tensor, TensorPointer, Dict[str, Union[torch.Tensor, TensorPointer]]],]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -677,12 +677,14 @@ def forward( hidden_states = mlp_output["hidden_states"] + residual if self._is_moe: - aux_loss = aux_loss + mlp_output["aux_loss"] + for key, value in mlp_output.items(): + if key != "hidden_states": + aux_losses[key] = aux_losses[key] + value return { "hidden_states": hidden_states, "sequence_mask": output["sequence_mask"], - "aux_loss": aux_loss, + "aux_losses": aux_losses, } @@ -768,8 +770,8 @@ def __init__( "parallel_context": parallel_context, "layer_idx": layer_idx, }, - module_input_keys={"hidden_states", "sequence_mask", "aux_loss"}, - module_output_keys={"hidden_states", "sequence_mask", "aux_loss"}, + module_input_keys={"hidden_states", "sequence_mask", "aux_losses"}, + module_output_keys={"hidden_states", "sequence_mask", "aux_losses"}, ) for layer_idx in range(config.num_hidden_layers) ] @@ -815,20 +817,20 @@ def forward( self, input_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] input_mask: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] - aux_loss: Union[torch.Tensor, TensorPointer], + aux_losses: Dict[str, Union[torch.Tensor, TensorPointer]], ): - sharded_logits, hidden_states, aux_loss = self.forward_with_hidden_states( + sharded_logits, hidden_states, aux_losses = self.forward_with_hidden_states( input_ids=input_ids, input_mask=input_mask, - aux_loss=aux_loss, + aux_losses=aux_losses, ) - return {"sharded_logits": sharded_logits, "aux_loss": aux_loss} + return {"sharded_logits": sharded_logits, "aux_losses": aux_losses} def forward_with_hidden_states( self, input_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] input_mask: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] - aux_loss: Union[torch.Tensor, TensorPointer], + aux_losses: Dict[str, Union[torch.Tensor, TensorPointer]], ): # all tensors are optional as most ranks don't need anything from the dataloader. @@ -840,7 +842,7 @@ def forward_with_hidden_states( hidden_encoder_states = { "hidden_states": output["input_embeds"], "sequence_mask": input_mask, - "aux_loss": aux_loss, + "aux_losses": aux_losses, } for encoder_block in self.decoder: @@ -852,7 +854,7 @@ def forward_with_hidden_states( fp32_sharded_logits = self.cast_to_fp32(x=sharded_logits)["output"] - return fp32_sharded_logits, hidden_states, hidden_encoder_states["aux_loss"] + return fp32_sharded_logits, hidden_states, hidden_encoder_states["aux_losses"] def get_block_compute_costs(self): """Computes the compute cost of each block in the model so that we can do a better job of load balancing.""" @@ -965,18 +967,33 @@ def forward( label_ids: Union[torch.Tensor, TensorPointer], label_mask: Union[torch.Tensor, TensorPointer], ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: - init_zero_aux_loss = ( # aux_loss is used for load balancing - torch.zeros(1, device=input_ids.device) - if not isinstance(input_ids, TensorPointer) - else TensorPointer(self.input_pp_rank) + # aux_losses are used for load balancing in case of MoEs + aux_losses = { + "load_balancing_loss": ( + torch.zeros(1, device=input_ids.device) + if not isinstance(input_ids, TensorPointer) + else TensorPointer(self.input_pp_rank) + ), + "z_loss": ( + torch.zeros(1, device=input_ids.device) + if not isinstance(input_ids, TensorPointer) + else TensorPointer(self.input_pp_rank) + ), + } + output = self.model( + input_ids=input_ids, + input_mask=input_mask, + aux_losses=aux_losses, ) - output = self.model(input_ids=input_ids, input_mask=input_mask, aux_loss=init_zero_aux_loss) loss = self.loss( sharded_logits=output["sharded_logits"], label_ids=label_ids, label_mask=label_mask, ) - loss["load_balancing_loss"] = output["aux_loss"] + + # add all aux_losses to the main loss dictionary + for key, value in output["aux_losses"].items(): + loss[key] = value return loss @torch.no_grad() diff --git a/src/nanotron/models/moe.py b/src/nanotron/models/moe.py index 8502931e..1b415fa3 100644 --- a/src/nanotron/models/moe.py +++ b/src/nanotron/models/moe.py @@ -156,10 +156,11 @@ def forward(self, hidden_states: torch.Tensor): router_logits, expert_weights, top_experts = self.gate(x) # Compute the experts. - x, aux_loss = self.experts(x, router_logits, expert_weights, top_experts) + x, lbl_loss, z_loss = self.experts(x, router_logits, expert_weights, top_experts) return { "hidden_states": x.reshape(batch_size, sequence_length, -1), - "aux_loss": aux_loss, + "load_balancing_loss": lbl_loss, + "z_loss": z_loss, } @@ -403,14 +404,15 @@ def forward(self, x, router_logits, expert_weights, top_experts): # Compute the experts. x, tokens_per_expert = self.forward_fn(x, expert_weights.flatten(), top_experts.flatten()) if self.training: - aux_loss = load_balancing_loss(router_logits, tokens_per_expert, self.config) - # z_loss = router_z_loss(router_logits, self.config) + lbl_loss = load_balancing_loss(router_logits, tokens_per_expert, self.config) + z_loss = router_z_loss(router_logits, self.config) else: - aux_loss = torch.zeros(1, device=x.device) + lbl_loss = torch.zeros(1, device=x.device) + z_loss = torch.zeros(1, device=x.device) if self.use_bias: return x + self.bias - return x, aux_loss + return x, lbl_loss, z_loss def permute_and_compute( self, diff --git a/src/nanotron/parallel/pipeline_parallel/engine.py b/src/nanotron/parallel/pipeline_parallel/engine.py index 7aacacd5..d4859d35 100644 --- a/src/nanotron/parallel/pipeline_parallel/engine.py +++ b/src/nanotron/parallel/pipeline_parallel/engine.py @@ -55,16 +55,14 @@ def forward( if not isinstance(v, TensorPointer): output[k] = v / self.nb_microbatches - # inside the model, we can have load balancing losses for some pipeline - # ranks without the final loss. still need to backpropagate through them - # TODO @haeggee just loop over all keys and check if they are TensorPointer? - if "load_balancing_loss" in output and not isinstance(output["load_balancing_loss"], TensorPointer): - assert output["load_balancing_loss"].requires_grad - state.register_activation_requiring_backward(output["load_balancing_loss"]) - # Add output as activations that require backward pass - if not isinstance(output["loss"], TensorPointer): - assert output["loss"].requires_grad - state.register_activation_requiring_backward(output["loss"]) + # the outputs are either + # - token prediction loss ["loss"] + # - auxiliary losses ["load_balancing_loss", "z_loss"] + # that we need to backpropagate through, so register activations + for loss_key, output_tensor in output.items(): + if not isinstance(output_tensor, TensorPointer): + assert output_tensor.requires_grad + state.register_activation_requiring_backward(output_tensor) return output @staticmethod @@ -167,11 +165,9 @@ def validate_batch_iter( if not isinstance(output, dict): output = {"loss": output} - # Store the loss for each microbatch + # Store the loss(es) for each microbatch if not isinstance(output["loss"], TensorPointer): output = {k: v.detach() for k, v in output.items()} - if "load_balancing_loss" in output and not isinstance(output["load_balancing_loss"], TensorPointer): - output["load_balancing_loss"] = output["load_balancing_loss"].detach() outputs.append(output) return outputs From 7c37b69477e32cd69cd9f2ceec28b994b962e4b9 Mon Sep 17 00:00:00 2001 From: Alex Hagele Date: Wed, 3 Jul 2024 13:41:35 +0000 Subject: [PATCH 12/13] forgot division by num_experts for z-loss --- src/nanotron/models/moe.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/nanotron/models/moe.py b/src/nanotron/models/moe.py index 1b415fa3..75c9e3a8 100644 --- a/src/nanotron/models/moe.py +++ b/src/nanotron/models/moe.py @@ -103,14 +103,18 @@ def router_z_loss(router_logits, config: Config) -> torch.Tensor: Scalar router z-loss. """ num_hidden_layers = config.num_hidden_layers + moe_num_experts = config.moe_num_experts + tokens = router_logits.shape[0] + assert router_logits.ndim == 2 and router_logits.shape[1] == moe_num_experts + z_loss_weight = config.moe_z_loss_weight log_z = torch.logsumexp(router_logits, dim=-1) z_loss = log_z**2 scale_numerator = z_loss_weight - scale_denominator = num_hidden_layers * tokens + scale_denominator = num_hidden_layers * tokens * moe_num_experts scale = scale_numerator / scale_denominator return scale * z_loss.sum(dim=0) From 3cde7de6f264ace4f88aa420e98fd1ce2f19a309 Mon Sep 17 00:00:00 2001 From: Alex Hagele Date: Wed, 3 Jul 2024 13:45:29 +0000 Subject: [PATCH 13/13] fix bug of dummy aux_loss activations registered for backward in case of dense model --- src/nanotron/models/llama.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index aa3ea067..c106e222 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -992,8 +992,9 @@ def forward( ) # add all aux_losses to the main loss dictionary - for key, value in output["aux_losses"].items(): - loss[key] = value + if self.config.moe_num_experts > 1: + for key, value in output["aux_losses"].items(): + loss[key] = value return loss @torch.no_grad()