Skip to content

Multi-token prediction #179

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking β€œSign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Apr 1, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions fast_llm/data/dataset/gpt/sampled.py
Original file line number Diff line number Diff line change
@@ -128,6 +128,7 @@ def _sample(self) -> None:
# Calculate basic stats.
documents_per_epoch = document_sizes.numel()
tokens_per_epoch = document_sizes.sum().item()
# TODO MTP: Produce more labels to provide labels for the multi-token prediction heads?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not following, could you explain what you mean here?
Surely you're shifting the sequence for the additional heads, are you not?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I shift the labels and truncate the inputs for the additional heads.
We could produce more labels than inputs to avoid truncating the input. Example with sequence-length=4, and mtp=2

t1 t2 t3 t4 t5 t6  <- document
i1 i2 i3 i4 -- --  <- inputs
-- l1 l2 l3 l4 --  <- labels

Currently, the labels stop at l4, so the second head only processes [i1, i2, i3]. By adding l5, we can have a label for all the input tokens.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we implement here? Should be quite easy, just need to replace a bunch of sequence_length+1 with sequence_length+prediction_heads

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's time-box this. I do not want us to spend more than a couple of hours on adding 4 tokens to the sequence.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I can have a quick look at this this afternoon, and let's try to merge this PR afterwards

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will merge now, and add this to a new PR, since I'd need to test it a little bit more

