Skip to content

Commit e35ed5f

Browse files
xcnickzhouzaida
andauthored
[Feature] Add ApexOptimWrapper (#742)
* add ApexOptimWrapper * typo fix * add apex amp.initialize in optim_context * assert apex_amp * polish code * add parameters of apex_amp.initialize * add docs * polish code * polish code * polish code * fix calling of apex amp load_state_dict * polish * add comments * Update apex_optimizer_wrapper.py * Update apex_optimizer_wrapper.py --------- Co-authored-by: Zaida Zhou <[email protected]>
1 parent bc49e0c commit e35ed5f

File tree

6 files changed

+315
-8
lines changed

6 files changed

+315
-8
lines changed

docs/en/api/optim.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ Optimizer
2020
:template: classtemplate.rst
2121

2222
AmpOptimWrapper
23+
ApexOptimWrapper
2324
OptimWrapper
2425
OptimWrapperDict
2526
DefaultOptimWrapperConstructor

docs/zh_cn/api/optim.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ Optimizer
2020
:template: classtemplate.rst
2121

2222
AmpOptimWrapper
23+
ApexOptimWrapper
2324
OptimWrapper
2425
OptimWrapperDict
2526
DefaultOptimWrapperConstructor

mmengine/optim/__init__.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
from .optimizer import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS,
3-
AmpOptimWrapper, DefaultOptimWrapperConstructor,
4-
OptimWrapper, OptimWrapperDict, build_optim_wrapper)
3+
AmpOptimWrapper, ApexOptimWrapper,
4+
DefaultOptimWrapperConstructor, OptimWrapper,
5+
OptimWrapperDict, build_optim_wrapper)
56
# yapf: disable
67
from .scheduler import (ConstantLR, ConstantMomentum, ConstantParamScheduler,
78
CosineAnnealingLR, CosineAnnealingMomentum,
@@ -25,8 +26,8 @@
2526
'MultiStepMomentum', 'StepMomentum', 'ConstantParamScheduler',
2627
'CosineAnnealingParamScheduler', 'ExponentialParamScheduler',
2728
'LinearParamScheduler', 'MultiStepParamScheduler', 'StepParamScheduler',
28-
'_ParamScheduler', 'OptimWrapper', 'AmpOptimWrapper', 'OptimWrapperDict',
29-
'OneCycleParamScheduler', 'OneCycleLR', 'PolyLR', 'PolyMomentum',
30-
'PolyParamScheduler', 'ReduceOnPlateauLR', 'ReduceOnPlateauMomentum',
31-
'ReduceOnPlateauParamScheduler'
29+
'_ParamScheduler', 'OptimWrapper', 'AmpOptimWrapper', 'ApexOptimWrapper',
30+
'OptimWrapperDict', 'OneCycleParamScheduler', 'OneCycleLR', 'PolyLR',
31+
'PolyMomentum', 'PolyParamScheduler', 'ReduceOnPlateauLR',
32+
'ReduceOnPlateauMomentum', 'ReduceOnPlateauParamScheduler'
3233
]

