diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index bf90853d..f5d23031 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -128,6 +128,7 @@ def _sample(self) -> None: # Calculate basic stats. documents_per_epoch = document_sizes.numel() tokens_per_epoch = document_sizes.sum().item() + # TODO MTP: Produce more labels to provide labels for the multi-token prediction heads? # We produce sequences of length `self._sequence_length + 1` so the last token has a label, # but we also include that last label in the following sample, # so we need `sequence_length * num_samples + 1` tokens in total. diff --git a/fast_llm/engine/checkpoint/external.py b/fast_llm/engine/checkpoint/external.py index 83514c86..b12c1806 100644 --- a/fast_llm/engine/checkpoint/external.py +++ b/fast_llm/engine/checkpoint/external.py @@ -141,12 +141,16 @@ def import_weight( return weight -class IgnoreWeightConverter(WeightConverter): +class IgnoreImportWeightConverter(WeightConverter): + def __post_init__(self): + Assert.eq(len(self.fast_llm_name), 0) + Assert.gt(len(self.export_name), 0) + def export_weight( self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: raise RuntimeError( - f"IgnoreWeightConverter should not be used for export: {self.fast_llm_name}, {self.export_name}" + f"IgnoreImportWeightConverter should not be used for export: {self.fast_llm_name}, {self.export_name}" ) def import_weight( @@ -155,6 +159,24 @@ def import_weight( return () +class IgnoreExportWeightConverter(WeightConverter): + def __post_init__(self): + Assert.gt(len(self.fast_llm_name), 0) + Assert.eq(len(self.export_name), 0) + + def export_weight( + self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] + ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: + return () + + def import_weight( + self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] + ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: + raise RuntimeError( + f"IgnoreExportWeightConverter should not be used for import: {self.fast_llm_name}, {self.export_name}" + ) + + class CopyWeightConverter(WeightConverter): def export_weight( self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] @@ -198,7 +220,9 @@ def __init__(self, model: "FastLLMModel"): if weight_converter.fast_llm_name } self._import_converters = { - weight_converter.export_name[0]: weight_converter for weight_converter in weight_converters + weight_converter.export_name[0]: weight_converter + for weight_converter in weight_converters + if weight_converter.export_name } @classmethod diff --git a/fast_llm/engine/checkpoint/state_dict.py b/fast_llm/engine/checkpoint/state_dict.py index f9cfe237..5288d49f 100644 --- a/fast_llm/engine/checkpoint/state_dict.py +++ b/fast_llm/engine/checkpoint/state_dict.py @@ -56,7 +56,9 @@ def save(self, config: CheckpointSaveConfig, metadata: CheckpointMetadata) -> No saver.add_tensor(self._get_key(exported_name, shard_name), exported_tensor) for shard_name, shard_state_dict in state_dict.items(): - assert not shard_state_dict, (shard_name, list(state_dict)) + assert ( + not shard_state_dict + ), f"Un-handled entries after conversion: {({k: list(v) for k, v in state_dict.items()})}" index = saver.finalize() if self._model.config.distributed.rank == 0: @@ -90,7 +92,7 @@ def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> No context.mark_as_loaded(loaded, (parameter_name, shard_name)) for shard_name, shard_state_dict in state_dict.items(): - assert not shard_state_dict, (shard_name, list(state_dict)) + assert not shard_state_dict, (shard_name, list(shard_state_dict)) @classmethod @abc.abstractmethod diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 8e3a467c..3bd79603 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -22,6 +22,12 @@ class LanguageModelLossNames: language_model_loss = "language_model_loss" z_loss = "z_loss" + @staticmethod + def multi_token_prediction_loss(index: int) -> str: + if index == 0: + return LanguageModelLossNames.language_model_loss + return f"language_model_loss_{index}" + class LanguageModelKwargs: position_ids = "position_ids" @@ -57,6 +63,12 @@ class LanguageModelArchitectureConfig(BaseModelArchitectureConfig): tie_word_embeddings: bool = Field( default=True, desc="Tie the output weights (logits) with the vocabulary embedding.", hint=FieldHint.core ) + prediction_heads: int = Field( + default=1, + desc="Number of multi-token prediction heads.", + hint=FieldHint.feature, + valid=check_field(Assert.gt, 0), + ) def _validate(self) -> None: if self.use_position_embeddings is None: diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index efca95b4..1286121c 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -13,7 +13,7 @@ from fast_llm.functional.config import CrossEntropyImpl, TritonConfig from fast_llm.functional.cross_entropy import cross_entropy_forward_backward from fast_llm.functional.linear import output_parallel_linear_backward, output_parallel_linear_forward -from fast_llm.layers.common.auxiliary_loss import z_loss +from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss from fast_llm.layers.language_model.config import ( LanguageModelBaseConfig, LanguageModelDimNames, @@ -24,7 +24,9 @@ from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs from fast_llm.logging import log_distributed_tensor from fast_llm.tensor import ParameterMeta, TensorMeta, init_normal_ -from fast_llm.utils import div +from fast_llm.utils import Assert, div + +OUTPUT_WEIGHTS = "output_weights" class LanguageModelHead[ConfigType: LanguageModelBaseConfig](Configurable[LanguageModelBaseConfig], Layer): @@ -38,6 +40,7 @@ def __init__( self, config: LanguageModelBaseConfig, tensor_space: TensorSpace, + prediction_distance: int, ): super().__init__(config) self._debug_transformer = config.transformer.debug_transformer @@ -56,23 +59,24 @@ def __init__( hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) + self._loss_name = LanguageModelLossNames.multi_token_prediction_loss(prediction_distance) self.final_norm = config.transformer.normalization.get_layer(hidden_dim) self._logits_scale_factor = config.logits_scale_factor self._z_loss_factor = config.logit_z_loss - # untie embedding weights - if not self._tie_word_embeddings: - vocab_dim = self._tensor_space.get_tensor_dim( - LanguageModelDimNames.vocab_tp if self._parallel_embeddings else LanguageModelDimNames.vocab - ) - self.output_weights = ParameterMeta.from_dims( - (vocab_dim, hidden_dim), - init_method=init_normal_( - std=config.init_method_std_embed, - min_val=config.init_method_min_embed, - max_val=config.init_method_max_embed, - ), - ) + # Distance of the target token prediction + # 0: next-token prediction + # >0: multi-token prediction (MTP) + Assert.geq(prediction_distance, 0) + self._prediction_distance = prediction_distance + self.is_last_head = self._prediction_distance == config.prediction_heads - 1 + if self._prediction_distance > 0: + assert ( + not self._sequence_parallel_logits + ), "Sequence parallel logits not supported for multi-token prediction." + assert not self._cross_entropy_splits, "Cross-entropy splits not supported for multi-token prediction." + + self._init_output_weights(hidden_dim, config) self._cross_entropy_impl = config.cross_entropy_impl if self._cross_entropy_impl == CrossEntropyImpl.auto: @@ -90,6 +94,23 @@ def __init__( if hasattr(self, "output_weights"): self.output_weights = self._config.transformer.peft.apply_weight(self.output_weights) + def _init_output_weights(self, hidden_dim: TensorDim, config) -> None: + # Only the first head defines the output weights + if self._tie_word_embeddings or self._prediction_distance > 0: + return + # untie embedding weights + vocab_dim = self._tensor_space.get_tensor_dim( + LanguageModelDimNames.vocab_tp if self._parallel_embeddings else LanguageModelDimNames.vocab + ) + self.output_weights = ParameterMeta.from_dims( + (vocab_dim, hidden_dim), + init_method=init_normal_( + std=config.init_method_std_embed, + min_val=config.init_method_min_embed, + max_val=config.init_method_max_embed, + ), + ) + def forward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None ) -> torch.Tensor: @@ -100,6 +121,9 @@ def forward( tensor_name="Loss", reductions=((DistributedDimNames.data, ReduceOp.AVG),), # noqa ) + if not self.is_last_head: + # MTP: split the stacked input + shared_hidden, input_ = torch.unbind(input_, dim=0) # TODO: Pytorch copies the grads in backward for no reason (not sure if still the case) # TODO: Torch compile implementation sometimes break. # TODO: Double-check correctness, optimize a bit more. @@ -107,26 +131,40 @@ def forward( # TODO: Skip cross-entropy backward if not needed. language_model_loss = self._forward(input_, kwargs, losses) if language_model_loss is not None: - losses[LanguageModelLossNames.language_model_loss].append(language_model_loss) + losses[self._loss_name].append(language_model_loss) # TODO: Return the model output when needed. - return language_model_loss + if self.is_last_head: + # Last head should return the loss for backward. + return language_model_loss + else: + # Backward hook to compute the gradient of the loss + shared_hidden = AuxiliaryLoss.apply(shared_hidden, language_model_loss, 1.0) + # MTP: Return shared_hidden to be used by the next head. + return shared_hidden def _forward_backward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None ) -> tuple[torch.Tensor, torch.Tensor | None]: - labels = kwargs[LanguageModelKwargs.labels].flatten() if LanguageModelKwargs.labels in kwargs else None + labels = kwargs[LanguageModelKwargs.labels] if LanguageModelKwargs.labels in kwargs else None + # MTP: Shift the labels + labels = labels[:, self._prediction_distance :].flatten() if labels is not None else None if self._sequence_parallel_logits: labels = split_op(labels, self._tensor_space.distributed.tensor_group, 0) do_grad = labels is not None and self.training input_ = input_.detach().requires_grad_(do_grad) with torch.enable_grad(): - ln_output = self.final_norm(input_) + # MTP: truncate the input + if self._prediction_distance > 0: + truncated_input = input_[:, : -self._prediction_distance, :].contiguous() + else: + truncated_input = input_ + ln_output = self.final_norm(truncated_input) grad_output = kwargs[TransformerKwargs.grad_output] / ( self._group_size if self._sequence_parallel_logits else 1 ) - output_weights = kwargs[WORD_EMBEDDINGS_WEIGHT] if self._tie_word_embeddings else self.output_weights + output_weights = self._get_output_weights(kwargs) loss, ln_output_grad = self._logits_cross_entropy_forward_backward_split( ln_output.detach(), labels, output_weights, grad_output, kwargs, losses ) @@ -137,6 +175,13 @@ def _forward_backward( else: return loss, None + def _get_output_weights(self, kwargs: dict) -> torch.Tensor: + if self._tie_word_embeddings: + return kwargs[WORD_EMBEDDINGS_WEIGHT] + if self._prediction_distance > 0: + return kwargs[OUTPUT_WEIGHTS] + return self.output_weights + def _logits_cross_entropy_forward_backward_split( self, input_: torch.Tensor, @@ -156,6 +201,7 @@ def _logits_cross_entropy_forward_backward_split( return None, None else: loss = None + # TODO MTP: allow a _cross_entropy_splits that is not a divisor of the sequence length split_size = div(labels.numel(), self._cross_entropy_splits) grad_output /= self._cross_entropy_splits logit_input = input_.flatten(0, -2) diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index b65be23f..311403fc 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -6,7 +6,7 @@ from fast_llm.core.distributed import set_generator from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank -from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.layers.transformer.attention import Attention from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs from fast_llm.layers.transformer.mixture_of_experts import MixtureOfExpertMLP @@ -27,11 +27,14 @@ def __init__( config: TransformerConfig, tensor_space: TensorSpace, layer_index: int, + return_input: bool = False, ): super().__init__() self._config = config self._tensor_space = tensor_space self._dropout_p = self._config.hidden_dropout + # For multi-token prediction, return a stack of shared_hidden and transformer_output. + self._return_input = return_input self._layer_index = layer_index self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory @@ -63,9 +66,10 @@ def name(self) -> str: return f"Transformer layer {self._layer_index}" def _get_meta(self, tensor: torch.Tensor, name: str, kwargs: dict): - return TensorMeta.from_dims( - kwargs[TransformerKwargs.hidden_dims], tensor_name=f"{self.name} {name}", dtype=tensor.dtype - ) + dims = kwargs[TransformerKwargs.hidden_dims] + if self._return_input: + dims = (TensorDim("stacked_input_output", 2),) + dims + return TensorMeta.from_dims(dims, tensor_name=f"{self.name} {name}", dtype=tensor.dtype) def _debug_log(self, tensor: torch.Tensor | None, name: str, kwargs: dict[str, typing.Any], *, bias=None) -> None: if self._config.debug_transformer_memory: @@ -103,6 +107,7 @@ def forward( ) if self._debug_mode: self._debug_log(None, "Begin", kwargs) + fw_input = input_ hidden_states = self.norm_1(input_) if self._debug_mode: self._debug_log(hidden_states, "Norm 1", kwargs) @@ -123,4 +128,6 @@ def forward( hidden_states = self._bias_dropout_add(hidden_states, bias, input_) if self._debug_mode: self._debug_log(None, "MLP residual", kwargs, bias=bias) + if self._return_input: + hidden_states = torch.stack((fw_input, hidden_states), dim=0) return hidden_states diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 6d97b34c..5a21368f 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -35,6 +35,7 @@ class Starcoder2GPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): class LlamaGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "llama" + class Qwen2GPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "qwen2" diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 51c8a3b7..30ae8041 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -11,8 +11,9 @@ AutoStateDictCheckpointHandler, ConstantExportParamConverter, ConstantImportParamConverter, + IgnoreExportWeightConverter, IgnoreImportParamConverter, - IgnoreWeightConverter, + IgnoreImportWeightConverter, MappedConfigParamConverter, ParamConverter, RenameParamConverter, @@ -29,9 +30,9 @@ GPTArchitectureConfig, GPTModelConfig, LlamaGPTHuggingfaceCheckpointFormat, - Qwen2GPTHuggingfaceCheckpointFormat, MistralGPTHuggingfaceCheckpointFormat, MixtralGPTHuggingfaceCheckpointFormat, + Qwen2GPTHuggingfaceCheckpointFormat, Starcoder2GPTHuggingfaceCheckpointFormat, ) from fast_llm.models.gpt.model import GPTModel @@ -160,58 +161,110 @@ def _create_config_converters(cls) -> list[ParamConverter]: def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: pass - def _create_weight_converters( self, ) -> list[WeightConverter]: converters = [] num_layers = self._model.config.base_model.transformer.num_layers - norm_bias: bool = self._model.config.base_model.transformer.normalization.type == NormalizationType.layer_norm - transformer_config: TransformerConfig = self._model.config.base_model.transformer - # Embedding and output - if self._model.config.base_model.tie_word_embeddings: - converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight")) - converters.append(IgnoreWeightConverter((), "lm_head.weight")) - else: - converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight")) - converters.append(WeightConverter(f"layers.{num_layers + 1}.output_weights", "lm_head.weight")) + # Embeddings + converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight")) - # Final norm - converters += self._get_weight_and_bias_converters( - f"layers.{num_layers + 1}.final_norm", "model.norm", norm_bias - ) + converters += self._create_lm_head_converters() for i in range(num_layers): + converters += self._create_transformer_layer_converters(i) + + return converters + + def _create_transformer_layer_converters(self, i: int, ignore_export: bool = False) -> list[WeightConverter]: + transformer_config: TransformerConfig = self._model.config.base_model.transformer + norm_bias: bool = self._model.config.base_model.transformer.normalization.type == NormalizationType.layer_norm + converters = [] + names_bias_cls = [ # Self-attn - converters += self._get_weight_and_bias_converters( + ( f"layers.{i+1}.self_attn.query", f"model.layers.{i}.self_attn.q_proj", transformer_config.add_attn_qkv_bias, QueryWeightConverter, - ) - converters += self._get_weight_and_bias_converters( + ), + ( f"layers.{i+1}.self_attn.key_value", (f"model.layers.{i}.self_attn.k_proj", f"model.layers.{i}.self_attn.v_proj"), transformer_config.add_attn_qkv_bias, KeyValueWeightConverter, - ) - converters += self._get_weight_and_bias_converters( + ), + ( f"layers.{i+1}.self_attn.dense", f"model.layers.{i}.self_attn.o_proj", transformer_config.add_attn_dense_bias, + WeightConverter, + ), + # Norm + ( + f"layers.{i+1}.norm_1", + f"model.layers.{i}.input_layernorm", + norm_bias, + WeightConverter, + ), + ( + f"layers.{i+1}.norm_2", + f"model.layers.{i}.post_attention_layernorm", + norm_bias, + WeightConverter, + ), + ] + for fast_llm_prefix, hf_prefix, use_bias, cls in names_bias_cls: + converters += self._get_weight_and_bias_converters( + fast_llm_prefix, + () if ignore_export else hf_prefix, + use_bias, + cls=IgnoreExportWeightConverter if ignore_export else cls, ) - # Norm + # MLP + if ignore_export: converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.norm_1", f"model.layers.{i}.input_layernorm", norm_bias + f"layers.{i+1}.mlp.layer_1", (), transformer_config.add_mlp_bias, cls=IgnoreExportWeightConverter ) converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.norm_2", f"model.layers.{i}.post_attention_layernorm", norm_bias + f"layers.{i+1}.mlp.layer_2", (), transformer_config.add_mlp_bias, cls=IgnoreExportWeightConverter ) - - # MLP + converters += [IgnoreExportWeightConverter(f"layers.{i+1}.mlp.router.weight", ())] + else: converters += self._get_mlp_converters(f"layers.{i+1}", f"model.layers.{i}") + return converters + + def _create_lm_head_converters(self) -> list[WeightConverter]: + num_layers = self._model.config.base_model.transformer.num_layers + prediction_heads = self._model.config.base_model.prediction_heads + norm_bias: bool = self._model.config.base_model.transformer.normalization.type == NormalizationType.layer_norm + converters = [] + + # Next-token prediction head + # Final norm + converters += self._get_weight_and_bias_converters( + f"layers.{num_layers + 1}.final_norm", "model.norm", norm_bias + ) + # Output weights + if self._model.config.base_model.tie_word_embeddings: + converters.append(IgnoreImportWeightConverter((), "lm_head.weight")) + else: + converters.append(WeightConverter(f"layers.{num_layers + 1}.output_weights", "lm_head.weight")) + + # MTP-heads > 0 are thrown away + for i in range(1, prediction_heads): + logger.warning( + f"The model weights for the multi-token prediction head {i} are discarded during conversion." + ) + mtp_transformer_layer_index = num_layers - 1 + 2 * i + # MTP transformer layer + converters += self._create_transformer_layer_converters(mtp_transformer_layer_index, ignore_export=True) + # MTP output norm + converters += self._get_weight_and_bias_converters( + f"layers.{mtp_transformer_layer_index + 2}.final_norm", (), norm_bias, IgnoreExportWeightConverter + ) return converters diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index d29fc28d..e878530c 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -12,7 +12,7 @@ from fast_llm.engine.schedule.config import BatchConfig from fast_llm.layers.language_model.config import LanguageModelKwargs, LanguageModelLossNames from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT, LanguageModelEmbedding -from fast_llm.layers.language_model.head import LanguageModelHead +from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead from fast_llm.layers.language_model.preprocessing import PositionEmbeddingPreprocessor from fast_llm.layers.transformer.config import ( RoutingType, @@ -71,7 +71,35 @@ def __init__( else: self._flash_varlen_preprocessor = FlashAttnVarlenPreprocessor(self._config.transformer, self._tensor_space) + def get_output_layers(self) -> list[Layer]: + return [ + layer + for i in range(self._config.prediction_heads) + for layer in [ + TransformerLayer( + self._config.transformer, + self._tensor_space, + # TODO MTP: which index? + layer_index=self._config.transformer.num_layers, + # The last layer only returns the transformer output. + # The previous layers return a stack of shared_hidden and transformer_output. + return_input=i < self._config.prediction_heads - 1, + ), + LanguageModelHead( + self._config, + self._tensor_space, + prediction_distance=i, + ), + ] + ] + def get_layers(self) -> list[Layer]: + if self._config.transformer.num_layers == 0: + Assert.eq(self._config.prediction_heads, 1) + return [ + LanguageModelEmbedding(self._config, self._tensor_space), + LanguageModelHead(self._config, self._tensor_space, 0), + ] return [ LanguageModelEmbedding(self._config, self._tensor_space), *[ @@ -80,9 +108,9 @@ def get_layers(self) -> list[Layer]: self._tensor_space, layer_index=i + 1, ) - for i in range(self._config.transformer.num_layers) + for i in range(self._config.transformer.num_layers - 1) ], - LanguageModelHead(self._config, self._tensor_space), + *self.get_output_layers(), ] def setup(self, distributed: Distributed) -> None: @@ -292,20 +320,33 @@ def transformer_layers(self) -> list[TransformerLayer]: @property def model_head(self) -> LanguageModelHead: - return self.layers[-1] + return self.layers[self.model_head_indices[0]] + + @property + def model_head_indices(self) -> list[int]: + return sorted([len(self) - 1 - 2 * i for i in range(self._config.prediction_heads)]) def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]: - return ( - {WORD_EMBEDDINGS_WEIGHT: (self.embedding.word_embeddings_weight, (0, len(self) - 1))} - if self._config.tie_word_embeddings - else {} - ) + if self._config.tie_word_embeddings: + return { + WORD_EMBEDDINGS_WEIGHT: ( + self.embedding.word_embeddings_weight, + (0, *self.model_head_indices), + ) + } + elif self._config.prediction_heads > 1: + return { + OUTPUT_WEIGHTS: ( + self.model_head.output_weights, + tuple(self.model_head_indices), + ) + } + else: + return {} @property def loss_defs(self) -> list[LossDef]: - loss_defs = [ - LossDef(name=LanguageModelLossNames.language_model_loss, formatted_name="language model loss", count=1) - ] + loss_defs = [] if ( self._config.transformer.num_experts > 1 and self._config.transformer.expert_routing_type == RoutingType.topk @@ -327,6 +368,15 @@ def loss_defs(self) -> list[LossDef]: ) if self._config.logit_z_loss: LossDef(name=LanguageModelLossNames.z_loss, formatted_name="logit z loss", count=1) + + for i in range(self._config.prediction_heads): + loss_defs.append( + LossDef( + name=LanguageModelLossNames.multi_token_prediction_loss(i), + formatted_name=f"language model loss {i}", + count=1, + ) + ) return loss_defs diff --git a/tests/common.py b/tests/common.py index 8b8e57c3..0bc6ed88 100644 --- a/tests/common.py +++ b/tests/common.py @@ -194,6 +194,12 @@ ] CONFIG_MIXTRAL_YARN_COMMON = CONFIG_MIXTRAL_YARN_FAST_LLM + ["model.distributed.training_dtype=bf16"] +CONFIG_LLAMA_MTP_MEGATRON = None +CONFIG_LLAMA_MTP_FAST_LLM = CONFIG_LLAMA_FAST_LLM + [ + "model.base_model.prediction_heads=4", +] +CONFIG_LLAMA_MTP_COMMON = CONFIG_LLAMA_MTP_FAST_LLM + ["model.distributed.training_dtype=bf16"] + _CONFIGS = { "gpt2": ("gpt", CONFIG_GPT2_FAST_LLM, CONFIG_GPT2_MEGATRON, CONFIG_GPT2_COMMON, None), "sc1": ("gpt", CONFIG_SC1_FAST_LLM, CONFIG_SC1_MEGATRON, CONFIG_SC1_COMMON, None), @@ -253,6 +259,13 @@ CONFIG_MIXTRAL_YARN_COMMON, MixtralGPTHuggingfaceCheckpointFormat, ), + "llama-mtp": ( + "gpt", + CONFIG_LLAMA_MTP_FAST_LLM, + CONFIG_LLAMA_MTP_MEGATRON, + CONFIG_LLAMA_MTP_COMMON, + LlamaGPTHuggingfaceCheckpointFormat, + ), }