# We produce sequences of length `self._sequence_length + 1` so the last token has a label,
# but we also include that last label in the following sample,
# so we need `sequence_length * num_samples + 1` tokens in total.
30 changes: 27 additions & 3 deletions fast_llm/engine/checkpoint/external.py
Original file line number Diff line number Diff line change
@@ -141,12 +141,16 @@ def import_weight(
return weight


class IgnoreWeightConverter(WeightConverter):
class IgnoreImportWeightConverter(WeightConverter):
def __post_init__(self):
Assert.eq(len(self.fast_llm_name), 0)
Assert.gt(len(self.export_name), 0)

def export_weight(
self, weight: tuple[torch.Tensor | SafeTensorSlice, ...]
) -> tuple[torch.Tensor | SafeTensorSlice, ...]:
raise RuntimeError(
f"IgnoreWeightConverter should not be used for export: {self.fast_llm_name}, {self.export_name}"
f"IgnoreImportWeightConverter should not be used for export: {self.fast_llm_name}, {self.export_name}"
)

def import_weight(
@@ -155,6 +159,24 @@ def import_weight(
return ()


class IgnoreExportWeightConverter(WeightConverter):
def __post_init__(self):
Assert.gt(len(self.fast_llm_name), 0)
Assert.eq(len(self.export_name), 0)

def export_weight(
self, weight: tuple[torch.Tensor | SafeTensorSlice, ...]
) -> tuple[torch.Tensor | SafeTensorSlice, ...]:
return ()

def import_weight(
self, weight: tuple[torch.Tensor | SafeTensorSlice, ...]
) -> tuple[torch.Tensor | SafeTensorSlice, ...]:
raise RuntimeError(
f"IgnoreExportWeightConverter should not be used for import: {self.fast_llm_name}, {self.export_name}"
)


class CopyWeightConverter(WeightConverter):
def export_weight(
self, weight: tuple[torch.Tensor | SafeTensorSlice, ...]
@@ -198,7 +220,9 @@ def __init__(self, model: "FastLLMModel"):
if weight_converter.fast_llm_name
}
self._import_converters = {
weight_converter.export_name[0]: weight_converter for weight_converter in weight_converters
weight_converter.export_name[0]: weight_converter
for weight_converter in weight_converters
if weight_converter.export_name
}

@classmethod
6 changes: 4 additions & 2 deletions fast_llm/engine/checkpoint/state_dict.py
Original file line number Diff line number Diff line change
@@ -56,7 +56,9 @@ def save(self, config: CheckpointSaveConfig, metadata: CheckpointMetadata) -> No
saver.add_tensor(self._get_key(exported_name, shard_name), exported_tensor)

for shard_name, shard_state_dict in state_dict.items():
assert not shard_state_dict, (shard_name, list(state_dict))
assert (
not shard_state_dict
), f"Un-handled entries after conversion: {({k: list(v) for k, v in state_dict.items()})}"

index = saver.finalize()
if self._model.config.distributed.rank == 0:
@@ -90,7 +92,7 @@ def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> No
context.mark_as_loaded(loaded, (parameter_name, shard_name))

for shard_name, shard_state_dict in state_dict.items():
assert not shard_state_dict, (shard_name, list(state_dict))
assert not shard_state_dict, (shard_name, list(shard_state_dict))

@classmethod
@abc.abstractmethod
12 changes: 12 additions & 0 deletions fast_llm/layers/language_model/config.py
Original file line number Diff line number Diff line change
@@ -22,6 +22,12 @@ class LanguageModelLossNames:
language_model_loss = "language_model_loss"
z_loss = "z_loss"

@staticmethod
def multi_token_prediction_loss(index: int) -> str:
if index == 0:
return LanguageModelLossNames.language_model_loss
return f"language_model_loss_{index}"


class LanguageModelKwargs:
position_ids = "position_ids"
@@ -57,6 +63,12 @@ class LanguageModelArchitectureConfig(BaseModelArchitectureConfig):
tie_word_embeddings: bool = Field(
default=True, desc="Tie the output weights (logits) with the vocabulary embedding.", hint=FieldHint.core
)
prediction_heads: int = Field(
default=1,
desc="Number of multi-token prediction heads.",
hint=FieldHint.feature,
valid=check_field(Assert.gt, 0),
)

def _validate(self) -> None:
if self.use_position_embeddings is None:
86 changes: 66 additions & 20 deletions fast_llm/layers/language_model/head.py
Original file line number Diff line number Diff line change
@@ -13,7 +13,7 @@
from fast_llm.functional.config import CrossEntropyImpl, TritonConfig
from fast_llm.functional.cross_entropy import cross_entropy_forward_backward
from fast_llm.functional.linear import output_parallel_linear_backward, output_parallel_linear_forward
from fast_llm.layers.common.auxiliary_loss import z_loss
from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss
from fast_llm.layers.language_model.config import (
LanguageModelBaseConfig,
LanguageModelDimNames,
@@ -24,7 +24,9 @@
from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs
from fast_llm.logging import log_distributed_tensor
from fast_llm.tensor import ParameterMeta, TensorMeta, init_normal_
from fast_llm.utils import div
from fast_llm.utils import Assert, div

OUTPUT_WEIGHTS = "output_weights"


class LanguageModelHead[ConfigType: LanguageModelBaseConfig](Configurable[LanguageModelBaseConfig], Layer):
@@ -38,6 +40,7 @@ def __init__(
self,
config: LanguageModelBaseConfig,
tensor_space: TensorSpace,
prediction_distance: int,
):
super().__init__(config)
self._debug_transformer = config.transformer.debug_transformer
@@ -56,23 +59,24 @@ def __init__(

hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden)

self._loss_name = LanguageModelLossNames.multi_token_prediction_loss(prediction_distance)
self.final_norm = config.transformer.normalization.get_layer(hidden_dim)
self._logits_scale_factor = config.logits_scale_factor
self._z_loss_factor = config.logit_z_loss

# untie embedding weights
if not self._tie_word_embeddings:
vocab_dim = self._tensor_space.get_tensor_dim(
LanguageModelDimNames.vocab_tp if self._parallel_embeddings else LanguageModelDimNames.vocab
)
self.output_weights = ParameterMeta.from_dims(
(vocab_dim, hidden_dim),
init_method=init_normal_(
std=config.init_method_std_embed,
min_val=config.init_method_min_embed,
max_val=config.init_method_max_embed,
),
)
# Distance of the target token prediction
# 0: next-token prediction
# >0: multi-token prediction (MTP)
Assert.geq(prediction_distance, 0)
self._prediction_distance = prediction_distance
self.is_last_head = self._prediction_distance == config.prediction_heads - 1
if self._prediction_distance > 0:
assert (
not self._sequence_parallel_logits
), "Sequence parallel logits not supported for multi-token prediction."
assert not self._cross_entropy_splits, "Cross-entropy splits not supported for multi-token prediction."

self._init_output_weights(hidden_dim, config)

self._cross_entropy_impl = config.cross_entropy_impl
if self._cross_entropy_impl == CrossEntropyImpl.auto:
@@ -90,6 +94,23 @@ def __init__(
if hasattr(self, "output_weights"):
self.output_weights = self._config.transformer.peft.apply_weight(self.output_weights)

def _init_output_weights(self, hidden_dim: TensorDim, config) -> None:
# Only the first head defines the output weights
if self._tie_word_embeddings or self._prediction_distance > 0:
return
# untie embedding weights
vocab_dim = self._tensor_space.get_tensor_dim(
LanguageModelDimNames.vocab_tp if self._parallel_embeddings else LanguageModelDimNames.vocab
)
self.output_weights = ParameterMeta.from_dims(
(vocab_dim, hidden_dim),
init_method=init_normal_(
std=config.init_method_std_embed,
min_val=config.init_method_min_embed,
max_val=config.init_method_max_embed,
),
)

def forward(
self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None
) -> torch.Tensor:
@@ -100,33 +121,50 @@ def forward(
tensor_name="Loss",
reductions=((DistributedDimNames.data, ReduceOp.AVG),), # noqa
)
if not self.is_last_head:
# MTP: split the stacked input
shared_hidden, input_ = torch.unbind(input_, dim=0)
# TODO: Pytorch copies the grads in backward for no reason (not sure if still the case)
# TODO: Torch compile implementation sometimes break.
# TODO: Double-check correctness, optimize a bit more.
# TODO: Drop autograd entirely.
# TODO: Skip cross-entropy backward if not needed.
language_model_loss = self._forward(input_, kwargs, losses)
if language_model_loss is not None:
losses[LanguageModelLossNames.language_model_loss].append(language_model_loss)
losses[self._loss_name].append(language_model_loss)
# TODO: Return the model output when needed.
return language_model_loss
if self.is_last_head:
# Last head should return the loss for backward.
return language_model_loss
else:
# Backward hook to compute the gradient of the loss
shared_hidden = AuxiliaryLoss.apply(shared_hidden, language_model_loss, 1.0)
# MTP: Return shared_hidden to be used by the next head.
return shared_hidden

def _forward_backward(
self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None
) -> tuple[torch.Tensor, torch.Tensor | None]:
labels = kwargs[LanguageModelKwargs.labels].flatten() if LanguageModelKwargs.labels in kwargs else None
labels = kwargs[LanguageModelKwargs.labels] if LanguageModelKwargs.labels in kwargs else None
# MTP: Shift the labels
labels = labels[:, self._prediction_distance :].flatten() if labels is not None else None
if self._sequence_parallel_logits:
labels = split_op(labels, self._tensor_space.distributed.tensor_group, 0)
do_grad = labels is not None and self.training
input_ = input_.detach().requires_grad_(do_grad)
with torch.enable_grad():
ln_output = self.final_norm(input_)
# MTP: truncate the input
if self._prediction_distance > 0:
truncated_input = input_[:, : -self._prediction_distance, :].contiguous()
else:
truncated_input = input_
ln_output = self.final_norm(truncated_input)

grad_output = kwargs[TransformerKwargs.grad_output] / (
self._group_size if self._sequence_parallel_logits else 1
)

output_weights = kwargs[WORD_EMBEDDINGS_WEIGHT] if self._tie_word_embeddings else self.output_weights
output_weights = self._get_output_weights(kwargs)
loss, ln_output_grad = self._logits_cross_entropy_forward_backward_split(
ln_output.detach(), labels, output_weights, grad_output, kwargs, losses
)
@@ -137,6 +175,13 @@ def _forward_backward(
else:
return loss, None

def _get_output_weights(self, kwargs: dict) -> torch.Tensor:
if self._tie_word_embeddings:
return kwargs[WORD_EMBEDDINGS_WEIGHT]
if self._prediction_distance > 0:
return kwargs[OUTPUT_WEIGHTS]
return self.output_weights

def _logits_cross_entropy_forward_backward_split(
self,
input_: torch.Tensor,
@@ -156,6 +201,7 @@ def _logits_cross_entropy_forward_backward_split(
return None, None
else:
loss = None
# TODO MTP: allow a _cross_entropy_splits that is not a divisor of the sequence length
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this important?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because for the additional heads, the input is truncated, so the sequence-length at that point will usually not be a multiple of _cross_entropy_splits

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can also be addressed by creating more labels at the data-sampling phase to avoid truncation of the sequence

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so this ends up being an argument for not truncating the sequence, right?
i.e. resolving https://github.com/ServiceNow/Fast-LLM/pull/179/files#r2012178184

split_size = div(labels.numel(), self._cross_entropy_splits)
grad_output /= self._cross_entropy_splits
logit_input = input_.flatten(0, -2)
15 changes: 11 additions & 4 deletions fast_llm/layers/transformer/transformer.py
Original file line number Diff line number Diff line change
@@ -6,7 +6,7 @@
from fast_llm.core.distributed import set_generator
from fast_llm.engine.base_model.base_model import Layer
from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank
from fast_llm.engine.config_utils.tensor_space import TensorSpace
from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace
from fast_llm.layers.transformer.attention import Attention
from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs
from fast_llm.layers.transformer.mixture_of_experts import MixtureOfExpertMLP
@@ -27,11 +27,14 @@ def __init__(
config: TransformerConfig,
tensor_space: TensorSpace,
layer_index: int,
return_input: bool = False,
):
super().__init__()
self._config = config
self._tensor_space = tensor_space
self._dropout_p = self._config.hidden_dropout
# For multi-token prediction, return a stack of shared_hidden and transformer_output.
self._return_input = return_input

self._layer_index = layer_index
self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory
@@ -63,9 +66,10 @@ def name(self) -> str:
return f"Transformer layer {self._layer_index}"

def _get_meta(self, tensor: torch.Tensor, name: str, kwargs: dict):
return TensorMeta.from_dims(
kwargs[TransformerKwargs.hidden_dims], tensor_name=f"{self.name} {name}", dtype=tensor.dtype
)
dims = kwargs[TransformerKwargs.hidden_dims]
if self._return_input:
dims = (TensorDim("stacked_input_output", 2),) + dims
return TensorMeta.from_dims(dims, tensor_name=f"{self.name} {name}", dtype=tensor.dtype)

def _debug_log(self, tensor: torch.Tensor | None, name: str, kwargs: dict[str, typing.Any], *, bias=None) -> None:
if self._config.debug_transformer_memory:
@@ -103,6 +107,7 @@ def forward(
)
if self._debug_mode:
self._debug_log(None, "Begin", kwargs)
fw_input = input_
hidden_states = self.norm_1(input_)
if self._debug_mode:
self._debug_log(hidden_states, "Norm 1", kwargs)
@@ -123,4 +128,6 @@ def forward(
hidden_states = self._bias_dropout_add(hidden_states, bias, input_)
if self._debug_mode:
self._debug_log(None, "MLP residual", kwargs, bias=bias)
if self._return_input:
hidden_states = torch.stack((fw_input, hidden_states), dim=0)
return hidden_states
1 change: 1 addition & 0 deletions fast_llm/models/gpt/config.py
Original file line number Diff line number Diff line change
@@ -35,6 +35,7 @@ class Starcoder2GPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat):
class LlamaGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat):
name: typing.ClassVar[str] = "llama"


class Qwen2GPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat):
name: typing.ClassVar[str] = "qwen2"

Loading