Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi-token prediction #179

Merged
merged 23 commits into from
Apr 1, 2025
Merged

Multi-token prediction #179

merged 23 commits into from
Apr 1, 2025

Conversation

RaymondLi0
Copy link
Contributor

@RaymondLi0 RaymondLi0 commented Mar 7, 2025

✨ Description

Add a minimally working version of Multi-token prediction:

  • a new num_multi_token_prediction_heads argument. None by default. If true, will add new layers for multi-token prediction, that replace the standard lm_head layer. For each token to predict, we add a MultiTokenPredictionTransformerLayer and a MultiTokenPredictionLanguageModelHead
  • All the output weights are shared. Like before, they can be shared with the input embeddings with tie_word_embeddings
  • Each multi-token prediction loss is tracked separately, and contributes equally to the final loss being optimized.
  • At export the additional lm-heads are discarded, resulting in a transformer model with num_layers + 1 layers.

In a future PR:

  • Add more labels in data-sampling to avoid input-truncation
  • add weights for each future-token prediction loss (e.g. the next token could be more important than the future ones)
  • handle sequence-parallelism and cross-entropy splits
  • Convert the additional lm-heads (requires a corresponding hf-transformer implementation)
  • Improve throughput

Closes #167

🔍 Type of change

Select all that apply:

  • 🐛 Bug fix (non-breaking change that addresses a specific issue)
  • 🚀 New feature (non-breaking change that adds functionality)
  • ⚠️ Breaking change (a change that could affect existing functionality)
  • 📈 Performance improvement/optimization (improves speed, memory usage, or efficiency)
  • 🛠️ Code refactor (non-functional changes that improve code readability, structure, etc.)
  • 📦 Dependency bump (updates dependencies, including Dockerfile or package changes)
  • 📝 Documentation change (updates documentation, including new content or typo fixes)
  • 🔧 Infrastructure/Build change (affects build process, CI/CD, or dependencies)

Sanity check

A baseline transformer model reaches the same loss as a model with num_multi_token_prediction_heads=1 and one less layer in the shared trunk. Memory usage is very similar.
Screenshot 2025-03-24 at 4 51 29 PM
Screenshot 2025-03-24 at 4 57 46 PM

Performance

The additional layers seem to negatively impact training throughput
Screenshot 2025-03-24 at 4 53 31 PM


🗒️ Additional Notes

Include any additional context, information, or considerations here, such as known issues, follow-up tasks, or backward compatibility concerns.

@RaymondLi0 RaymondLi0 marked this pull request as ready for review March 24, 2025 20:39
Copy link
Collaborator

@jlamypoirier jlamypoirier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, but I have some suggestions to make it behave more like a standard language model with added prediction heads rather than an entirely new thing.

@@ -22,6 +22,10 @@ class LanguageModelLossNames:
language_model_loss = "language_model_loss"
z_loss = "z_loss"

@classmethod
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this really needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would make more sense as a @staticmethod?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@classmethod is fine (and probably better), I meant having the method at all seems unnecessary...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's for the same purpose that we have language_model_loss = "language_model_loss" above, no?

losses[self.loss_name].append(language_model_loss)
if self.is_last_head:
# Last layer should return the loss for backward.
return language_model_loss
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This returns the loss for the most distant predicted token, which doesn't make much sense. How about running the predictions in reversed order so the last one is the next token (index=0)? This would make the return value here more relevant and merge the is_last_head and multi_token_prediction_index > 0 conditions.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Screenshot 2025-03-24 at 9 37 30 PM

we want this ^^^ because this way we ensure that memory impact is minimal.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do want to compute the heads one at the time, but the order in which we do it doesn't matter.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the conditions: the output-weights should be defined in the first layer (according to

# The weight should be defined in the first layer in the set.
), right? And the last layer should return its loss. So even if we do the reverse order I don't think we can merge the is_last_head and multi_token_prediction_index > 0 conditions.

I also thought it would make it convenient if later we want to support a sequential version of this, where each head would use the output of the previous one as input (this is what deepseek-v3 does for example)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess you're right. I'll let you decide on the best option here.

return torch.stack((input_, output), dim=0)


class MultiTokenPredictionLanguageModelHead(LanguageModelHead):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this really worth a separate class? Seems like we could just add a multi_token_prediction_index/_prediction_distance option to the base class, it would only lead to tiny changes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, we could yes, but there are quite a bit of changes, especially in _foward_backward, with the shifting/truncating of labels and inputs, and the potential handling of sequence-parallel and cross-entropy-splits (not supported now, but this would need different code).
Do you still think we should put this in the base class LanguageModelHead ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I guess the logic of the base class would just be a special-case where multi_token_prediction_index=0
Then indeed I could just add this code to the base class

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could also remove the MultiTokenPredictionTransformerLayer class by adding a stacked_output parameter to the TransformerLayer class. wdyt @jlamypoirier @tscholak ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed both MultiTokenPredictionTransformerLayer and MultiTokenPredictionLanguageModelHead and moved the code to the base classes. Things should be a bit simpler now.

Copy link
Collaborator

