-
Notifications
You must be signed in to change notification settings - Fork 28
Multi-token prediction #179
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking βSign up for GitHubβ, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a42cb0a
96f3607
54a52ca
9a2188f
3988fa3
a099ef6
56e1ca4
e6b5e27
fb37e29
fb5ef34
525c701
78e3c3e
a68a8d1
71968a6
feaf9a6
e22cb83
7185072
afacb9b
2edc4c2
f5703c1
1f0e134
c1b1450
c85be14
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,7 +13,7 @@ | |
from fast_llm.functional.config import CrossEntropyImpl, TritonConfig | ||
from fast_llm.functional.cross_entropy import cross_entropy_forward_backward | ||
from fast_llm.functional.linear import output_parallel_linear_backward, output_parallel_linear_forward | ||
from fast_llm.layers.common.auxiliary_loss import z_loss | ||
from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss | ||
from fast_llm.layers.language_model.config import ( | ||
LanguageModelBaseConfig, | ||
LanguageModelDimNames, | ||
|
@@ -24,7 +24,9 @@ | |
from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs | ||
from fast_llm.logging import log_distributed_tensor | ||
from fast_llm.tensor import ParameterMeta, TensorMeta, init_normal_ | ||
from fast_llm.utils import div | ||
from fast_llm.utils import Assert, div | ||
|
||
OUTPUT_WEIGHTS = "output_weights" | ||
|
||
|
||
class LanguageModelHead[ConfigType: LanguageModelBaseConfig](Configurable[LanguageModelBaseConfig], Layer): | ||
|
@@ -38,6 +40,7 @@ def __init__( | |
self, | ||
config: LanguageModelBaseConfig, | ||
tensor_space: TensorSpace, | ||
prediction_distance: int, | ||
): | ||
super().__init__(config) | ||
self._debug_transformer = config.transformer.debug_transformer | ||
|
@@ -56,23 +59,24 @@ def __init__( | |
|
||
hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) | ||
|
||
self._loss_name = LanguageModelLossNames.multi_token_prediction_loss(prediction_distance) | ||
RaymondLi0 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.final_norm = config.transformer.normalization.get_layer(hidden_dim) | ||
self._logits_scale_factor = config.logits_scale_factor | ||
self._z_loss_factor = config.logit_z_loss | ||
|
||
# untie embedding weights | ||
if not self._tie_word_embeddings: | ||
vocab_dim = self._tensor_space.get_tensor_dim( | ||
LanguageModelDimNames.vocab_tp if self._parallel_embeddings else LanguageModelDimNames.vocab | ||
) | ||
self.output_weights = ParameterMeta.from_dims( | ||
(vocab_dim, hidden_dim), | ||
init_method=init_normal_( | ||
std=config.init_method_std_embed, | ||
min_val=config.init_method_min_embed, | ||
max_val=config.init_method_max_embed, | ||
), | ||
) | ||
# Distance of the target token prediction | ||
# 0: next-token prediction | ||
# >0: multi-token prediction (MTP) | ||
Assert.geq(prediction_distance, 0) | ||
self._prediction_distance = prediction_distance | ||
self.is_last_head = self._prediction_distance == config.prediction_heads - 1 | ||
if self._prediction_distance > 0: | ||
assert ( | ||
not self._sequence_parallel_logits | ||
), "Sequence parallel logits not supported for multi-token prediction." | ||
assert not self._cross_entropy_splits, "Cross-entropy splits not supported for multi-token prediction." | ||
|
||
self._init_output_weights(hidden_dim, config) | ||
|
||
self._cross_entropy_impl = config.cross_entropy_impl | ||
if self._cross_entropy_impl == CrossEntropyImpl.auto: | ||
|
@@ -90,6 +94,23 @@ def __init__( | |
if hasattr(self, "output_weights"): | ||
self.output_weights = self._config.transformer.peft.apply_weight(self.output_weights) | ||
|
||
def _init_output_weights(self, hidden_dim: TensorDim, config) -> None: | ||
# Only the first head defines the output weights | ||
if self._tie_word_embeddings or self._prediction_distance > 0: | ||
return | ||
# untie embedding weights | ||
vocab_dim = self._tensor_space.get_tensor_dim( | ||
LanguageModelDimNames.vocab_tp if self._parallel_embeddings else LanguageModelDimNames.vocab | ||
) | ||
self.output_weights = ParameterMeta.from_dims( | ||
(vocab_dim, hidden_dim), | ||
init_method=init_normal_( | ||
std=config.init_method_std_embed, | ||
min_val=config.init_method_min_embed, | ||
max_val=config.init_method_max_embed, | ||
), | ||
) | ||
|
||
def forward( | ||
self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None | ||
) -> torch.Tensor: | ||
|
@@ -100,33 +121,50 @@ def forward( | |
tensor_name="Loss", | ||
reductions=((DistributedDimNames.data, ReduceOp.AVG),), # noqa | ||
) | ||
if not self.is_last_head: | ||
# MTP: split the stacked input | ||
shared_hidden, input_ = torch.unbind(input_, dim=0) | ||
# TODO: Pytorch copies the grads in backward for no reason (not sure if still the case) | ||
# TODO: Torch compile implementation sometimes break. | ||
# TODO: Double-check correctness, optimize a bit more. | ||
# TODO: Drop autograd entirely. | ||
# TODO: Skip cross-entropy backward if not needed. | ||
language_model_loss = self._forward(input_, kwargs, losses) | ||
if language_model_loss is not None: | ||
losses[LanguageModelLossNames.language_model_loss].append(language_model_loss) | ||
losses[self._loss_name].append(language_model_loss) | ||
# TODO: Return the model output when needed. | ||
return language_model_loss | ||
if self.is_last_head: | ||
# Last head should return the loss for backward. | ||
return language_model_loss | ||
else: | ||
# Backward hook to compute the gradient of the loss | ||
shared_hidden = AuxiliaryLoss.apply(shared_hidden, language_model_loss, 1.0) | ||
# MTP: Return shared_hidden to be used by the next head. | ||
return shared_hidden | ||
|
||
def _forward_backward( | ||
self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None | ||
) -> tuple[torch.Tensor, torch.Tensor | None]: | ||
labels = kwargs[LanguageModelKwargs.labels].flatten() if LanguageModelKwargs.labels in kwargs else None | ||
labels = kwargs[LanguageModelKwargs.labels] if LanguageModelKwargs.labels in kwargs else None | ||
# MTP: Shift the labels | ||
labels = labels[:, self._prediction_distance :].flatten() if labels is not None else None | ||
if self._sequence_parallel_logits: | ||
labels = split_op(labels, self._tensor_space.distributed.tensor_group, 0) | ||
do_grad = labels is not None and self.training | ||
input_ = input_.detach().requires_grad_(do_grad) | ||
with torch.enable_grad(): | ||
ln_output = self.final_norm(input_) | ||
# MTP: truncate the input | ||
if self._prediction_distance > 0: | ||
truncated_input = input_[:, : -self._prediction_distance, :].contiguous() | ||
else: | ||
truncated_input = input_ | ||
ln_output = self.final_norm(truncated_input) | ||
|
||
grad_output = kwargs[TransformerKwargs.grad_output] / ( | ||
self._group_size if self._sequence_parallel_logits else 1 | ||
) | ||
|
||
output_weights = kwargs[WORD_EMBEDDINGS_WEIGHT] if self._tie_word_embeddings else self.output_weights | ||
output_weights = self._get_output_weights(kwargs) | ||
loss, ln_output_grad = self._logits_cross_entropy_forward_backward_split( | ||
ln_output.detach(), labels, output_weights, grad_output, kwargs, losses | ||
) | ||
|
@@ -137,6 +175,13 @@ def _forward_backward( | |
else: | ||
return loss, None | ||
|
||
def _get_output_weights(self, kwargs: dict) -> torch.Tensor: | ||
if self._tie_word_embeddings: | ||
return kwargs[WORD_EMBEDDINGS_WEIGHT] | ||
if self._prediction_distance > 0: | ||
return kwargs[OUTPUT_WEIGHTS] | ||
return self.output_weights | ||
|
||
def _logits_cross_entropy_forward_backward_split( | ||
self, | ||
input_: torch.Tensor, | ||
|
@@ -156,6 +201,7 @@ def _logits_cross_entropy_forward_backward_split( | |
return None, None | ||
else: | ||
loss = None | ||
# TODO MTP: allow a _cross_entropy_splits that is not a divisor of the sequence length | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this important? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because for the additional heads, the input is truncated, so the sequence-length at that point will usually not be a multiple of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This can also be addressed by creating more labels at the data-sampling phase to avoid truncation of the sequence There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so this ends up being an argument for not truncating the sequence, right? |
||
split_size = div(labels.numel(), self._cross_entropy_splits) | ||
grad_output /= self._cross_entropy_splits | ||
logit_input = input_.flatten(0, -2) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not following, could you explain what you mean here?
Surely you're shifting the sequence for the additional heads, are you not?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. I shift the labels and truncate the inputs for the additional heads.
We could produce more labels than inputs to avoid truncating the input. Example with sequence-length=4, and mtp=2
Currently, the labels stop at
l4
, so the second head only processes[i1, i2, i3]
. By addingl5
, we can have a label for all the input tokens.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we implement here? Should be quite easy, just need to replace a bunch of
sequence_length+1
withsequence_length+prediction_heads
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's time-box this. I do not want us to spend more than a couple of hours on adding 4 tokens to the sequence.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok I can have a quick look at this this afternoon, and let's try to merge this PR afterwards
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will merge now, and add this to a new PR, since I'd need to test it a little bit more