From 5b27f93f82ef42217fb5fe7bc0f535b4e7ca9bad Mon Sep 17 00:00:00 2001 From: tenacious Date: Wed, 6 Nov 2024 15:03:53 +0800 Subject: [PATCH 1/2] ensure type in cfg --- mmengine/optim/optimizer/builder.py | 2 +- mmengine/registry/build_functions.py | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/mmengine/optim/optimizer/builder.py b/mmengine/optim/optimizer/builder.py index 013fa2f3af..7b4090ba7a 100644 --- a/mmengine/optim/optimizer/builder.py +++ b/mmengine/optim/optimizer/builder.py @@ -27,7 +27,7 @@ def register_torch_optimizers() -> List[str]: torch.optim.Optimizer): if module_name == 'Adafactor': OPTIMIZERS.register_module( - name='torch_Adafactor', module=_optim) + name='TorchAdafactor', module=_optim) else: OPTIMIZERS.register_module(module=_optim) torch_optimizers.append(module_name) diff --git a/mmengine/registry/build_functions.py b/mmengine/registry/build_functions.py index 4fee07569a..b73f5c3d8c 100644 --- a/mmengine/registry/build_functions.py +++ b/mmengine/registry/build_functions.py @@ -238,11 +238,12 @@ def build_optimizer_from_cfg( cfg: Union[dict, ConfigDict, Config], registry: Registry, default_args: Optional[Union[dict, ConfigDict, Config]] = None) -> Any: - if 'Adafactor' == cfg['type'] and digit_version( - torch.__version__) >= digit_version('2.5.0'): + if 'type' in cfg \ + and 'Adafactor' == cfg['type'] \ + and digit_version(torch.__version__) >= digit_version('2.5.0'): from ..logging import print_log print_log( - 'the torch version of Adafactor is registered as torch_Adafactor') + 'the torch version of Adafactor is registered as TorchAdafactor') return build_from_cfg(cfg, registry, default_args) From 7595f500e6f76bff509bbac015156e7e7148baf0 Mon Sep 17 00:00:00 2001 From: tenacious Date: Wed, 6 Nov 2024 15:07:41 +0800 Subject: [PATCH 2/2] change import level --- mmengine/registry/build_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmengine/registry/build_functions.py b/mmengine/registry/build_functions.py index b73f5c3d8c..fe907c80db 100644 --- a/mmengine/registry/build_functions.py +++ b/mmengine/registry/build_functions.py @@ -238,10 +238,10 @@ def build_optimizer_from_cfg( cfg: Union[dict, ConfigDict, Config], registry: Registry, default_args: Optional[Union[dict, ConfigDict, Config]] = None) -> Any: + from ..logging import print_log if 'type' in cfg \ and 'Adafactor' == cfg['type'] \ and digit_version(torch.__version__) >= digit_version('2.5.0'): - from ..logging import print_log print_log( 'the torch version of Adafactor is registered as TorchAdafactor') return build_from_cfg(cfg, registry, default_args)