mmengine/optim/optimizer/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
from .amp_optimizer_wrapper import AmpOptimWrapper
3+
from .apex_optimizer_wrapper import ApexOptimWrapper
34
from .builder import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS,
45
build_optim_wrapper)
56
from .default_constructor import DefaultOptimWrapperConstructor
@@ -10,5 +11,6 @@
1011
__all__ = [
1112
'OPTIM_WRAPPER_CONSTRUCTORS', 'OPTIMIZERS',
1213
'DefaultOptimWrapperConstructor', 'build_optim_wrapper', 'OptimWrapper',
13-
'AmpOptimWrapper', 'OptimWrapperDict', 'ZeroRedundancyOptimizer'
14+
'AmpOptimWrapper', 'ApexOptimWrapper', 'OptimWrapperDict',
15+
'ZeroRedundancyOptimizer'
1416
]
Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
from contextlib import contextmanager
3+
from typing import Optional, Union
4+
5+
import torch
6+
import torch.nn as nn
7+
8+
# a circular import will be caused by
9+
# from mmengine.model.wrappers import is_model_wrapper
10+
import mmengine
11+
from mmengine.registry import OPTIM_WRAPPERS
12+
from .optimizer_wrapper import OptimWrapper
13+
14+
try:
15+
import apex.amp as apex_amp
16+
except ImportError:
17+
apex_amp = None
18+
19+
20+
@OPTIM_WRAPPERS.register_module()
21+
class ApexOptimWrapper(OptimWrapper):
22+
"""A subclass of :class:`OptimWrapper` that supports automatic mixed
23+
precision training based on apex.amp.
24+
25+
``ApexOptimWrapper`` provides a unified interface with
26+
``OptimWrapper``, so it can be used in the same way as ``OptimWrapper``.
27+
28+
Warning:
29+
``ApexOptimWrapper`` requires `nvidia apex <https://github.com/NVIDIA/apex>`_
30+
31+
Args:
32+
opt_level (str): Pure or mixed precision optimization level. Accepted
33+
values are "O0", "O1", "O2", and "O3". Defaults to "O1".
34+
loss_scale (float or str, optional): If passed as a string, must be a
35+
string representing a number, e.g., "128.0", or the string
36+
"dynamic". Defaults to "dynamic".
37+
enabled (bool): If False, renders all Amp calls no-ops, so your script
38+
should run as if Amp were not present. Defaults to True.
39+
cast_model_type (torch.dtype, optional): Model's parameters and
40+
buffers to the desired type. Defaults to None.
41+
patch_torch_functions (bool, optional): Patch all Torch functions
42+
and Tensor methods to perform Tensor Core-friendly ops like GEMMs
43+
and convolutions in FP16, and any ops that benefit from FP32
44+
precision in FP32. Defaults to None.
45+
keep_batchnorm_fp32 (bool or str, optional): To enhance precision
46+
and enable cudnn batchnorm (which improves performance),
47+
it's often beneficial to keep batchnorm weights in FP32
48+
even if the rest of the model is FP16.
49+
If passed as a string, must be the string "True" or "False".
50+
Defaults to None.
51+
master_weights (bool, optional): Maintain FP32 master weights to
52+
accompany any FP16 model weights. FP32 master weights are stepped
53+
by the optimizer to enhance precision and capture small gradients.
54+
Defaults to None.
55+
cast_model_outputs (torch.dtype, optional): Option to ensure that
56+
the outputs of your model(s) are always cast to a particular type
57+
regardless of ``opt_level``. Defaults to None.
58+
num_losses (int): Option to tell Amp in advance how many
59+
losses/backward passes you plan to use. Defaults to 1.
60+
verbosity (int): Set to 0 to suppress Amp-related output.
61+
Defaults to 1.
62+
min_loss_scale (float, optional): Sets a floor for the loss scale
63+
values that can be chosen by dynamic loss scaling.
64+
The default value of None means that no floor is imposed.
65+
If dynamic loss scaling is not used, `min_loss_scale` is ignored.
66+
Defaults to None.
67+
max_loss_scale (float, optional): Sets a ceiling for the loss scale
68+
values that can be chosen by dynamic loss scaling. If dynamic
69+
loss scaling is not used, `max_loss_scale` is ignored.
70+
Defaults to 2.**24.
71+
**kwargs: Keyword arguments passed to OptimWrapper.
72+
73+
Note:
74+
If you use ``IterBasedRunner`` and enable gradient accumulation,
75+
the original `max_iters` should be multiplied by
76+
``accumulative_counts``.
77+
78+
Note:
79+
`New in version 0.6.0.`
80+
""" # noqa: E501
81+
82+
def __init__(self,
83+
opt_level: str = 'O1',
84+
loss_scale: Union[float, str, None] = 'dynamic',
85+
enabled: Optional[bool] = True,
86+
cast_model_type: Optional[torch.dtype] = None,
87+
patch_torch_functions: Optional[bool] = None,
88+
keep_batchnorm_fp32: Union[bool, str, None] = None,
89+
master_weights: Optional[bool] = None,
90+
cast_model_outputs: Optional[torch.dtype] = None,
91+
num_losses: int = 1,
92+
verbosity: int = 1,
93+
min_loss_scale: Optional[float] = None,
94+
max_loss_scale: Optional[float] = 2.**24,
95+
**kwargs):
96+
assert apex_amp is not None, \
97+
'Apex is not installed. Please check ' \
98+
'https://github.com/NVIDIA/apex#linux.'
99+
super().__init__(**kwargs)
100+
self.opt_level = opt_level
101+
self.loss_scale = loss_scale
102+
self.enabled = enabled
103+
self.cast_model_type = cast_model_type
104+
self.patch_torch_functions = patch_torch_functions
105+
self.keep_batchnorm_fp32 = keep_batchnorm_fp32
106+
self.master_weights = master_weights
107+
self.cast_model_outputs = cast_model_outputs
108+
self.num_losses = num_losses
109+
self.verbosity = verbosity
110+
self.min_loss_scale = min_loss_scale
111+
self.max_loss_scale = max_loss_scale
112+
self._apex_amp_state_dict = None
113+
114+
def backward(self, loss: torch.Tensor, **kwargs) -> None:
115+
"""Perform gradient back propagation with :attr:`loss_scaler`.
116+
117+
Args:
118+
loss (torch.Tensor): The loss of current iteration.
119+
kwargs: Keyword arguments passed to :meth:`torch.Tensor.backward`
120+
"""
121+
with apex_amp.scale_loss(loss, self.optimizer) as scaled_loss:
122+
scaled_loss.backward(**kwargs)
123+
self._inner_count += 1
124+
125+
def state_dict(self) -> dict:
126+
"""Get the state dictionary of :attr:`optimizer` and
127+
:attr:`apex_amp`.
128+
129+
Based on the state dictionary of the optimizer, the returned state
130+
dictionary will add a key named "apex_amp".
131+
132+
Returns:
133+
dict: The merged state dict of :attr:`apex_amp` and
134+
:attr:`optimizer`.
135+
"""
136+
state_dict = self.optimizer.state_dict()
137+
state_dict['apex_amp'] = apex_amp.state_dict()
138+
return state_dict
139+
140+
def load_state_dict(self, state_dict: dict) -> None:
141+
"""Load and parse the state dictionary of :attr:`optimizer` and
142+
:attr:`apex_amp`.
143+
144+
If state_dict contains "apex_amp", the :attr:`apex_amp` will
145+
load the corresponding keys. Otherwise, only the :attr:`optimizer`
146+
will load the state dictionary.
147+
148+
Note:
149+
:meth:`load_state_dict` shuold be called after
150+
`apex_amp.initialize` is called.
151+
Args:
152+
state_dict (dict): The state dict of :attr:`optimizer` and
153+
:attr:`apex_amp`
154+
"""
155+
if 'apex_amp' in state_dict:
156+
# when `apex_amp` is not initialized, calling `load_state_dict`
157+
# will raise an error, so we temporarily cache the apex_amp
158+
# part, and then load it into `apex_amp` after completing
159+
# the `apex_amp` initialization in `optim_context` method
160+
if hasattr(self.optimizer, '_amp_stash'):
161+
apex_amp.load_state_dict(state_dict.pop('apex_amp'))
162+
else:
163+
self._apex_amp_state_dict = state_dict.pop('apex_amp')
164+
self.optimizer.load_state_dict(state_dict)
165+
166+
@contextmanager
167+
def optim_context(self, model: nn.Module):
168+
"""Enables the context for mixed precision training, and enables the
169+
context for disabling gradient synchronization during gradient
170+
accumulation context.
171+
172+
Args:
173+
model (nn.Module): The training model.
174+
"""
175+
with super().optim_context(model):
176+
# when a given optimizer be passed through apex_amp.initialize,
177+
# the "_amp_stash" property will be added
178+
if not hasattr(self.optimizer, '_amp_stash'):
179+
if mmengine.model.wrappers.is_model_wrapper(model):
180+
model = model.module
181+
model, self.optimizer = apex_amp.initialize(
182+
model,
183+
self.optimizer,
184+
opt_level=self.opt_level,
185+
loss_scale=self.loss_scale,
186+
enabled=self.enabled,
187+
cast_model_type=self.cast_model_type,
188+
patch_torch_functions=self.patch_torch_functions,
189+
keep_batchnorm_fp32=self.keep_batchnorm_fp32,
190+
master_weights=self.master_weights,
191+
cast_model_outputs=self.cast_model_outputs,
192+
num_losses=self.num_losses,
193+
verbosity=self.verbosity,
194+
min_loss_scale=self.min_loss_scale,
195+
max_loss_scale=self.max_loss_scale)
196+
# loading apex_amp state_dict after initialization of apex_amp
197+
if self._apex_amp_state_dict is not None:
198+
apex_amp.load_state_dict(self._apex_amp_state_dict)
199+
self._apex_amp_state_dict = None
200+
yield

