|
| 1 | +# Copyright (c) OpenMMLab. All rights reserved. |
| 2 | +from contextlib import contextmanager |
| 3 | +from typing import Optional, Union |
| 4 | + |
| 5 | +import torch |
| 6 | +import torch.nn as nn |
| 7 | + |
| 8 | +# a circular import will be caused by |
| 9 | +# from mmengine.model.wrappers import is_model_wrapper |
| 10 | +import mmengine |
| 11 | +from mmengine.registry import OPTIM_WRAPPERS |
| 12 | +from .optimizer_wrapper import OptimWrapper |
| 13 | + |
| 14 | +try: |
| 15 | + import apex.amp as apex_amp |
| 16 | +except ImportError: |
| 17 | + apex_amp = None |
| 18 | + |
| 19 | + |
| 20 | +@OPTIM_WRAPPERS.register_module() |
| 21 | +class ApexOptimWrapper(OptimWrapper): |
| 22 | + """A subclass of :class:`OptimWrapper` that supports automatic mixed |
| 23 | + precision training based on apex.amp. |
| 24 | +
|
| 25 | + ``ApexOptimWrapper`` provides a unified interface with |
| 26 | + ``OptimWrapper``, so it can be used in the same way as ``OptimWrapper``. |
| 27 | +
|
| 28 | + Warning: |
| 29 | + ``ApexOptimWrapper`` requires `nvidia apex <https://github.com/NVIDIA/apex>`_ |
| 30 | +
|
| 31 | + Args: |
| 32 | + opt_level (str): Pure or mixed precision optimization level. Accepted |
| 33 | + values are "O0", "O1", "O2", and "O3". Defaults to "O1". |
| 34 | + loss_scale (float or str, optional): If passed as a string, must be a |
| 35 | + string representing a number, e.g., "128.0", or the string |
| 36 | + "dynamic". Defaults to "dynamic". |
| 37 | + enabled (bool): If False, renders all Amp calls no-ops, so your script |
| 38 | + should run as if Amp were not present. Defaults to True. |
| 39 | + cast_model_type (torch.dtype, optional): Model's parameters and |
| 40 | + buffers to the desired type. Defaults to None. |
| 41 | + patch_torch_functions (bool, optional): Patch all Torch functions |
| 42 | + and Tensor methods to perform Tensor Core-friendly ops like GEMMs |
| 43 | + and convolutions in FP16, and any ops that benefit from FP32 |
| 44 | + precision in FP32. Defaults to None. |
| 45 | + keep_batchnorm_fp32 (bool or str, optional): To enhance precision |
| 46 | + and enable cudnn batchnorm (which improves performance), |
| 47 | + it's often beneficial to keep batchnorm weights in FP32 |
| 48 | + even if the rest of the model is FP16. |
| 49 | + If passed as a string, must be the string "True" or "False". |
| 50 | + Defaults to None. |
| 51 | + master_weights (bool, optional): Maintain FP32 master weights to |
| 52 | + accompany any FP16 model weights. FP32 master weights are stepped |
| 53 | + by the optimizer to enhance precision and capture small gradients. |
| 54 | + Defaults to None. |
| 55 | + cast_model_outputs (torch.dtype, optional): Option to ensure that |
| 56 | + the outputs of your model(s) are always cast to a particular type |
| 57 | + regardless of ``opt_level``. Defaults to None. |
| 58 | + num_losses (int): Option to tell Amp in advance how many |
| 59 | + losses/backward passes you plan to use. Defaults to 1. |
| 60 | + verbosity (int): Set to 0 to suppress Amp-related output. |
| 61 | + Defaults to 1. |
| 62 | + min_loss_scale (float, optional): Sets a floor for the loss scale |
| 63 | + values that can be chosen by dynamic loss scaling. |
| 64 | + The default value of None means that no floor is imposed. |
| 65 | + If dynamic loss scaling is not used, `min_loss_scale` is ignored. |
| 66 | + Defaults to None. |
| 67 | + max_loss_scale (float, optional): Sets a ceiling for the loss scale |
| 68 | + values that can be chosen by dynamic loss scaling. If dynamic |
| 69 | + loss scaling is not used, `max_loss_scale` is ignored. |
| 70 | + Defaults to 2.**24. |
| 71 | + **kwargs: Keyword arguments passed to OptimWrapper. |
| 72 | +
|
| 73 | + Note: |
| 74 | + If you use ``IterBasedRunner`` and enable gradient accumulation, |
| 75 | + the original `max_iters` should be multiplied by |
| 76 | + ``accumulative_counts``. |
| 77 | +
|
| 78 | + Note: |
| 79 | + `New in version 0.6.0.` |
| 80 | + """ # noqa: E501 |
| 81 | + |
| 82 | + def __init__(self, |
| 83 | + opt_level: str = 'O1', |
| 84 | + loss_scale: Union[float, str, None] = 'dynamic', |
| 85 | + enabled: Optional[bool] = True, |
| 86 | + cast_model_type: Optional[torch.dtype] = None, |
| 87 | + patch_torch_functions: Optional[bool] = None, |
| 88 | + keep_batchnorm_fp32: Union[bool, str, None] = None, |
| 89 | + master_weights: Optional[bool] = None, |
| 90 | + cast_model_outputs: Optional[torch.dtype] = None, |
| 91 | + num_losses: int = 1, |
| 92 | + verbosity: int = 1, |
| 93 | + min_loss_scale: Optional[float] = None, |
| 94 | + max_loss_scale: Optional[float] = 2.**24, |
| 95 | + **kwargs): |
| 96 | + assert apex_amp is not None, \ |
| 97 | + 'Apex is not installed. Please check ' \ |
| 98 | + 'https://github.com/NVIDIA/apex#linux.' |
| 99 | + super().__init__(**kwargs) |
| 100 | + self.opt_level = opt_level |
| 101 | + self.loss_scale = loss_scale |
| 102 | + self.enabled = enabled |
| 103 | + self.cast_model_type = cast_model_type |
| 104 | + self.patch_torch_functions = patch_torch_functions |
| 105 | + self.keep_batchnorm_fp32 = keep_batchnorm_fp32 |
| 106 | + self.master_weights = master_weights |
| 107 | + self.cast_model_outputs = cast_model_outputs |
| 108 | + self.num_losses = num_losses |
| 109 | + self.verbosity = verbosity |
| 110 | + self.min_loss_scale = min_loss_scale |
| 111 | + self.max_loss_scale = max_loss_scale |
| 112 | + self._apex_amp_state_dict = None |
| 113 | + |
| 114 | + def backward(self, loss: torch.Tensor, **kwargs) -> None: |
| 115 | + """Perform gradient back propagation with :attr:`loss_scaler`. |
| 116 | +
|
| 117 | + Args: |
| 118 | + loss (torch.Tensor): The loss of current iteration. |
| 119 | + kwargs: Keyword arguments passed to :meth:`torch.Tensor.backward` |
| 120 | + """ |
| 121 | + with apex_amp.scale_loss(loss, self.optimizer) as scaled_loss: |
| 122 | + scaled_loss.backward(**kwargs) |
| 123 | + self._inner_count += 1 |
| 124 | + |
| 125 | + def state_dict(self) -> dict: |
| 126 | + """Get the state dictionary of :attr:`optimizer` and |
| 127 | + :attr:`apex_amp`. |
| 128 | +
|
| 129 | + Based on the state dictionary of the optimizer, the returned state |
| 130 | + dictionary will add a key named "apex_amp". |
| 131 | +
|
| 132 | + Returns: |
| 133 | + dict: The merged state dict of :attr:`apex_amp` and |
| 134 | + :attr:`optimizer`. |
| 135 | + """ |
| 136 | + state_dict = self.optimizer.state_dict() |
| 137 | + state_dict['apex_amp'] = apex_amp.state_dict() |
| 138 | + return state_dict |
| 139 | + |
| 140 | + def load_state_dict(self, state_dict: dict) -> None: |
| 141 | + """Load and parse the state dictionary of :attr:`optimizer` and |
| 142 | + :attr:`apex_amp`. |
| 143 | +
|
| 144 | + If state_dict contains "apex_amp", the :attr:`apex_amp` will |
| 145 | + load the corresponding keys. Otherwise, only the :attr:`optimizer` |
| 146 | + will load the state dictionary. |
| 147 | +
|
| 148 | + Note: |
| 149 | + :meth:`load_state_dict` shuold be called after |
| 150 | + `apex_amp.initialize` is called. |
| 151 | + Args: |
| 152 | + state_dict (dict): The state dict of :attr:`optimizer` and |
| 153 | + :attr:`apex_amp` |
| 154 | + """ |
| 155 | + if 'apex_amp' in state_dict: |
| 156 | + # when `apex_amp` is not initialized, calling `load_state_dict` |
| 157 | + # will raise an error, so we temporarily cache the apex_amp |
| 158 | + # part, and then load it into `apex_amp` after completing |
| 159 | + # the `apex_amp` initialization in `optim_context` method |
| 160 | + if hasattr(self.optimizer, '_amp_stash'): |
| 161 | + apex_amp.load_state_dict(state_dict.pop('apex_amp')) |
| 162 | + else: |
| 163 | + self._apex_amp_state_dict = state_dict.pop('apex_amp') |
| 164 | + self.optimizer.load_state_dict(state_dict) |
| 165 | + |
| 166 | + @contextmanager |
| 167 | + def optim_context(self, model: nn.Module): |
| 168 | + """Enables the context for mixed precision training, and enables the |
| 169 | + context for disabling gradient synchronization during gradient |
| 170 | + accumulation context. |
| 171 | +
|
| 172 | + Args: |
| 173 | + model (nn.Module): The training model. |
| 174 | + """ |
| 175 | + with super().optim_context(model): |
| 176 | + # when a given optimizer be passed through apex_amp.initialize, |
| 177 | + # the "_amp_stash" property will be added |
| 178 | + if not hasattr(self.optimizer, '_amp_stash'): |
| 179 | + if mmengine.model.wrappers.is_model_wrapper(model): |
| 180 | + model = model.module |
| 181 | + model, self.optimizer = apex_amp.initialize( |
| 182 | + model, |
| 183 | + self.optimizer, |
| 184 | + opt_level=self.opt_level, |
| 185 | + loss_scale=self.loss_scale, |
| 186 | + enabled=self.enabled, |
| 187 | + cast_model_type=self.cast_model_type, |
| 188 | + patch_torch_functions=self.patch_torch_functions, |
| 189 | + keep_batchnorm_fp32=self.keep_batchnorm_fp32, |
| 190 | + master_weights=self.master_weights, |
| 191 | + cast_model_outputs=self.cast_model_outputs, |
| 192 | + num_losses=self.num_losses, |
| 193 | + verbosity=self.verbosity, |
| 194 | + min_loss_scale=self.min_loss_scale, |
| 195 | + max_loss_scale=self.max_loss_scale) |
| 196 | + # loading apex_amp state_dict after initialization of apex_amp |
| 197 | + if self._apex_amp_state_dict is not None: |
| 198 | + apex_amp.load_state_dict(self._apex_amp_state_dict) |
| 199 | + self._apex_amp_state_dict = None |
| 200 | + yield |
0 commit comments