Skip to content

Commit 5a9828b

Browse files
committed
llama4 distributed: scale_grads with foreach
ghstack-source-id: 961b75c Pull Request resolved: #2624
1 parent 6301e01 commit 5a9828b

File tree

3 files changed

+80
-3
lines changed

3 files changed

+80
-3
lines changed

recipes/full_finetune_distributed.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -926,8 +926,12 @@ def train(self) -> None:
926926
torch.distributed.all_reduce(num_tokens)
927927
# This will ensure that the logged loss matches what we're optimizing
928928
torch.distributed.all_reduce(running_loss)
929+
929930
# Manually scale the gradients from unnormalized loss by total # of tokens
930-
training.scale_grads(self._model, self.dp_degree / num_tokens)
931+
training.scale_grads_(
932+
self._model.parameters(), self.dp_degree / num_tokens
933+
)
934+
931935
if self._clip_grad_norm is not None:
932936
grad_norm = torch.nn.utils.clip_grad_norm_(
933937
self._model.parameters(),

torchtune/training/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
shard_model,
2626
validate_no_params_on_meta_device,
2727
)
28-
from torchtune.training._grad_scaler import scale_grads
28+
from torchtune.training._grad_scaler import scale_grads, scale_grads_
2929
from torchtune.training._model_util import disable_dropout
3030
from torchtune.training._profiler import (
3131
DEFAULT_PROFILE_DIR,
@@ -139,6 +139,7 @@
139139
"OffloadActivations",
140140
"FormattedCheckpointFiles",
141141
"scale_grads",
142+
"scale_grads_",
142143
"get_distributed_backend",
143144
"disable_dropout",
144145
"DATALOADER_KEY",

torchtune/training/_grad_scaler.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,13 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from collections import defaultdict
8+
from typing import Optional
9+
710
import torch
8-
from torch import nn
11+
from torch import nn, Tensor
12+
from torch.nn.utils.clip_grad import _no_grad, _tensor_or_tensors
13+
from torch.utils._foreach_utils import _device_has_foreach_support, _has_foreach_support
914

1015

1116
def scale_grads(model: nn.Module, scaler: torch.Tensor) -> None:
@@ -29,3 +34,70 @@ def scale_grads(model: nn.Module, scaler: torch.Tensor) -> None:
2934
scaler = scaler.to(device)
3035
if p.grad is not None:
3136
p.grad *= scaler
37+
38+
39+
@_no_grad
40+
def scale_grads_(
41+
parameters: _tensor_or_tensors,
42+
scaler: torch.Tensor,
43+
foreach: Optional[bool] = None,
44+
) -> None:
45+
r"""Scale gradients of iterable parameters.
46+
47+
This function is equivalent to :func:`torch.mul_` applied to each parameter.
48+
Gradients are modified in-place, multiplying by specified scaler.
49+
50+
Args:
51+
parameters (_tensor_or_tensors): an iterable of Tensors or a
52+
single Tensor that will have gradients scaled
53+
scaler (torch.Tensor): multiplier to scale gradients
54+
foreach (Optional[bool]): use the faster foreach-based implementation.
55+
If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently
56+
fall back to the slow implementation for other device types.
57+
Default: ``None``
58+
Returns:
59+
None
60+
"""
61+
if isinstance(parameters, torch.Tensor):
62+
parameters = [parameters]
63+
else:
64+
parameters = list(parameters)
65+
_scale_grad_(parameters, scaler, foreach)
66+
67+
68+
def _group_tensors_by_device_and_dtype(
69+
tensors: list[torch.Tensor],
70+
) -> dict[tuple[torch.device, torch.dtype], list[Tensor]]:
71+
ret = defaultdict(list)
72+
for i, tensor in enumerate(tensors):
73+
ret[(tensor.device, tensor.dtype)].append(tensor)
74+
75+
return ret
76+
77+
78+
@_no_grad
79+
def _scale_grad_(
80+
parameters: _tensor_or_tensors,
81+
scaler: torch.Tensor,
82+
foreach: Optional[bool] = None,
83+
) -> None:
84+
if isinstance(parameters, torch.Tensor):
85+
parameters = [parameters]
86+
grads = [p.grad for p in parameters if p.grad is not None]
87+
if len(grads) == 0:
88+
return
89+
grouped_grads = _group_tensors_by_device_and_dtype(grads)
90+
91+
for (device, _), device_grads in grouped_grads.items():
92+
if (foreach is None and _has_foreach_support(device_grads, device)) or (
93+
foreach and _device_has_foreach_support(device)
94+
):
95+
torch._foreach_mul_(device_grads, scaler.to(device))
96+
elif foreach:
97+
raise RuntimeError(
98+
f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
99+
)
100+
else:
101+
scaler_device = scaler.to(device)
102+
for g in device_grads:
103+
g.mul_(scaler_device)

0 commit comments

Comments
 (0)