tests/test_optim/test_optimizer/test_optimizer_wrapper.py

Lines changed: 103 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,19 @@
1414

1515
from mmengine.dist import all_gather
1616
from mmengine.logging import MessageHub, MMLogger
17-
from mmengine.optim import AmpOptimWrapper, OptimWrapper
17+
from mmengine.optim import AmpOptimWrapper, ApexOptimWrapper, OptimWrapper
1818
from mmengine.testing import assert_allclose
1919
from mmengine.testing._internal import MultiProcessTestCase
2020
from mmengine.utils import digit_version
2121
from mmengine.utils.dl_utils import TORCH_VERSION
2222

23+
is_apex_available = False
24+
try:
25+
import apex.amp as apex_amp
26+
is_apex_available = True
27+
except ImportError:
28+
pass
29+
2330

2431
class ToyModel(nn.Module):
2532

@@ -283,6 +290,101 @@ def mock_methd(loss):
283290
optim_wrapper.zero_grad = MagicMock()
284291

285292

293+
@unittest.skipIf(not torch.cuda.is_available(), reason='need gpu to test Apex')
294+
class TestApexOptimWrapper(TestCase):
295+
296+
def setUp(self) -> None:
297+
self.model = ToyModel().cuda()
298+
self.optimizer = SGD(self.model.parameters(), lr=0.1)
299+
300+
@unittest.skipIf(
301+
not is_apex_available,
302+
reason='`apex` is not available, Please install apex from '
303+
'https://www.github.com/nvidia/apex')
304+
def test_init(self):
305+
apex_optim_wrapper = ApexOptimWrapper(
306+
optimizer=self.optimizer, opt_level='O1', loss_scale=1)
307+
with apex_optim_wrapper.optim_context(self.model):
308+
pass
309+
310+
@unittest.skipIf(
311+
not is_apex_available,
312+
reason='`apex` is not available, Please install apex from '
313+
'https://www.github.com/nvidia/apex')
314+
def test_step(self):
315+
optimizer = MagicMock(spec=Optimizer)
316+
apex_optim_wrapper = ApexOptimWrapper(
317+
optimizer=optimizer, opt_level='O1', loss_scale=1)
318+
with apex_optim_wrapper.optim_context(self.model):
319+
loss = self.model(torch.Tensor(1, 1, 1, 1).cuda())
320+
apex_optim_wrapper.backward(loss)
321+
apex_optim_wrapper.step()
322+
323+
@unittest.skipIf(
324+
not is_apex_available,
325+
reason='`apex` is not available, Please install apex from '
326+
'https://www.github.com/nvidia/apex')
327+
def test_backward(self):
328+
apex_optim_wrapper = ApexOptimWrapper(
329+
optimizer=self.optimizer, opt_level='O1', loss_scale=1)
330+
with apex_optim_wrapper.optim_context(self.model):
331+
loss = self.model(torch.Tensor(1, 1, 1, 1).cuda())
332+
apex_optim_wrapper.backward(loss)
333+
334+
@unittest.skipIf(
335+
not is_apex_available,
336+
reason='`apex` is not available, Please install apex from '
337+
'https://www.github.com/nvidia/apex')
338+
def test_state_dict(self):
339+
apex_optim_wrapper = ApexOptimWrapper(
340+
optimizer=self.optimizer, opt_level='O1', loss_scale=1)
341+
with apex_optim_wrapper.optim_context(self.model):
342+
loss = self.model(torch.Tensor(1, 1, 1, 1).cuda())
343+
apex_optim_wrapper.update_params(loss)
344+
state_dict = apex_optim_wrapper.state_dict()
345+
amp_state_dict = state_dict.pop('apex_amp')
346+
optim_state_dict = state_dict
347+
348+
self.assertDictEqual(optim_state_dict,
349+
apex_optim_wrapper.optimizer.state_dict())
350+
self.assertDictEqual(amp_state_dict, apex_amp.state_dict())
351+
352+
@unittest.skipIf(
353+
not is_apex_available,
354+
reason='`apex` is not available, Please install apex from '
355+
'https://www.github.com/nvidia/apex')
356+
def test_load_state_dict(self):
357+
apex_optim_wrapper = ApexOptimWrapper(
358+
optimizer=self.optimizer, opt_level='O1', loss_scale=1)
359+
with apex_optim_wrapper.optim_context(self.model):
360+
# Test load from optimizer
361+
optimizer = SGD(self.model.parameters(), lr=0.1)
362+
apex_optim_wrapper.load_state_dict(optimizer.state_dict())
363+
364+
self.assertDictEqual(optimizer.state_dict(),
365+
apex_optim_wrapper.optimizer.state_dict())
366+
# Test load from optim_wrapper
367+
apex_optim_wrapper = ApexOptimWrapper(optimizer=self.optimizer)
368+
apex_optim_wrapper_ = ApexOptimWrapper(
369+
optimizer=SGD(self.model.parameters(), lr=0.1))
370+
apex_optim_wrapper_.load_state_dict(
371+
apex_optim_wrapper.state_dict())
372+
self.assertDictEqual(apex_optim_wrapper.optimizer.state_dict(),
373+
apex_optim_wrapper_.optimizer.state_dict())
374+
375+
@unittest.skipIf(
376+
not is_apex_available,
377+
reason='`apex` is not available, Please install apex from '
378+
'https://www.github.com/nvidia/apex')
379+
def test_optim_context(self):
380+
apex_optim_wrapper = ApexOptimWrapper(
381+
optimizer=self.optimizer, opt_level='O1', loss_scale=1)
382+
with apex_optim_wrapper.optim_context(self.model):
383+
x = torch.randn(1, 1, 1, 1).cuda()
384+
y = nn.Conv2d(1, 1, 1).cuda()(x)
385+
self.assertEqual(y.dtype, torch.float16)
386+
387+
286388
class TestAmpOptimWrapper(TestCase):
287389

288390
def setUp(self) -> None:

0 commit comments

Comments
 (0)