4
4
# This source code is licensed under the BSD-style license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
+ from collections import defaultdict
8
+ from typing import Optional
9
+
7
10
import torch
8
- from torch import nn
11
+ from torch import nn , Tensor
12
+ from torch .nn .utils .clip_grad import _no_grad , _tensor_or_tensors
13
+ from torch .utils ._foreach_utils import _device_has_foreach_support , _has_foreach_support
9
14
10
15
11
16
def scale_grads (model : nn .Module , scaler : torch .Tensor ) -> None :
@@ -29,3 +34,70 @@ def scale_grads(model: nn.Module, scaler: torch.Tensor) -> None:
29
34
scaler = scaler .to (device )
30
35
if p .grad is not None :
31
36
p .grad *= scaler
37
+
38
+
39
+ @_no_grad
40
+ def scale_grads_ (
41
+ parameters : _tensor_or_tensors ,
42
+ scaler : torch .Tensor ,
43
+ foreach : Optional [bool ] = None ,
44
+ ) -> None :
45
+ r"""Scale gradients of iterable parameters.
46
+
47
+ This function is equivalent to :func:`torch.mul_` applied to each parameter.
48
+ Gradients are modified in-place, multiplying by specified scaler.
49
+
50
+ Args:
51
+ parameters (_tensor_or_tensors): an iterable of Tensors or a
52
+ single Tensor that will have gradients scaled
53
+ scaler (torch.Tensor): multiplier to scale gradients
54
+ foreach (Optional[bool]): use the faster foreach-based implementation.
55
+ If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently
56
+ fall back to the slow implementation for other device types.
57
+ Default: ``None``
58
+ Returns:
59
+ None
60
+ """
61
+ if isinstance (parameters , torch .Tensor ):
62
+ parameters = [parameters ]
63
+ else :
64
+ parameters = list (parameters )
65
+ _scale_grad_ (parameters , scaler , foreach )
66
+
67
+
68
+ def _group_tensors_by_device_and_dtype (
69
+ tensors : list [torch .Tensor ],
70
+ ) -> dict [tuple [torch .device , torch .dtype ], list [Tensor ]]:
71
+ ret = defaultdict (list )
72
+ for i , tensor in enumerate (tensors ):
73
+ ret [(tensor .device , tensor .dtype )].append (tensor )
74
+
75
+ return ret
76
+
77
+
78
+ @_no_grad
79
+ def _scale_grad_ (
80
+ parameters : _tensor_or_tensors ,
81
+ scaler : torch .Tensor ,
82
+ foreach : Optional [bool ] = None ,
83
+ ) -> None :
84
+ if isinstance (parameters , torch .Tensor ):
85
+ parameters = [parameters ]
86
+ grads = [p .grad for p in parameters if p .grad is not None ]
87
+ if len (grads ) == 0 :
88
+ return
89
+ grouped_grads = _group_tensors_by_device_and_dtype (grads )
90
+
91
+ for (device , _ ), device_grads in grouped_grads .items ():
92
+ if (foreach is None and _has_foreach_support (device_grads , device )) or (
93
+ foreach and _device_has_foreach_support (device )
94
+ ):
95
+ torch ._foreach_mul_ (device_grads , scaler .to (device ))
96
+ elif foreach :
97
+ raise RuntimeError (
98
+ f"foreach=True was passed, but can't use the foreach API on { device .type } tensors"
99
+ )
100
+ else :
101
+ scaler_device = scaler .to (device )
102
+ for g in device_grads :
103
+ g .mul_ (scaler_device )
0 commit comments