Skip to content

Refactor losses instantiation and chunked CE #2531

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Apr 30, 2025

Conversation

felipemello1
Copy link
Contributor

@felipemello1 felipemello1 commented Mar 27, 2025

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

IMPORTANT: Recipes do NOT work with older version of ChunkedCrosEntropy anymore, because we dont expect transformer to chunk the outputs.

Problem:

  1. We have seen many chunked losses being added to torchtune. The current setup put the chunking burden on the model.
  2. Users have interest in using losses that require model.output.weight as input, e.g. liger losses

Solution:

  1. Enable the recipe to call loss(weight, input, targets)
  2. Reimplement ChunkedCE, so that chunking and projection happens in the loss.
  3. Adds protocol so that new losses can follow the same pattern

PROFILING: https://drive.google.com/drive/folders/1jHOCuOF74F9lmmJv7wxbcK-i_wtB2stf?usp=sharing

Changelog

  • Updated full_distributed and lora_distributed
  • Tested with lora llama 3.2 distributed (TiedLinear)
  • Implemented new ChunkedCE

TODO: when approved, will implement it to the other recipes/losses/update configs

Test

ChunkedCrossEntropyLoss

tune run --nproc_per_node 2 lora_finetune_distributed --config /data/users/felipemello/torchtune/recipes/configs/llama3_2/1B_lora.yaml \
metric_logger=torchtune.training.metric_logging.WandBLogger \
dataset.packed=True \
dataset.split=train[:50%] \
tokenizer.max_seq_len=4096 \
gradient_accumulation_steps=1 \
batch_size=4 \
max_steps_per_epoch=20 \
compile=True \
use_output_weight_in_loss=True \
loss=torchtune.modules.loss.sft_losses.ChunkedCrossEntropyLoss

image

To reproduce

fork ----> https://github.com/pytorch/torchtune
git clone https://github.com/<YOUR_GITHUB_USER>/torchtune.git

cd torchtune
conda create -n torchtune python=3.11
conda activate torchtune
pip install --pre --upgrade torch torchvision torchao --index-url https://download.pytorch.org/whl/nightly/cu124
pip install -e ".[dev]"
pre-commit install

git remote add felipemello1 https://github.com/felipemello1/torchtune.git
git checkout -b loss_refactor felipemello1/loss_refactor

Copy link

pytorch-bot bot commented Mar 27, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2531

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit f4b5fc1 with merge base 1be43b6 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 27, 2025
@felipemello1 felipemello1 mentioned this pull request Mar 27, 2025
13 tasks
@felipemello1 felipemello1 changed the title Refactor losses installation and chunked CE Refactor losses instantiation and chunked CE Mar 27, 2025
@felipemello1 felipemello1 marked this pull request as draft March 31, 2025 14:39
@felipemello1 felipemello1 marked this pull request as ready for review March 31, 2025 22:16
return total_loss / total_elements


class ChunkedCrossEntropywithAutograd(torch.autograd.Function):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you want to add these Autograd versions? How does this help you test?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this version is based off horace's code from a few months back. In this implementation, the chunks are not held up in memory. He coded it to show that you dont need trition.

I dont want to keep it in torchtune, because it would be hard to use to for KD/RL losses. This is more a reference for the compile folks. They are working on enabling the chunking on compile to match the autograd memory perf.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without autograd
image

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with autograd

image

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you put a comment in the code to that effect?

@SalmanMohammadi
Copy link
Contributor

@felipemello1 this is awesome. Out of curioisity did you happen to benchmark against the existing CEWithChunkedOutputLoss?

I wonder if we could simplify the configuration further by removing the need for the user to also specify use_output_weight_in_loss? Could we define a

class BaseLoss(Protocol):
    is_chunked: bool

and do

-                if self.use_output_weight_in_loss:
+                if self.loss_fn.is_chunked:
                    weight = self._model.get_output_weight()
                    current_loss = self._loss_fn(weight, outputs, labels)
                else:
                    labels = labels.reshape(-1)
                    logits = logits.reshape(-1, logits.size(-1))
                    outputs = outputs.reshape(-1, outputs.size(-1))
                    current_loss = self._loss_fn(outputs, labels)