@tscholak tscholak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thanks @RaymondLi0!
I wonder though if it wouldn't be better to avoid adding the additional transformer layer for the first lm head. If that was the case, then the mtp feature could always be on (with default 1 head) and would do the same thing as before. Only for more than 1 head we would see additional transformer layers being added.

@@ -128,6 +128,7 @@ def _sample(self) -> None:
# Calculate basic stats.
documents_per_epoch = document_sizes.numel()
tokens_per_epoch = document_sizes.sum().item()
# TODO MTP: Produce more labels to provide labels for the multi-token prediction heads?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not following, could you explain what you mean here?
Surely you're shifting the sequence for the additional heads, are you not?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I shift the labels and truncate the inputs for the additional heads.
We could produce more labels than inputs to avoid truncating the input. Example with sequence-length=4, and mtp=2

t1 t2 t3 t4 t5 t6  <- document
i1 i2 i3 i4 -- --  <- inputs
-- l1 l2 l3 l4 --  <- labels

Currently, the labels stop at l4, so the second head only processes [i1, i2, i3]. By adding l5, we can have a label for all the input tokens.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we implement here? Should be quite easy, just need to replace a bunch of sequence_length+1 with sequence_length+prediction_heads

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's time-box this. I do not want us to spend more than a couple of hours on adding 4 tokens to the sequence.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I can have a quick look at this this afternoon, and let's try to merge this PR afterwards

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will merge now, and add this to a new PR, since I'd need to test it a little bit more

losses[self.loss_name].append(language_model_loss)
if self.is_last_head:
# Last layer should return the loss for backward.
return language_model_loss
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Screenshot 2025-03-24 at 9 37 30 PM

we want this ^^^ because this way we ensure that memory impact is minimal.

@@ -22,6 +22,10 @@ class LanguageModelLossNames:
language_model_loss = "language_model_loss"
z_loss = "z_loss"

@classmethod
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@classmethod is fine (and probably better), I meant having the method at all seems unnecessary...

layer
for i in range(self._config.prediction_heads)
for layer in [
TransformerLayer(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wrong when num_layers=0. (Mostly for debug but we do want to support it.)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the whole MTP thing doesn't make sense for num_layers=0. we would have 0 transformer layers for the first head and each 1 for every additional head.

@@ -128,6 +128,7 @@ def _sample(self) -> None:
# Calculate basic stats.
documents_per_epoch = document_sizes.numel()
tokens_per_epoch = document_sizes.sum().item()
# TODO MTP: Produce more labels to provide labels for the multi-token prediction heads?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we implement here? Should be quite easy, just need to replace a bunch of sequence_length+1 with sequence_length+prediction_heads

@@ -123,4 +127,6 @@ def forward(
hidden_states = self._bias_dropout_add(hidden_states, bias, input_)
if self._debug_mode:
self._debug_log(None, "MLP residual", kwargs, bias=bias)
if self._stacked_output:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't match the meta output which will break pipeline parallelism (same for LM head). We need to either update get_meta (really easy) or explicitly prevent pipeline parallelism.

Copy link
Collaborator

@tscholak tscholak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @RaymondLi0, this looks great.
I see some logic in trying to avoid truncation of the inputs when additional heads are present, but I would time box this at this point. I suggest 2 hours max. I'd like us to quickly move on to running actual experiments that validate the value proposition of MTP before over-investing in implementation details. Thanks.

@@ -128,6 +128,7 @@ def _sample(self) -> None:
# Calculate basic stats.
documents_per_epoch = document_sizes.numel()
tokens_per_epoch = document_sizes.sum().item()
# TODO MTP: Produce more labels to provide labels for the multi-token prediction heads?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's time-box this. I do not want us to spend more than a couple of hours on adding 4 tokens to the sequence.

@@ -151,6 +160,7 @@ def _logits_cross_entropy_forward_backward_split(
return None, None
else:
loss = None
# TODO MTP: allow a _cross_entropy_splits that is not a divisor of the sequence length
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so this ends up being an argument for not truncating the sequence, right?
i.e. resolving https://github.com/ServiceNow/Fast-LLM/pull/179/files#r2012178184

layer
for i in range(self._config.prediction_heads)
for layer in [
TransformerLayer(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the whole MTP thing doesn't make sense for num_layers=0. we would have 0 transformer layers for the first head and each 1 for every additional head.

@RaymondLi0 RaymondLi0 changed the title WIP: Multi-token prediction Multi-token prediction Mar 31, 2025
Copy link
Collaborator

@jlamypoirier jlamypoirier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor suggestion, otherwise LGTM

def get_layers(self) -> list[Layer]:
if self._config.transformer.num_layers == 0:
Assert.eq(self._config.prediction_heads, 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be checked in config validation

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for a debug case like num_layers=0 that has no practical applications?
I think not.

@RaymondLi0
Copy link
Contributor Author

Thanks for the reviews @jlamypoirier @tscholak ! I'll move the additional labels to a new PR to respect the time-boxing.

@RaymondLi0 RaymondLi0 merged commit 9036fd2 into main Apr 1, 2025
4 checks passed
@RaymondLi0 RaymondLi0 deleted the raymond/mtp branch April 1, 2025 13:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Multi-token prediction
3 participants