Skip to content

Linear Cross Entropy #2507

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/install.rst
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ you can also install the package locally with the following command.
pip install -e .

# or for a developer installation
pip install -e .["dev"]
pip install -e ".[dev]"

|

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ dev = [
"urllib3<2.0.0",
"wandb",
"expecttest",
"cut-cross-entropy @ git+https://github.com/apple/ml-cross-entropy.git",
]

[tool.setuptools.dynamic]
Expand Down
40 changes: 38 additions & 2 deletions recipes/lora_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,9 +297,38 @@ def setup(self, cfg: DictConfig) -> None:
if self._compile:
self._loss_fn = training.compile_loss(self._loss_fn)

if self._loss_fn.__class__.__name__ == "CEWithChunkedOutputLoss":
# save the name of the loss class to reuse later
self.loss_class_name = (
self._loss_fn._orig_mod.__class__.__name__
if hasattr(self._loss_fn, "_orig_mod")
else self._loss_fn.__class__.__name__
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noob question: what's the issue with always using self._loss_fn.__class__.__name__?

from torchtune import config

# Create an empty dictionary
short_dict = {
    "loss1":
        {
            "_component_": "torchtune.modules.loss.CEWithChunkedOutputLoss",
        },
    "loss2":
        {
            "_component_": "cut_cross_entropy.LinearCrossEntropy",
        }
}

# Convert the empty dictionary to an OmegaConf DictConfig object
from omegaconf import OmegaConf

cfg = OmegaConf.create(short_dict)
loss1 = config.instantiate(cfg.loss1)
loss2 = config.instantiate(cfg.loss2)

print(loss1.__class__.__name__)
print(loss2.__class__.__name__)

The above toy example prints

CEWithChunkedOutputLoss
LinearCrossEntropy

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello, @nathan-az
Unfortunately, I didn't fully understand the question 🙃.
Are you asking why do we need to know the name of the loss function in the first place?
If so, then we need to know it since different loss functions require specific changes in the model.
If not, then a clarification would be nice 😊.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i guess the question is why do we need "_orig_mod". Is that right? I have no clue.

Btw, we can have a nicer abstraction to this. Maybe have the losses follow a protocol, and we can check something like: "hasattr(loss, module_to_compile)", compile(loss.module_to_compile), else compile(loss).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

orig_mod is needed because if I compile the loss with torch.compile, then the class will be wrapped into OptimizedModule (for LinearCrossEntropy). So thus I need to go deeper to retrieve the proper name.

There are many ways to make this PR nicer.
The whole logic could be wrapped in a new loss class, and in the training recipe only something like will be added

if loss_class_name == "LinearCrossEntropy":
    self.loss_fn.prepare_model(model)

and the rest will be the same. Including loss_step function.

Or, if the core team decides to keep a fused version of the loss (like in torch_compile impl) variant, then we could have FusedLoss class that will contain any loss function and in the forward call will do the trick, the yaml file will have fused argument:

loss:
    __component__: ...
    fused: True

I just need to know what is the decision.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for clarifying @felipemello1 and @Andrei-Aksionov - that does clarify and answer my question.

Agreed that a nicer abstraction is desired here. The manual handling of the output layer, logits and losses is an unfortunate side-effect too. I wonder if there's a nicer pattern for handling that.


if self.loss_class_name == "CEWithChunkedOutputLoss":
# set num_output_chunks for model
self._model.set_num_output_chunks(self._loss_fn.num_output_chunks)
elif self.loss_class_name == "LinearCrossEntropy":
# LCE doesn't call forward method of the class, but rather multiples embeddings with output weights.
# This is why PEFT methods (that use additional weights with custom logic) are not supported in LCE.
if isinstance(
self._model.output, (modules.peft.LoRALinear, modules.peft.DoRALinear)
):
raise ValueError(
"PEFT (LoRA/DoRA) applied to output is not supported with LinearCrossEntropy."
)

# LCE does multiplication of embeddings with output weights by itself in an efficient way.
# Thus we need to disable the forward method of the output layer.
if isinstance(self._model.output, nn.Linear):
self._model.output.forward = lambda x: x
self._output_weights = self._model.output.weight
elif isinstance(self._model.output, modules.TiedLinear):
self._model.output.linear.forward = lambda x, _: x
self._output_weights = self._model.output.tied_module.weight
else:
raise ValueError(
f"`{type(self._model.output)}` for output layer is not supported with LinearCrossEntropy."
)

log.info("Loss is initialized.")

Expand Down Expand Up @@ -653,7 +682,14 @@ def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
labels = labels.reshape(-1)
logits = logits.reshape(-1, logits.size(-1))

loss = self._loss_fn(logits, labels)
if self.loss_class_name == "LinearCrossEntropy":
# LCE doesn't materialize logits in the global memory, but instead multiples embeddings with output weights
# in the shared memory and calculates loss. This is done to save memory.
# This is why we need to provide embeddings (logits), output weights and labels.
logits = logits.type_as(self._output_weights)
loss = self._loss_fn(logits, self._output_weights, labels)
else:
loss = self._loss_fn(logits, labels)

# free logits otherwise it peaks backward memory
del logits
Expand Down