It would require either 1) requiring that all losses use this protocol (which tbh I wouldn't be opposed to as we start to support more custom losses without needing to modify recipes), or doing a hasattr check on self._loss_fn and relying on an identifying field on just the chunked losses.

wdyt?

@felipemello1
Copy link
Contributor Author

felipemello1 commented Apr 4, 2025

I wonder if we could simplify the configuration further by removing the need for the user to also specify use_output_weight_in_loss?

@SalmanMohammadi , i thought about it and even implemented, but then realized that it would be hard to support 3rd party libraries, unless we create some sort of loss adapter, which we may need to do anyway, because not all libraries follow the patten (weight, input, label). They may follow (label, weight, input), for example.

the loss adapter could be something like:

config.yaml

loss:
	_component_: torchtune.loss.lossadapter
   loss: path.to.loss
   requires_weight_input: True
   input_order: ["label", "weight", "input"]

# Shift labels to compute loss
# equivalent to doing labels[..., 1:] and logits[..., :-1, :]
# But this way we dont need to slice the logits. We just add an ignore index to labels.
labels = torch.hstack(
(labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]])
)
if not isinstance(logits, list):

if self.use_output_weight_in_loss:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very nice

# set num_output_chunks for model
self._model.set_num_output_chunks(self._loss_fn.num_output_chunks)
# The loss may handle the output projection. If true, the model should skip it.
self.use_output_weight_in_loss = getattr(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tangential point: if the contract is that SFT losses follow the protocols defined in loss_protocols, do we need to make this check?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

someone may try to use a loss that is not from torchtune, e.g. vanilla F.cross_entropy_loss



class SFTLoss(Protocol):
"""Protocol for loss functions in torchtune used in sft recipes."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""Protocol for loss functions in torchtune used in sft recipes."""
"""Protocol for loss functions in torchtune used in SFT recipes."""

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont know if i like "SFT" here, since it may not be obvious for a new reader what it means

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, I am a new reader, and if I see all capitalized letters I immediately think that it's an abbreviation, not just a word.
Actually, it's an Initialism as I've just learned.



class SFTLossWithProjection(Protocol):
"""Protocol for loss functions in torchtune used in Supervised Finetune recipes and that require
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""Protocol for loss functions in torchtune used in Supervised Finetune recipes and that require
"""Protocol for loss functions in torchtune used in SFT recipes and that require

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer "SFTI dont know if i like "SFT" here, since it may not be obvious for a new reader what it means

Copy link
Contributor

@SalmanMohammadi SalmanMohammadi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

real nice

Copy link
Contributor

@pbontrager pbontrager left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this big effort. This looks good and I'm happy to approve it now. Please finish going through and resolving the open comments before landing.

Comment on lines 347 to 349
# skip final projection, since the loss takes hidden input instead of logits
self.skip_unembedding = cfg.get("loss_takes_embeddings", False)
self._model.set_skip_unembedding(self.skip_unembedding)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: skip_output_layer

import torch


class SFTLossWithProjection(Protocol):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that this name is confusing. I think we should just standardize on "fused" or "linear", or "chunked". All the names have issues which we've discussed but if we're consistent at least people should be able to learn the term quickly.

target_chunks[idx],
)

return total_loss / total_elements
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it'd be nice to offer the same 'reduction' option as most pytorch losses to control returning the mean, sum, or no reduction

@@ -301,9 +301,12 @@ def setup(self, cfg: DictConfig) -> None:
if self._compile:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the plan for rolling this out to the other sft recipes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Recipes NOT being updated should still work with configs NOT being updated
  2. Recipes being updated should NOT work anymore with old ce_with_chunked_outputs_loss
  3. So any recipe that is changed also requires the configs to be updated with the new loss

TODO: need to check if the deprecation warnings work fine. This can be checked by running a recipe/config that has not been updated.

