-
Notifications
You must be signed in to change notification settings - Fork 562
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Linear Cross Entropy #2507
base: main
Are you sure you want to change the base?
Linear Cross Entropy #2507
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2507
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Hi @Andrei-Aksionov! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at [email protected]. Thanks! |
There are a couple of things that prevent this from moving from draft stage.
|
Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks! |
hey @Andrei-Aksionov , thanks for the PR! impressive numbers. Do you have any intuition for why tokens per second would double? From their paper, torch.compile is actually the fastest. I dont know if in their paper, when they compared vs Torchtune, they used compile=True. But, in any case, I wouldnt imagine that changing the the output_layer + CE takes would have such TPS impact (but my intuition here may be wrong, since its a 2B model + lora, maybe computing the output_layer for large vocab is the most expensive thing) https://arxiv.org/pdf/2411.09009 If you have bandwidth, could you try with llama 8b for a few steps CCE+compile vs cutloss compile? |
Hello @felipemello1
Well, I guess that's because it's Gemma 2 with only 2b parameters, but the vocab size is the same as for larger models from Gemma 2 family, so applying classification head and calculating loss has a significant effect on training speed. Plus it's a LoRA variant, so it makes it even more pronounced. As I understand, in case of a regular cross-entropy (chunked or not), the process is:
As you can see, a lot of copying operations compared to number of floating point operations. LCE, on the other hand:
As you can see, the proportion of The most interesting part to me, is that the That said, it could be a better option, since no additional dependencies required and no custom triton kernels.
I initially thought it might be unfair to compare a regular, non-compiled model with LCE. To address this, I included a compiled version in the comparison as well.
I'll try to. |
Hey @felipemello1 I've rerun test runs without compilation for both implementations of LCE (for completeness).
As you can see, even without a compilation, it's still ~2x faster than non-compiled CCE. I wasn't able to run Llama 3.1 8b, since I don't have access to the repo. Tried to request access a couple of times and now I'm blocked foreeeeeeveeer.
![]() Here the difference is less pronounced, since now the vocab size doesn't dwarf everything else, so doesn't have such a significant impact. Again, what's interesting, is the
With larger and larger models the memory savings will become negligible, but the improvements in the speed of training is a nice thing. What do you think, @felipemello1? |
hey @Andrei-Aksionov , this is amazing. Thanks for all the insightful experiments and comments. I think that, as a rule of thumb, we would be way more comfortable to add a torch-only implementation to torchtune than take dependency on a new repo. I don't know if their license allows copying it, but horace implemented an early version of pytorch-only CCE and shared it on twitter. Maybe we could use that? We are a bit swamped now with some other tasks, but I dont want to keep you waiting. Can i get back you next week after i meet with the team and check if/how we want to add this loss? The way you did works, but a) it adds some extra logic to the recipe (which is already a bit bloated), b) it modifies the model. So let me check what others think and we can get back to it. Does that sound like a plan? |
Also, food for thought: Many other losses in torchtune ended up implementing the chunked version (e.g. knowledge distilation and GRPO). Although the fused version is more efficient, i wonder how easy it would be for someone to try to repurpose it. Maybe we would need to keep the current chunked version around. |
That was my first idea, but I decided to not go this route since it's training regime specific, so it was not clear for me why the architecture of the model should care about it.
Yes, it does. 🙂 |
Hi there 👋
This PR adds support for Linear Cross Entropy loss (a.k.a. Cut Cross Entropy) proposed by Apple in CUT YOUR LOSSES
IN LARGE-VOCABULARY LANGUAGE MODELS paper by adding support of the official implementation.
So, from now on, we have:
CCE
: chunked cross entropyLCE
: linear cross entropyThe benefit of LCE is that it reduces VRAM usage, especially for models with large vocab sizes.
This is done with 3 things:
Custom Kernels.
LCE uses custom CUDA kernels to perform matrix multiplications and the log-sum-exp reduction in shared memory.
Usually you need to first calculate logits by multiplying embeddings (B, T, hidden_dim) with the output layer (hidden_dim, vocab_size) and then calculate loss with these logits and target labels.
LCE doesn't materialize this matrix with logits in the global memory (this matrix might be huge with modern LLMs that have large vocab size), but rather multiplies only a part (at a time) of output layer weights with logits in the shared (fast) memory and immediately calculates loss there (flash attention playbook).
This avoids materializing logits matrix (saves memory) and saves memory bandwidth (no need to move data from the global to shared memory and back multiple times).
Gradient filtering.
LCE leverages the sparsity of the softmax gradient to skip elements of the gradient computation that have a negligible contribution. This improves the throughput of LCE by reducing unnecessary computation.
Vocabulary sorting.
LCE uses vocabulary sorting to group tokens with similar average logits together. This increases the block-level sparsity and improves the effectiveness of gradient filtering.
I ran a quick single device LoRA recipe with
gemma 2 2b
model.This is the best possible case for LCE, since this model is relatively small, but has the same size of the vocab as larger models of gemma 2 family.
I ran 4 experiments:
CCE (non-compiled)
: chunked cross-entropy without compilation of model and lossCCE
: chunked cross-entropy with compilationLCE
: linear cross entropy with compilationLCE (torch_compile impl)
: an implementation for older GPUs and MPS devicesNote
To do this, one needs to change loss section of a config
As one can see, the loss chart is identical in all cases, yet LCE significantly reduced reserved memory size and reduced time for training.
Context
What is the purpose of this PR? Is it to
Please link to any issues this PR addresses.
Test plan
Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.
pre-commit install
)pytest tests
pytest tests -m integration_test
UX
If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example