Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Linear Cross Entropy #2507

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

Andrei-Aksionov
Copy link

Hi there 👋

This PR adds support for Linear Cross Entropy loss (a.k.a. Cut Cross Entropy) proposed by Apple in CUT YOUR LOSSES
IN LARGE-VOCABULARY LANGUAGE MODELS
paper by adding support of the official implementation.

So, from now on, we have:

  • CCE: chunked cross entropy
  • LCE: linear cross entropy

The benefit of LCE is that it reduces VRAM usage, especially for models with large vocab sizes.

This is done with 3 things:

  1. Custom Kernels.
    LCE uses custom CUDA kernels to perform matrix multiplications and the log-sum-exp reduction in shared memory.
    Usually you need to first calculate logits by multiplying embeddings (B, T, hidden_dim) with the output layer (hidden_dim, vocab_size) and then calculate loss with these logits and target labels.
    LCE doesn't materialize this matrix with logits in the global memory (this matrix might be huge with modern LLMs that have large vocab size), but rather multiplies only a part (at a time) of output layer weights with logits in the shared (fast) memory and immediately calculates loss there (flash attention playbook).
    This avoids materializing logits matrix (saves memory) and saves memory bandwidth (no need to move data from the global to shared memory and back multiple times).

  2. Gradient filtering.
    LCE leverages the sparsity of the softmax gradient to skip elements of the gradient computation that have a negligible contribution. This improves the throughput of LCE by reducing unnecessary computation.

  3. Vocabulary sorting.
    LCE uses vocabulary sorting to group tokens with similar average logits together. This increases the block-level sparsity and improves the effectiveness of gradient filtering.


I ran a quick single device LoRA recipe with gemma 2 2b model.
This is the best possible case for LCE, since this model is relatively small, but has the same size of the vocab as larger models of gemma 2 family.

I ran 4 experiments:

  • CCE (non-compiled): chunked cross-entropy without compilation of model and loss
  • CCE: chunked cross-entropy with compilation
  • LCE: linear cross entropy with compilation
  • LCE (torch_compile impl): an implementation for older GPUs and MPS devices

Note

To do this, one needs to change loss section of a config

loss:
    _component_: cut_cross_entropy.LinearCrossEntropy
    impl: torch_compile # (older than Ampere GPUs or MPS devices)
Name Peak Memory Reserved (GB) Time
CCE (non-compiled) 21.6 10:15
CCE 12.3 09:18
LCE 7.73 04:51
LCE (torch_compile impl) 9.1 04:43
Screenshot 2025-03-17 at 6 55 01 PM

As one can see, the loss chart is identical in all cases, yet LCE significantly reduced reserved memory size and reduced time for training.


Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Please link to any issues this PR addresses.

Test plan

Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example

  • I did not change any public API
  • I have added an example to docs or docstrings

Copy link

pytorch-bot bot commented Mar 17, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2507

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link

Hi @Andrei-Aksionov!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at [email protected]. Thanks!

@Andrei-Aksionov
Copy link
Author

There are a couple of things that prevent this from moving from draft stage.

  1. What to do with docs? Do I need to create a wrapper class for LinearCrossEntropy with a docstring, so this info is automatically populated in docs?
  2. Due to lack of compute resources (and no access to multiple GPU machines), I tested only single device LoRA recipe.
    I haven't found any mentioning that LCE doesn't work on muiti-GPU setups, yet cannot confirm that.
    What should I do in this case? And what to do with other recipes?

@facebook-github-bot
Copy link

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks!

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 17, 2025
@felipemello1
Copy link
Contributor

felipemello1 commented Mar 17, 2025

hey @Andrei-Aksionov , thanks for the PR! impressive numbers. Do you have any intuition for why tokens per second would double? From their paper, torch.compile is actually the fastest. I dont know if in their paper, when they compared vs Torchtune, they used compile=True.

But, in any case, I wouldnt imagine that changing the the output_layer + CE takes would have such TPS impact (but my intuition here may be wrong, since its a 2B model + lora, maybe computing the output_layer for large vocab is the most expensive thing)

https://arxiv.org/pdf/2411.09009
image

If you have bandwidth, could you try with llama 8b for a few steps CCE+compile vs cutloss compile?

@Andrei-Aksionov
Copy link
Author

Hello @felipemello1

Do you have any intuition for why tokens per second would double?

Well, I guess that's because it's Gemma 2 with only 2b parameters, but the vocab size is the same as for larger models from Gemma 2 family, so applying classification head and calculating loss has a significant effect on training speed. Plus it's a LoRA variant, so it makes it even more pronounced.
I would imagine for other models of a bigger size and smaller vocab size the effect might be smaller.

As I understand, in case of a regular cross-entropy (chunked or not), the process is:

  1. Calculate logits by multiplying embeddings with the output layer. If the vocab size is large, there are a lot of copying operations from shared memory (SM) to global memory (GB or HBM).
  2. We need to convert those logits to probabilities, so we need:
    a) load logits block-by-block from HBM into SM to exponentiate values and then return them back.
    b) calculate sum of logits
    c) load exponentiated logits back from HBM into SM to divide them by the sum value, so we can convert them to probabilities
  3. Load them again to calculate the loss

As you can see, a lot of copying operations compared to number of floating point operations.

