Skip to content

Commit

Permalink
[Fix] fix Adafactor optim on torch2.5 and fix compatibility (#1600)
Browse files Browse the repository at this point in the history
* fix Adafactor opptim on torch2.5 and fix compatibility

* fix runtest error
  • Loading branch information
tenacioustommy authored Nov 5, 2024
1 parent fc59364 commit 2e0ab7a
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 5 deletions.
6 changes: 5 additions & 1 deletion mmengine/optim/optimizer/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@ def register_torch_optimizers() -> List[str]:
_optim = getattr(torch.optim, module_name)
if inspect.isclass(_optim) and issubclass(_optim,
torch.optim.Optimizer):
OPTIMIZERS.register_module(module=_optim)
if module_name == 'Adafactor':
OPTIMIZERS.register_module(
name='torch_Adafactor', module=_optim)
else:
OPTIMIZERS.register_module(module=_optim)
torch_optimizers.append(module_name)
return torch_optimizers

Expand Down
16 changes: 15 additions & 1 deletion mmengine/registry/build_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import logging
from typing import TYPE_CHECKING, Any, Optional, Union

import torch

from mmengine.config import Config, ConfigDict
from mmengine.utils import ManagerMixin
from mmengine.utils import ManagerMixin, digit_version
from .registry import Registry

if TYPE_CHECKING:
Expand Down Expand Up @@ -232,6 +234,18 @@ def build_model_from_cfg(
return build_from_cfg(cfg, registry, default_args)


def build_optimizer_from_cfg(
cfg: Union[dict, ConfigDict, Config],
registry: Registry,
default_args: Optional[Union[dict, ConfigDict, Config]] = None) -> Any:
if 'Adafactor' == cfg['type'] and digit_version(
torch.__version__) >= digit_version('2.5.0'):
from ..logging import print_log
print_log(
'the torch version of Adafactor is registered as torch_Adafactor')
return build_from_cfg(cfg, registry, default_args)


def build_scheduler_from_cfg(
cfg: Union[dict, ConfigDict, Config],
registry: Registry,
Expand Down
6 changes: 3 additions & 3 deletions mmengine/registry/root.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
https://mmengine.readthedocs.io/en/latest/advanced_tutorials/registry.html.
"""

from .build_functions import (build_model_from_cfg, build_runner_from_cfg,
build_scheduler_from_cfg)
from .build_functions import (build_model_from_cfg, build_optimizer_from_cfg,
build_runner_from_cfg, build_scheduler_from_cfg)
from .registry import Registry

# manage all kinds of runners like `EpochBasedRunner` and `IterBasedRunner`
Expand Down Expand Up @@ -35,7 +35,7 @@
WEIGHT_INITIALIZERS = Registry('weight initializer')

# mangage all kinds of optimizers like `SGD` and `Adam`
OPTIMIZERS = Registry('optimizer')
OPTIMIZERS = Registry('optimizer', build_func=build_optimizer_from_cfg)
# manage optimizer wrapper
OPTIM_WRAPPERS = Registry('optim_wrapper')
# manage constructors that customize the optimization hyperparameters.
Expand Down

0 comments on commit 2e0ab7a

Please sign in to comment.