diff --git a/mmengine/optim/optimizer/builder.py b/mmengine/optim/optimizer/builder.py index 8557f4d34c..013fa2f3af 100644 --- a/mmengine/optim/optimizer/builder.py +++ b/mmengine/optim/optimizer/builder.py @@ -25,7 +25,11 @@ def register_torch_optimizers() -> List[str]: _optim = getattr(torch.optim, module_name) if inspect.isclass(_optim) and issubclass(_optim, torch.optim.Optimizer): - OPTIMIZERS.register_module(module=_optim) + if module_name == 'Adafactor': + OPTIMIZERS.register_module( + name='torch_Adafactor', module=_optim) + else: + OPTIMIZERS.register_module(module=_optim) torch_optimizers.append(module_name) return torch_optimizers diff --git a/mmengine/registry/build_functions.py b/mmengine/registry/build_functions.py index 9d195162bc..4fee07569a 100644 --- a/mmengine/registry/build_functions.py +++ b/mmengine/registry/build_functions.py @@ -3,8 +3,10 @@ import logging from typing import TYPE_CHECKING, Any, Optional, Union +import torch + from mmengine.config import Config, ConfigDict -from mmengine.utils import ManagerMixin +from mmengine.utils import ManagerMixin, digit_version from .registry import Registry if TYPE_CHECKING: @@ -232,6 +234,18 @@ def build_model_from_cfg( return build_from_cfg(cfg, registry, default_args) +def build_optimizer_from_cfg( + cfg: Union[dict, ConfigDict, Config], + registry: Registry, + default_args: Optional[Union[dict, ConfigDict, Config]] = None) -> Any: + if 'Adafactor' == cfg['type'] and digit_version( + torch.__version__) >= digit_version('2.5.0'): + from ..logging import print_log + print_log( + 'the torch version of Adafactor is registered as torch_Adafactor') + return build_from_cfg(cfg, registry, default_args) + + def build_scheduler_from_cfg( cfg: Union[dict, ConfigDict, Config], registry: Registry, diff --git a/mmengine/registry/root.py b/mmengine/registry/root.py index 2663dffcd9..eb9a225a91 100644 --- a/mmengine/registry/root.py +++ b/mmengine/registry/root.py @@ -6,8 +6,8 @@ https://mmengine.readthedocs.io/en/latest/advanced_tutorials/registry.html. """ -from .build_functions import (build_model_from_cfg, build_runner_from_cfg, - build_scheduler_from_cfg) +from .build_functions import (build_model_from_cfg, build_optimizer_from_cfg, + build_runner_from_cfg, build_scheduler_from_cfg) from .registry import Registry # manage all kinds of runners like `EpochBasedRunner` and `IterBasedRunner` @@ -35,7 +35,7 @@ WEIGHT_INITIALIZERS = Registry('weight initializer') # mangage all kinds of optimizers like `SGD` and `Adam` -OPTIMIZERS = Registry('optimizer') +OPTIMIZERS = Registry('optimizer', build_func=build_optimizer_from_cfg) # manage optimizer wrapper OPTIM_WRAPPERS = Registry('optim_wrapper') # manage constructors that customize the optimization hyperparameters.