44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7+ from collections import defaultdict
8+ from typing import Optional
9+
710import torch
8- from torch import nn
11+ from torch import nn , Tensor
12+ from torch .nn .utils .clip_grad import _no_grad , _tensor_or_tensors
13+ from torch .utils ._foreach_utils import _device_has_foreach_support , _has_foreach_support
914
1015
1116def scale_grads (model : nn .Module , scaler : torch .Tensor ) -> None :
@@ -29,3 +34,70 @@ def scale_grads(model: nn.Module, scaler: torch.Tensor) -> None:
2934 scaler = scaler .to (device )
3035 if p .grad is not None :
3136 p .grad *= scaler
37+
38+
39+ @_no_grad
40+ def scale_grads_ (
41+ parameters : _tensor_or_tensors ,
42+ scaler : torch .Tensor ,
43+ foreach : Optional [bool ] = None ,
44+ ) -> None :
45+ r"""Scale gradients of iterable parameters.
46+
47+ This function is equivalent to :func:`torch.mul_` applied to each parameter.
48+ Gradients are modified in-place, multiplying by specified scaler.
49+
50+ Args:
51+ parameters (_tensor_or_tensors): an iterable of Tensors or a
52+ single Tensor that will have gradients scaled
53+ scaler (torch.Tensor): multiplier to scale gradients
54+ foreach (Optional[bool]): use the faster foreach-based implementation.
55+ If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently
56+ fall back to the slow implementation for other device types.
57+ Default: ``None``
58+ Returns:
59+ None
60+ """
61+ if isinstance (parameters , torch .Tensor ):
62+ parameters = [parameters ]
63+ else :
64+ parameters = list (parameters )
65+ _scale_grad_ (parameters , scaler , foreach )
66+
67+
68+ def _group_tensors_by_device (
69+ tensors : list [torch .Tensor ],
70+ ) -> dict [torch .device , list [Tensor ]]:
71+ ret = defaultdict (list )
72+ for i , tensor in enumerate (tensors ):
73+ ret [tensor .device ].append (tensor )
74+
75+ return ret
76+
77+
78+ @_no_grad
79+ def _scale_grad_ (
80+ parameters : _tensor_or_tensors ,
81+ scaler : torch .Tensor ,
82+ foreach : Optional [bool ] = None ,
83+ ) -> None :
84+ if isinstance (parameters , torch .Tensor ):
85+ parameters = [parameters ]
86+ grads = [p .grad for p in parameters if p .grad is not None ]
87+ if len (grads ) == 0 :
88+ return
89+ grouped_grads = _group_tensors_by_device (grads )
90+
91+ for device , device_grads in grouped_grads .items ():
92+ if (foreach is None and _has_foreach_support (device_grads , device )) or (
93+ foreach and _device_has_foreach_support (device )
94+ ):
95+ torch ._foreach_mul_ (device_grads , scaler .to (device ))
96+ elif foreach :
97+ raise RuntimeError (
98+ f"foreach=True was passed, but can't use the foreach API on { device .type } tensors"
99+ )
100+ else :
101+ scaler_device = scaler .to (device )
102+ for g in device_grads :
103+ g .mul_ (scaler_device )
0 commit comments