-
Notifications
You must be signed in to change notification settings - Fork 27
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Multi-token prediction #179
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, but I have some suggestions to make it behave more like a standard language model with added prediction heads rather than an entirely new thing.
@@ -22,6 +22,10 @@ class LanguageModelLossNames: | |||
language_model_loss = "language_model_loss" | |||
z_loss = "z_loss" | |||
|
|||
@classmethod |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this really needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would make more sense as a @staticmethod
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@classmethod is fine (and probably better), I meant having the method at all seems unnecessary...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's for the same purpose that we have language_model_loss = "language_model_loss"
above, no?
losses[self.loss_name].append(language_model_loss) | ||
if self.is_last_head: | ||
# Last layer should return the loss for backward. | ||
return language_model_loss |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This returns the loss for the most distant predicted token, which doesn't make much sense. How about running the predictions in reversed order so the last one is the next token (index=0)? This would make the return value here more relevant and merge the is_last_head
and multi_token_prediction_index > 0
conditions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We do want to compute the heads one at the time, but the order in which we do it doesn't matter.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the conditions: the output-weights should be defined in the first layer (according to
# The weight should be defined in the first layer in the set. |
is_last_head
and multi_token_prediction_index > 0
conditions.
I also thought it would make it convenient if later we want to support a sequential version of this, where each head would use the output of the previous one as input (this is what deepseek-v3 does for example)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess you're right. I'll let you decide on the best option here.
return torch.stack((input_, output), dim=0) | ||
|
||
|
||
class MultiTokenPredictionLanguageModelHead(LanguageModelHead): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this really worth a separate class? Seems like we could just add a multi_token_prediction_index
/_prediction_distance
option to the base class, it would only lead to tiny changes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, we could yes, but there are quite a bit of changes, especially in _foward_backward
, with the shifting/truncating of labels and inputs, and the potential handling of sequence-parallel and cross-entropy-splits (not supported now, but this would need different code).
Do you still think we should put this in the base class LanguageModelHead
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I guess the logic of the base class would just be a special-case where multi_token_prediction_index=0
Then indeed I could just add this code to the base class
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I could also remove the MultiTokenPredictionTransformerLayer
class by adding a stacked_output
parameter to the TransformerLayer
class. wdyt @jlamypoirier @tscholak ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed both MultiTokenPredictionTransformerLayer
and MultiTokenPredictionLanguageModelHead
and moved the code to the base classes. Things should be a bit simpler now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, thanks @RaymondLi0!
I wonder though if it wouldn't be better to avoid adding the additional transformer layer for the first lm head. If that was the case, then the mtp feature could always be on (with default 1 head) and would do the same thing as before. Only for more than 1 head we would see additional transformer layers being added.
@@ -128,6 +128,7 @@ def _sample(self) -> None: | |||
# Calculate basic stats. | |||
documents_per_epoch = document_sizes.numel() | |||
tokens_per_epoch = document_sizes.sum().item() | |||
# TODO MTP: Produce more labels to provide labels for the multi-token prediction heads? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not following, could you explain what you mean here?
Surely you're shifting the sequence for the additional heads, are you not?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. I shift the labels and truncate the inputs for the additional heads.
We could produce more labels than inputs to avoid truncating the input. Example with sequence-length=4, and mtp=2
t1 t2 t3 t4 t5 t6 <- document
i1 i2 i3 i4 -- -- <- inputs
-- l1 l2 l3 l4 -- <- labels
Currently, the labels stop at l4
, so the second head only processes [i1, i2, i3]
. By adding l5
, we can have a label for all the input tokens.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we implement here? Should be quite easy, just need to replace a bunch of sequence_length+1
with sequence_length+prediction_heads
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's time-box this. I do not want us to spend more than a couple of hours on adding 4 tokens to the sequence.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok I can have a quick look at this this afternoon, and let's try to merge this PR afterwards
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will merge now, and add this to a new PR, since I'd need to test it a little bit more
losses[self.loss_name].append(language_model_loss) | ||
if self.is_last_head: | ||
# Last layer should return the loss for backward. | ||
return language_model_loss |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -22,6 +22,10 @@ class LanguageModelLossNames: | |||
language_model_loss = "language_model_loss" | |||
z_loss = "z_loss" | |||
|
|||
@classmethod |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@classmethod is fine (and probably better), I meant having the method at all seems unnecessary...
layer | ||
for i in range(self._config.prediction_heads) | ||
for layer in [ | ||
TransformerLayer( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wrong when num_layers=0. (Mostly for debug but we do want to support it.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the whole MTP thing doesn't make sense for num_layers=0. we would have 0 transformer layers for the first head and each 1 for every additional head.
@@ -128,6 +128,7 @@ def _sample(self) -> None: | |||
# Calculate basic stats. | |||
documents_per_epoch = document_sizes.numel() | |||
tokens_per_epoch = document_sizes.sum().item() | |||
# TODO MTP: Produce more labels to provide labels for the multi-token prediction heads? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we implement here? Should be quite easy, just need to replace a bunch of sequence_length+1
with sequence_length+prediction_heads
@@ -123,4 +127,6 @@ def forward( | |||
hidden_states = self._bias_dropout_add(hidden_states, bias, input_) | |||
if self._debug_mode: | |||
self._debug_log(None, "MLP residual", kwargs, bias=bias) | |||
if self._stacked_output: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't match the meta output which will break pipeline parallelism (same for LM head). We need to either update get_meta
(really easy) or explicitly prevent pipeline parallelism.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @RaymondLi0, this looks great.
I see some logic in trying to avoid truncation of the inputs when additional heads are present, but I would time box this at this point. I suggest 2 hours max. I'd like us to quickly move on to running actual experiments that validate the value proposition of MTP before over-investing in implementation details. Thanks.
@@ -128,6 +128,7 @@ def _sample(self) -> None: | |||
# Calculate basic stats. | |||
documents_per_epoch = document_sizes.numel() | |||
tokens_per_epoch = document_sizes.sum().item() | |||
# TODO MTP: Produce more labels to provide labels for the multi-token prediction heads? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's time-box this. I do not want us to spend more than a couple of hours on adding 4 tokens to the sequence.
@@ -151,6 +160,7 @@ def _logits_cross_entropy_forward_backward_split( | |||
return None, None | |||
else: | |||
loss = None | |||
# TODO MTP: allow a _cross_entropy_splits that is not a divisor of the sequence length |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so this ends up being an argument for not truncating the sequence, right?
i.e. resolving https://github.com/ServiceNow/Fast-LLM/pull/179/files#r2012178184
layer | ||
for i in range(self._config.prediction_heads) | ||
for layer in [ | ||
TransformerLayer( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the whole MTP thing doesn't make sense for num_layers=0. we would have 0 transformer layers for the first head and each 1 for every additional head.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor suggestion, otherwise LGTM
def get_layers(self) -> list[Layer]: | ||
if self._config.transformer.num_layers == 0: | ||
Assert.eq(self._config.prediction_heads, 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This could be checked in config validation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for a debug case like num_layers=0 that has no practical applications?
I think not.
Thanks for the reviews @jlamypoirier @tscholak ! I'll move the additional labels to a new PR to respect the time-boxing. |
✨ Description
Add a minimally working version of Multi-token prediction:
num_multi_token_prediction_heads
argument.None
by default. If true, will add new layers for multi-token prediction, that replace the standardlm_head
layer. For each token to predict, we add aMultiTokenPredictionTransformerLayer
and aMultiTokenPredictionLanguageModelHead
tie_word_embeddings
num_layers + 1
layers.In a future PR:
Closes #167
🔍 Type of change
Select all that apply:
Sanity check
A baseline transformer model reaches the same loss as a model with


num_multi_token_prediction_heads=1
and one less layer in the shared trunk. Memory usage is very similar.Performance
The additional layers seem to negatively impact training throughput

🗒️ Additional Notes
Include any additional context, information, or considerations here, such as known issues, follow-up tasks, or backward compatibility concerns.