LCE, on the other hand:

  1. Copies a portion of embeddings and output weights (that can fit into a thread block), HBM --> SM
  2. Multiplies them to get logits, but keeps the result in SM
  3. Converts logits into probabilities without a need of knowing the sum of all exponentiated values, thanks to online softmax normalization trick (the one used in flash attention).
  4. Probably calculates the loss there, but I'm not sure.
  5. Copies the result back SM --> HBM.

As you can see, the proportion of (number of floating point ops) / (bits to copy) is higher than for the regular CE calculation.
That's exactly what you want when you are writing a custom kernel to make everything faster (considering that the default kernel is memory bound).


The most interesting part to me, is that the torch_compile impl version is not that far off, actually.
It's as fast, but consumes a bit more memory (albeit smaller than CCE).
And this is just your regular pytorch code, where in one function embeddings @ output layer + CE calculation, all wrapped in torch.compile decorator. No custom triton kernels.
I would imagine, that torch.compile can fuse these operations together and for softmax operations it uses online softmax normalization trick, which is already implemented in SDPA anyway. So, there is high chance for it.

That said, it could be a better option, since no additional dependencies required and no custom triton kernels.


I dont know if in their paper, when they compared vs Torchtune, they used compile=True.

I initially thought it might be unfair to compare a regular, non-compiled model with LCE. To address this, I included a compiled version in the comparison as well.
For consistency, I also decided to compile the LCE variant. I'm planning to add a non-compiled LCE version for completeness, although I expect the loss function to be compiled automatically anyway (only the model itself will stay non-compiled).

If you have bandwidth, could you try with llama 8b for a few steps CCE+compile vs cutloss compile?

I'll try to.

@Andrei-Aksionov
Copy link
Author

Hey @felipemello1

I've rerun test runs without compilation for both implementations of LCE (for completeness).
(Forgot to mention that it's all done on a single L4.)

Loss Name Peak Memory Reserved (GB) Time Compiled
CCE 21.6 10:15 False
CCE 12.3 09:18 True
------------------------- ------------------------- ------------------------- -------------------------
LCE 8.2 05:53 False
LCE 7.73 04:51 True
------------------------- ------------------------- ------------------------- -------------------------
LCE (torch_compile impl) 8.7 05:40 False
LCE (torch_compile impl) 9.1 04:43 True

As you can see, even without a compilation, it's still ~2x faster than non-compiled CCE.


I wasn't able to run Llama 3.1 8b, since I don't have access to the repo. Tried to request access a couple of times and now I'm blocked foreeeeeeveeer.
So I tried Qwen 2.5 7b. It has a larger vocab size than llama (152k vs 128k) and 1b less parameters, but at least it's something somewhat similar (kinda 🙃).

Loss Name Peak Memory Reserved (GB) Time
CCE 17.5 05:14
LCE 16.2 04:40
LCE (torch_compile impl) 16.9 04:26
Screenshot 2025-03-19 at 8 33 24 PM

Here the difference is less pronounced, since now the vocab size doesn't dwarf everything else, so doesn't have such a significant impact.

Again, what's interesting, is the torch_compile impl:

  1. As fast as LCE implemented with custom Triton kernels
  2. Provides some memory savings, but smaller than the proper Triton variant. (Due to not supporting Gradient filtering.)

With larger and larger models the memory savings will become negligible, but the improvements in the speed of training is a nice thing.
I recommend spending time on researching the impact of a function, that, apparently, makes it easier to fuse last layer matmul and loss calculations. Again, look at the code for torch_compile impl: it's just a function with two operations in it. The benefit is that you can replace F.cross_entropy with all the custom loss functions from this repo.

What do you think, @felipemello1?
And @ebsmothers, of course 😊

@felipemello1
Copy link
Contributor

hey @Andrei-Aksionov , this is amazing. Thanks for all the insightful experiments and comments.

I think that, as a rule of thumb, we would be way more comfortable to add a torch-only implementation to torchtune than take dependency on a new repo. I don't know if their license allows copying it, but horace implemented an early version of pytorch-only CCE and shared it on twitter. Maybe we could use that?

We are a bit swamped now with some other tasks, but I dont want to keep you waiting. Can i get back you next week after i meet with the team and check if/how we want to add this loss?

The way you did works, but a) it adds some extra logic to the recipe (which is already a bit bloated), b) it modifies the model.
What I thought about doing originally was to add to the transformer a flag "skip_output_layer", but this is also bad, as it adds yet another flag to the model.

So let me check what others think and we can get back to it. Does that sound like a plan?

@felipemello1
Copy link
Contributor

Also, food for thought: Many other losses in torchtune ended up implementing the chunked version (e.g. knowledge distilation and GRPO). Although the fused version is more efficient, i wonder how easy it would be for someone to try to repurpose it. Maybe we would need to keep the current chunked version around.

@Andrei-Aksionov
Copy link
Author

What I thought about doing originally was to add to the transformer a flag "skip_output_layer", but this is also bad, as it adds yet another flag to the model.

That was my first idea, but I decided to not go this route since it's training regime specific, so it was not clear for me why the architecture of the model should care about it.


So let me check what others think and we can get back to it. Does that sound like a plan?

Yes, it does. 🙂
Evan mentioned a couple of things that I can take a look at, so I can keep myself busy.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants