diff --git a/mmengine/optim/optimizer/builder.py b/mmengine/optim/optimizer/builder.py index 013fa2f3af..7b4090ba7a 100644 --- a/mmengine/optim/optimizer/builder.py +++ b/mmengine/optim/optimizer/builder.py @@ -27,7 +27,7 @@ def register_torch_optimizers() -> List[str]: torch.optim.Optimizer): if module_name == 'Adafactor': OPTIMIZERS.register_module( - name='torch_Adafactor', module=_optim) + name='TorchAdafactor', module=_optim) else: OPTIMIZERS.register_module(module=_optim) torch_optimizers.append(module_name) diff --git a/mmengine/registry/build_functions.py b/mmengine/registry/build_functions.py index 4fee07569a..fe907c80db 100644 --- a/mmengine/registry/build_functions.py +++ b/mmengine/registry/build_functions.py @@ -238,11 +238,12 @@ def build_optimizer_from_cfg( cfg: Union[dict, ConfigDict, Config], registry: Registry, default_args: Optional[Union[dict, ConfigDict, Config]] = None) -> Any: - if 'Adafactor' == cfg['type'] and digit_version( - torch.__version__) >= digit_version('2.5.0'): - from ..logging import print_log + from ..logging import print_log + if 'type' in cfg \ + and 'Adafactor' == cfg['type'] \ + and digit_version(torch.__version__) >= digit_version('2.5.0'): print_log( - 'the torch version of Adafactor is registered as torch_Adafactor') + 'the torch version of Adafactor is registered as TorchAdafactor') return build_from_cfg(cfg, registry, default_args)