Skip to content

Commit cab831f

Browse files
committed
scale_grads with foreach + compile
ghstack-source-id: 081a1a9 Pull Request resolved: #2624
1 parent a1c77bc commit cab831f

File tree

4 files changed

+95
-3
lines changed

4 files changed

+95
-3
lines changed

recipes/full_finetune_distributed.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -913,8 +913,19 @@ def train(self) -> None:
913913
torch.distributed.all_reduce(num_tokens)
914914
# This will ensure that the logged loss matches what we're optimizing
915915
torch.distributed.all_reduce(running_loss)
916+
916917
# Manually scale the gradients from unnormalized loss by total # of tokens
917-
training.scale_grads(self._model, self.dp_degree / num_tokens)
918+
def scale_grads_fn():
919+
training.scale_grads_(
920+
self._model.parameters(), self.dp_degree / num_tokens
921+
)
922+
923+
if self._compile:
924+
training.compile_scale_grads(
925+
scale_grads_fn, verbose=self._is_rank_zero
926+
)()
927+
else:
928+
scale_grads_fn()
918929
if self._clip_grad_norm is not None:
919930
grad_norm = torch.nn.utils.clip_grad_norm_(
920931
self._model.parameters(),

torchtune/training/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
compile_loss,
1313
compile_model,
1414
compile_optimizer_step,
15+
compile_scale_grads,
1516
)
1617
from torchtune.training._distributed import (
1718
gather_cpu_state_dict,
@@ -29,7 +30,7 @@
2930
shard_model,
3031
validate_no_params_on_meta_device,
3132
)
32-
from torchtune.training._grad_scaler import scale_grads
33+
from torchtune.training._grad_scaler import scale_grads, scale_grads_
3334
from torchtune.training._model_util import disable_dropout
3435
from torchtune.training._profiler import (
3536
DEFAULT_PROFILE_DIR,
@@ -140,6 +141,7 @@
140141
"compile_loss",
141142
"compile_model",
142143
"compile_optimizer_step",
144+
"compile_scale_grads",
143145
"NoOpManager",
144146
"OffloadActivations",
145147
"FormattedCheckpointFiles",

torchtune/training/_compile.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,3 +93,10 @@ def compile_optimizer_step(optimizer_step_fn, verbose: bool = True):
9393
if verbose:
9494
log.info("Compiling optimizer step function with torch.compile...")
9595
return torch.compile(optimizer_step_fn, backend=backend)
96+
97+
98+
def compile_scale_grads(fn, verbose: bool = True):
99+
backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor")
100+
if verbose:
101+
log.info("Compiling scale_grads function with torch.compile...")
102+
return torch.compile(fn, backend=backend)

torchtune/training/_grad_scaler.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,13 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from collections import defaultdict
8+
from typing import Optional
9+
710
import torch
8-
from torch import nn
11+
from torch import nn, Tensor
12+
from torch.nn.utils.clip_grad import _no_grad, _tensor_or_tensors
13+
from torch.utils._foreach_utils import _device_has_foreach_support, _has_foreach_support
914

1015

1116
def scale_grads(model: nn.Module, scaler: torch.Tensor) -> None:
@@ -29,3 +34,70 @@ def scale_grads(model: nn.Module, scaler: torch.Tensor) -> None:
2934
scaler = scaler.to(device)
3035
if p.grad is not None:
3136
p.grad *= scaler
37+
38+
39+
@_no_grad
40+
def scale_grads_(
41+
parameters: _tensor_or_tensors,
42+
scaler: torch.Tensor,
43+
foreach: Optional[bool] = None,
44+
) -> None:
45+
r"""Scale gradients of iterable parameters.
46+
47+
This function is equivalent to :func:`torch.mul_` applied to each parameter.
48+
Gradients are modified in-place, multiplying by specified scaler.
49+
50+
Args:
51+
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
52+
single Tensor that will have gradients scaled
53+
scaler (Tensor): multiplier to scale gradients
54+
foreach (bool): use the faster foreach-based implementation.
55+
If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently
56+
fall back to the slow implementation for other device types.
57+
Default: ``None``
58+
Returns:
59+
None
60+
"""
61+
if isinstance(parameters, torch.Tensor):
62+
parameters = [parameters]
63+
else:
64+
parameters = list(parameters)
65+
_scale_grad_(parameters, scaler, foreach)
66+
67+
68+
def _group_tensors_by_device_and_dtype(
69+
tensors: list[torch.Tensor],
70+
) -> dict[tuple[torch.device, torch.dtype], list[Tensor]]:
71+
ret = defaultdict(list)
72+
for i, tensor in enumerate(tensors):
73+
ret[(tensor.device, tensor.dtype)].append(tensor)
74+
75+
return ret
76+
77+
78+
@_no_grad
79+
def _scale_grad_(
80+
parameters: _tensor_or_tensors,
81+
scaler: torch.Tensor,
82+
foreach: Optional[bool] = None,
83+
) -> None:
84+
if isinstance(parameters, torch.Tensor):
85+
parameters = [parameters]
86+
grads = [p.grad for p in parameters if p.grad is not None]
87+
if len(grads) == 0:
88+
return
89+
grouped_grads = _group_tensors_by_device_and_dtype(grads)
90+
91+
for (device, _), device_grads in grouped_grads.items():
92+
if (foreach is None and _has_foreach_support(device_grads, device)) or (
93+
foreach and _device_has_foreach_support(device)
94+
):
95+
torch._foreach_mul_(device_grads, scaler.to(device))
96+
elif foreach:
97+
raise RuntimeError(
98+
f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
99+
)
100+
else:
101+
scaler_device = scaler.to(device)
102+
for g in device_grads:
103+
g.mul_(scaler_device)

0 commit comments

Comments
 (0)