-
Notifications
You must be signed in to change notification settings - Fork 659
Refactor losses instantiation and chunked CE #2531
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2531
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit f4b5fc1 with merge base 1be43b6 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
torchtune/modules/loss/sft_losses.py
Outdated
return total_loss / total_elements | ||
|
||
|
||
class ChunkedCrossEntropywithAutograd(torch.autograd.Function): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why did you want to add these Autograd versions? How does this help you test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this version is based off horace's code from a few months back. In this implementation, the chunks are not held up in memory. He coded it to show that you dont need trition.
I dont want to keep it in torchtune, because it would be hard to use to for KD/RL losses. This is more a reference for the compile folks. They are working on enabling the chunking on compile to match the autograd memory perf.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you put a comment in the code to that effect?
@felipemello1 this is awesome. Out of curioisity did you happen to benchmark against the existing I wonder if we could simplify the configuration further by removing the need for the user to also specify class BaseLoss(Protocol):
is_chunked: bool and do - if self.use_output_weight_in_loss:
+ if self.loss_fn.is_chunked:
weight = self._model.get_output_weight()
current_loss = self._loss_fn(weight, outputs, labels)
else:
labels = labels.reshape(-1)
logits = logits.reshape(-1, logits.size(-1))
outputs = outputs.reshape(-1, outputs.size(-1))
current_loss = self._loss_fn(outputs, labels) It would require either 1) requiring that all losses use this protocol (which tbh I wouldn't be opposed to as we start to support more custom losses without needing to modify recipes), or doing a wdyt? |
@SalmanMohammadi , i thought about it and even implemented, but then realized that it would be hard to support 3rd party libraries, unless we create some sort of loss adapter, which we may need to do anyway, because not all libraries follow the patten (weight, input, label). They may follow (label, weight, input), for example. the loss adapter could be something like: config.yaml
|
Co-authored-by: salman <[email protected]>
recipes/full_finetune_distributed.py
Outdated
# Shift labels to compute loss | ||
# equivalent to doing labels[..., 1:] and logits[..., :-1, :] | ||
# But this way we dont need to slice the logits. We just add an ignore index to labels. | ||
labels = torch.hstack( | ||
(labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]]) | ||
) | ||
if not isinstance(logits, list): | ||
|
||
if self.use_output_weight_in_loss: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
very nice
recipes/full_finetune_distributed.py
Outdated
# set num_output_chunks for model | ||
self._model.set_num_output_chunks(self._loss_fn.num_output_chunks) | ||
# The loss may handle the output projection. If true, the model should skip it. | ||
self.use_output_weight_in_loss = getattr( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tangential point: if the contract is that SFT losses follow the protocols defined in loss_protocols
, do we need to make this check?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
someone may try to use a loss that is not from torchtune, e.g. vanilla F.cross_entropy_loss
|
||
|
||
class SFTLoss(Protocol): | ||
"""Protocol for loss functions in torchtune used in sft recipes.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"""Protocol for loss functions in torchtune used in sft recipes.""" | |
"""Protocol for loss functions in torchtune used in SFT recipes.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I dont know if i like "SFT" here, since it may not be obvious for a new reader what it means
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, I am a new reader, and if I see all capitalized letters I immediately think that it's an abbreviation, not just a word.
Actually, it's an Initialism
as I've just learned.
|
||
|
||
class SFTLossWithProjection(Protocol): | ||
"""Protocol for loss functions in torchtune used in Supervised Finetune recipes and that require |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"""Protocol for loss functions in torchtune used in Supervised Finetune recipes and that require | |
"""Protocol for loss functions in torchtune used in SFT recipes and that require |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I prefer "SFTI dont know if i like "SFT" here, since it may not be obvious for a new reader what it means
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
real nice
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this big effort. This looks good and I'm happy to approve it now. Please finish going through and resolving the open comments before landing.
recipes/full_finetune_distributed.py
Outdated
# skip final projection, since the loss takes hidden input instead of logits | ||
self.skip_unembedding = cfg.get("loss_takes_embeddings", False) | ||
self._model.set_skip_unembedding(self.skip_unembedding) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: skip_output_layer
import torch | ||
|
||
|
||
class SFTLossWithProjection(Protocol): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that this name is confusing. I think we should just standardize on "fused" or "linear", or "chunked". All the names have issues which we've discussed but if we're consistent at least people should be able to learn the term quickly.
target_chunks[idx], | ||
) | ||
|
||
return total_loss / total_elements |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: it'd be nice to offer the same 'reduction' option as most pytorch losses to control returning the mean, sum, or no reduction
@@ -301,9 +301,12 @@ def setup(self, cfg: DictConfig) -> None: | |||
if self._compile: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the plan for rolling this out to the other sft recipes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Recipes NOT being updated should still work with configs NOT being updated
- Recipes being updated should NOT work anymore with old ce_with_chunked_outputs_loss
- So any recipe that is changed also requires the configs to be updated with the new loss
TODO: need to check if the deprecation warnings work fine. This can be checked by running a recipe/config that has not been updated.
@@ -396,6 +400,7 @@ def __init__( | |||
self.head_dim = head_dim | |||
self.causal_mask = None | |||
self.num_output_chunks = 0 | |||
self._skip_output_projection = False | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should enforce in init that the output module has the "weight" property
Co-authored-by: salman <[email protected]>
Co-authored-by: salman <[email protected]>
elif getattr(self._loss_fn, "linear_loss", False): | ||
raise ValueError( | ||
"Linear losses are not supported yet for KD. Please use the deprecated CEWithChunkedOutputLoss." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wrong error message? Also can we open a high-priority issue for this? I don't like being in a state where half our configs are on a deprecated API
msg = ( | ||
"'CEWithChunkedOutputLoss' is deprecated and will be removed in future versions. " | ||
"Please use `torchtune.modules.loss.LinearCrossEntropyLoss` instead." | ||
) | ||
log_once(logger=logger, msg=msg, level=logging.WARNING) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: isn't this what the deprecated decorator is for?
@@ -42,6 +52,13 @@ def compute_cross_entropy( | |||
logits.float(), labels, ignore_index=self.ignore_index, reduction="sum" | |||
) | |||
|
|||
def apply_compile_strategy(self, *args, **kwargs): | |||
"""Applies compile only to the fkl_loss function.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this isn't fkl?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, it's not :)
@@ -96,6 +100,12 @@ def __init__( | |||
def set_num_output_chunks(self, num_output_chunks: int) -> None: | |||
"""Used to save memory in combination with :class:`~torchtune.modules.loss.CEWithChunkedOutputLoss`. | |||
This should be called before the first forward pass, in the recipe.""" | |||
msg = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here, can we use deprecated decorator?
""" | ||
# Accessing the weight directly will not trigger FSDP hooks | ||
# to gather the full tensor so we have to unshard manually | ||
if isinstance(self.output, FSDPModule): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So this is not relevant for early fusion or deep fusion?
# to gather the full tensor so we have to unshard manually | ||
if isinstance(self.output, FSDPModule): | ||
self.output.unshard() | ||
weight = self.output.weight.clone() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so the fix was the removal of detach? also any memory implications of the clone here?
"""Protocol for loss functions in torchtune used in Supervised Finetune recipes that require | ||
model output linear projection weights in loss computation.""" | ||
|
||
linear_loss: bool = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we leave a comment in these classes explaining what this field means? Also an example usage in the docstrings of both would help a lot imo
class LinearCrossEntropyLoss(nn.Module, SFTLinearLoss): | ||
"""Memory efficient Cross-entropy loss that incrementally computes loss for chunks of tokens | ||
by masking ignored tokens, calculating logits and then applying cross-entropy loss. Combines | ||
the linear projection with the cross-entropy calculation for futher memory savings. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit
the linear projection with the cross-entropy calculation for futher memory savings. | |
the linear projection with the cross-entropy calculation for further memory savings. |
mask_pre_projection (bool): Whether to mask the output tensor before projection, avoiding | ||
computing it for tokens that will be ignored during CE anyway. Default is True. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe I'm out of the loop here, but why do we need to expose this? This doesn't seem like an intuitive parameter to me, is there a reason someone would want to modify it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm also curious.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a few more comments, but really happy with how this turned out. This addresses a longstanding problem of inflexibility in our losses with a clear UX and opens us up to other more memory-efficient CE implementations.
Co-authored-by: Felipe Mello <[email protected]> Co-authored-by: salman <[email protected]> Co-authored-by: Philip Bontrager <[email protected]> Co-authored-by: joecummings <[email protected]>
|
||
if not isinstance(logits, list): | ||
if self.linear_loss: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a question: what if we can encapsulate the logic specific to linear cross entropy in the LinearCrossEntropy
class itself?
When we check whether it's a linear ce or not, we can assign self._loss_fn.output_weights=...
and then just inject it during forward pass without a need to provide it explicitly.
This way we don't need a custom logic for loss calculation, so it could be unified for all losses.
Maybe it can somehow affect compilations? 🤷
Or there is a plan to have a functional version of this loss in the future?
Or it's just plain dumb? 😆
|
||
# Compute loss | ||
# Loss is normalized by default so we multiply by the number of tokens |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: is it indeed normalized?
For me, it's more like "aggregated".
Normalization is a different thing, no?
@@ -63,7 +63,7 @@ optimizer: | |||
lr: 2e-5 | |||
optimizer_in_bwd: True # True saves memory. Requires gradient_accumulation_steps=1 | |||
loss: | |||
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss | |||
_component_: torchtune.modules.loss.LinearCrossEntropyLoss |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Idea: turn on
compilation for LinearCrossEntropyLoss by default.
Without compilation there shouldn't be any benefits like online softmax, simultaneous logits and loss calculation, ...
So it won't be, em, linear at all 🤓
@@ -42,6 +52,13 @@ def compute_cross_entropy( | |||
logits.float(), labels, ignore_index=self.ignore_index, reduction="sum" | |||
) | |||
|
|||
def apply_compile_strategy(self, *args, **kwargs): | |||
"""Applies compile only to the fkl_loss function.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, it's not :)
mask_pre_projection (bool): Whether to mask the output tensor before projection, avoiding | ||
computing it for tokens that will be ignored during CE anyway. Default is True. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm also curious.
total_elements = mask.sum() | ||
|
||
# Chunk along sequence dimension | ||
hidden_chunks = outputs.tensor_split(self.num_output_chunks, dim=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Out of curiosity: does it make a difference in peak memory consumption in case of linear CE?
If compilation works as I anticipate it in case of LinearCrossEntropyLoss class, where logits calculation and loss calculation are within the same method and thus compilation can produce basically the same kernels as custom kernels for cut cross entropy, so if it's true, than chunking is no longer needed.
Maybe only when someone uses LinearCrossEntropyLoss without compilation 🤔
|
||
|
||
class SFTLoss(Protocol): | ||
"""Protocol for loss functions in torchtune used in sft recipes.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, I am a new reader, and if I see all capitalized letters I immediately think that it's an abbreviation, not just a word.
Actually, it's an Initialism
as I've just learned.
Hello @felipemello1 |
Context
What is the purpose of this PR? Is it to
IMPORTANT: Recipes do NOT work with older version of ChunkedCrosEntropy anymore, because we dont expect transformer to chunk the outputs.
Problem:
Solution:
PROFILING: https://drive.google.com/drive/folders/1jHOCuOF74F9lmmJv7wxbcK-i_wtB2stf?usp=sharing
Changelog
TODO: when approved, will implement it to the other recipes/losses/update configs
Test
ChunkedCrossEntropyLoss
To reproduce