@@ -396,6 +400,7 @@ def __init__(
self.head_dim = head_dim
self.causal_mask = None
self.num_output_chunks = 0
self._skip_output_projection = False

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should enforce in init that the output module has the "weight" property

Comment on lines +360 to +363
elif getattr(self._loss_fn, "linear_loss", False):
raise ValueError(
"Linear losses are not supported yet for KD. Please use the deprecated CEWithChunkedOutputLoss."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wrong error message? Also can we open a high-priority issue for this? I don't like being in a state where half our configs are on a deprecated API

Comment on lines +37 to +41
msg = (
"'CEWithChunkedOutputLoss' is deprecated and will be removed in future versions. "
"Please use `torchtune.modules.loss.LinearCrossEntropyLoss` instead."
)
log_once(logger=logger, msg=msg, level=logging.WARNING)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: isn't this what the deprecated decorator is for?

@@ -42,6 +52,13 @@ def compute_cross_entropy(
logits.float(), labels, ignore_index=self.ignore_index, reduction="sum"
)

def apply_compile_strategy(self, *args, **kwargs):
"""Applies compile only to the fkl_loss function."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this isn't fkl?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it's not :)

@@ -96,6 +100,12 @@ def __init__(
def set_num_output_chunks(self, num_output_chunks: int) -> None:
"""Used to save memory in combination with :class:`~torchtune.modules.loss.CEWithChunkedOutputLoss`.
This should be called before the first forward pass, in the recipe."""
msg = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, can we use deprecated decorator?

"""
# Accessing the weight directly will not trigger FSDP hooks
# to gather the full tensor so we have to unshard manually
if isinstance(self.output, FSDPModule):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this is not relevant for early fusion or deep fusion?

# to gather the full tensor so we have to unshard manually
if isinstance(self.output, FSDPModule):
self.output.unshard()
weight = self.output.weight.clone()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so the fix was the removal of detach? also any memory implications of the clone here?

"""Protocol for loss functions in torchtune used in Supervised Finetune recipes that require
model output linear projection weights in loss computation."""

linear_loss: bool = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we leave a comment in these classes explaining what this field means? Also an example usage in the docstrings of both would help a lot imo

class LinearCrossEntropyLoss(nn.Module, SFTLinearLoss):
"""Memory efficient Cross-entropy loss that incrementally computes loss for chunks of tokens
by masking ignored tokens, calculating logits and then applying cross-entropy loss. Combines
the linear projection with the cross-entropy calculation for futher memory savings.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit

Suggested change
the linear projection with the cross-entropy calculation for futher memory savings.
the linear projection with the cross-entropy calculation for further memory savings.

Comment on lines +31 to +32
mask_pre_projection (bool): Whether to mask the output tensor before projection, avoiding
computing it for tokens that will be ignored during CE anyway. Default is True.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I'm out of the loop here, but why do we need to expose this? This doesn't seem like an intuitive parameter to me, is there a reason someone would want to modify it?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm also curious.

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a few more comments, but really happy with how this turned out. This addresses a longstanding problem of inflexibility in our losses with a clear UX and opens us up to other more memory-efficient CE implementations.

@pbontrager pbontrager merged commit 9c06c8b into pytorch:main Apr 30, 2025
14 checks passed
Darktex pushed a commit to Darktex/torchtune that referenced this pull request Apr 30, 2025
Co-authored-by: Felipe Mello <[email protected]>
Co-authored-by: salman <[email protected]>
Co-authored-by: Philip Bontrager <[email protected]>
Co-authored-by: joecummings <[email protected]>

if not isinstance(logits, list):
if self.linear_loss:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a question: what if we can encapsulate the logic specific to linear cross entropy in the LinearCrossEntropy class itself?
When we check whether it's a linear ce or not, we can assign self._loss_fn.output_weights=... and then just inject it during forward pass without a need to provide it explicitly.
This way we don't need a custom logic for loss calculation, so it could be unified for all losses.

Maybe it can somehow affect compilations? 🤷
Or there is a plan to have a functional version of this loss in the future?
Or it's just plain dumb? 😆


# Compute loss
# Loss is normalized by default so we multiply by the number of tokens

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: is it indeed normalized?
For me, it's more like "aggregated".
Normalization is a different thing, no?

@@ -63,7 +63,7 @@ optimizer:
lr: 2e-5
optimizer_in_bwd: True # True saves memory. Requires gradient_accumulation_steps=1
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
_component_: torchtune.modules.loss.LinearCrossEntropyLoss

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Idea: turn on compilation for LinearCrossEntropyLoss by default.
Without compilation there shouldn't be any benefits like online softmax, simultaneous logits and loss calculation, ...
So it won't be, em, linear at all 🤓

@@ -42,6 +52,13 @@ def compute_cross_entropy(
logits.float(), labels, ignore_index=self.ignore_index, reduction="sum"
)

def apply_compile_strategy(self, *args, **kwargs):
"""Applies compile only to the fkl_loss function."""

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it's not :)

Comment on lines +31 to +32
mask_pre_projection (bool): Whether to mask the output tensor before projection, avoiding
computing it for tokens that will be ignored during CE anyway. Default is True.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm also curious.

total_elements = mask.sum()

# Chunk along sequence dimension
hidden_chunks = outputs.tensor_split(self.num_output_chunks, dim=1)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity: does it make a difference in peak memory consumption in case of linear CE?

If compilation works as I anticipate it in case of LinearCrossEntropyLoss class, where logits calculation and loss calculation are within the same method and thus compilation can produce basically the same kernels as custom kernels for cut cross entropy, so if it's true, than chunking is no longer needed.

Maybe only when someone uses LinearCrossEntropyLoss without compilation 🤔



class SFTLoss(Protocol):
"""Protocol for loss functions in torchtune used in sft recipes."""

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, I am a new reader, and if I see all capitalized letters I immediately think that it's an abbreviation, not just a word.
Actually, it's an Initialism as I've just learned.

@Andrei-Aksionov
Copy link

Hello @felipemello1
Sorry for the late review of this PR. Just left a couple of comments.
But in overall, great job!

@felipemello1 felipemello1 deleted the loss_refactor branch May 5, 2025 18:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants