Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
14736b8
compile optimizer
IvanKobzarev Apr 22, 2025
87fe9f7
scale_grads with foreach + compile
IvanKobzarev Apr 22, 2025
742ff81
Update base for Update on "scale_grads with foreach + compile"
IvanKobzarev Apr 22, 2025
7a6fb12
Update on "scale_grads with foreach + compile"
IvanKobzarev Apr 22, 2025
1dd7eb2
Update base for Update on "scale_grads with foreach + compile"
IvanKobzarev Apr 22, 2025
d0d42b8
Update on "scale_grads with foreach + compile"
IvanKobzarev Apr 22, 2025
0523776
Update base for Update on "scale_grads with foreach + compile"
IvanKobzarev Apr 22, 2025
93c05ce
Update on "scale_grads with foreach + compile"
IvanKobzarev Apr 22, 2025
eaa7a90
Update base for Update on "scale_grads with foreach + compile"
IvanKobzarev Apr 28, 2025
cfab1c3
Update on "scale_grads with foreach + compile"
IvanKobzarev Apr 28, 2025
a32d5c9
Update base for Update on "scale_grads with foreach + compile"
IvanKobzarev Apr 28, 2025
43934ab
Update on "scale_grads with foreach + compile"
IvanKobzarev Apr 28, 2025
da6ca90
Update base for Update on "scale_grads with foreach + compile"
IvanKobzarev Apr 28, 2025
63a8dca
Update on "scale_grads with foreach + compile"
IvanKobzarev Apr 28, 2025
69f9e16
Update base for Update on "scale_grads with foreach + compile"
IvanKobzarev Apr 28, 2025
8925f87
Update on "scale_grads with foreach + compile"
IvanKobzarev Apr 28, 2025
6242893
Update base for Update on "scale_grads with foreach + compile"
IvanKobzarev Apr 28, 2025
b0a279c
Update on "scale_grads with foreach + compile"
IvanKobzarev Apr 28, 2025
bd7584b
Update base for Update on "scale_grads with foreach + compile"
IvanKobzarev Apr 28, 2025
9e1ab29
Update on "scale_grads with foreach + compile"
IvanKobzarev Apr 28, 2025
452cd7d
Update base for Update on "scale_grads with foreach + compile"
IvanKobzarev May 2, 2025
a7aeb65
Update on "scale_grads with foreach + compile"
IvanKobzarev May 2, 2025
06e55ee
Update base for Update on "scale_grads with foreach + compile"
IvanKobzarev May 2, 2025
ae2e965
Update on "scale_grads with foreach + compile"
IvanKobzarev May 2, 2025
35020b2
Update base for Update on "scale_grads with foreach + compile"
IvanKobzarev May 6, 2025
45d150e
Update on "scale_grads with foreach + compile"
IvanKobzarev May 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions recipes/configs/llama4/scout_17B_16E_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ compile: False
# model: True
# loss: True
# optimizer_step: False
# scale_grads: True

# Reduced precision
dtype: bf16
Expand Down
15 changes: 14 additions & 1 deletion recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,10 +313,19 @@ def setup(self, cfg: DictConfig) -> None:
self._compile_model = compile_bool
self._compile_loss = compile_bool
self._compile_optimizer_step = compile_bool
self._compile_scale_grads = compile_bool
if isinstance(compile, DictConfig):
self._compile_model = compile.get("model", True)
self._compile_loss = compile.get("loss", True)
self._compile_optimizer_step = compile.get("optimizer_step", False)
self._compile_scale_grads = compile.get("scale_grads", True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit but I actually wonder whether it's worth having a separate flag for this. I think model, loss, optimizer step are all pretty clear but this one may be a bit niche for someone to look at and immediately understand (and also I think it's unlikely for someone to actually want to experiment with). Could even just use some heuristic like self._compile_scale_grads = all([self._compile_model, self._compile_loss, self._compile_optimizer_step])


# This indirection is needed to apply torch.compile to scale_grads step.
self._grad_scaler = training.scale_grads_
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you leave a comment here explaining that we need this indirection for things to work w/ PT2 compile for some reason?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, will add a comment.

if self._compile_scale_grads:
self._grad_scaler = torch.compile(
self._grad_scaler, backend=self._compile_backend
)

self._model = self._setup_model(
cfg_model=cfg.model,
Expand Down Expand Up @@ -932,8 +941,12 @@ def train(self) -> None:
torch.distributed.all_reduce(num_tokens)
# This will ensure that the logged loss matches what we're optimizing
torch.distributed.all_reduce(running_loss)

# Manually scale the gradients from unnormalized loss by total # of tokens
training.scale_grads(self._model, self.dp_degree / num_tokens)
self._grad_scaler(
self._model.parameters(), self.dp_degree / num_tokens
)

if self._clip_grad_norm is not None:
grad_norm = torch.nn.utils.clip_grad_norm_(
self._model.parameters(),
Expand Down
3 changes: 2 additions & 1 deletion torchtune/training/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
shard_model,
validate_no_params_on_meta_device,
)
from torchtune.training._grad_scaler import scale_grads
from torchtune.training._grad_scaler import scale_grads, scale_grads_
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this needs to be added to all

from torchtune.training._model_util import disable_dropout
from torchtune.training._profiler import (
DEFAULT_PROFILE_DIR,
Expand Down Expand Up @@ -139,6 +139,7 @@
"OffloadActivations",
"FormattedCheckpointFiles",
"scale_grads",
"scale_grads_",
"get_distributed_backend",
"disable_dropout",
"DATALOADER_KEY",
Expand Down
76 changes: 75 additions & 1 deletion torchtune/training/_grad_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,17 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from collections import defaultdict
from typing import Optional

import torch
from torch import nn
from torch import nn, Tensor
from torch.nn.utils.clip_grad import _no_grad, _tensor_or_tensors
from torch.utils._foreach_utils import _device_has_foreach_support, _has_foreach_support
from torchtune.utils._logging import deprecated


@deprecated(msg="Please use `scale_grads_` instead.")
def scale_grads(model: nn.Module, scaler: torch.Tensor) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ebsmothers Should we add a deprecation warning here then? No need to keep both.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Talked with Evan, let's add the deprecation label here.

"""
Utility to scale the gradients of a model.
Expand All @@ -29,3 +36,70 @@ def scale_grads(model: nn.Module, scaler: torch.Tensor) -> None:
scaler = scaler.to(device)
if p.grad is not None:
p.grad *= scaler


@_no_grad
def scale_grads_(
parameters: _tensor_or_tensors,
scaler: torch.Tensor,
foreach: Optional[bool] = None,
) -> None:
r"""Scale gradients of iterable parameters.

This function is equivalent to :func:`torch.mul_` applied to each parameter.
Gradients are modified in-place, multiplying by specified scaler.

Args:
parameters (_tensor_or_tensors): an iterable of Tensors or a
single Tensor that will have gradients scaled
scaler (torch.Tensor): multiplier to scale gradients
foreach (Optional[bool]): use the faster foreach-based implementation.
If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently
fall back to the slow implementation for other device types.
Default: ``None``
Returns:
None
"""
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
else:
parameters = list(parameters)
_scale_grad_(parameters, scaler, foreach)


def _group_tensors_by_device(
tensors: list[torch.Tensor],
) -> dict[torch.device, list[Tensor]]:
ret = defaultdict(list)
for i, tensor in enumerate(tensors):
ret[tensor.device].append(tensor)

return ret


@_no_grad
def _scale_grad_(
parameters: _tensor_or_tensors,
scaler: torch.Tensor,
foreach: Optional[bool] = None,
) -> None:
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
grads = [p.grad for p in parameters if p.grad is not None]
if len(grads) == 0:
return
grouped_grads = _group_tensors_by_device(grads)

for device, device_grads in grouped_grads.items():
if (foreach is None and _has_foreach_support(device_grads, device)) or (
foreach and _device_has_foreach_support(device)
):
torch._foreach_mul_(device_grads, scaler.to(device))
elif foreach:
raise RuntimeError(
f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
)
else:
scaler_device = scaler.to(device)
for g in device_grads:
g.mul_(scaler_device)
Loading