Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding a simple optimizer registry. #876

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

balancap
Copy link
Contributor

At the moment, _create_optimizer is hardcoding a list optimizers available.

Adding a simple registry is allowing users to fully customize their optimizer choice and configuration (when used in combination with a custom build_optimizers).

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Feb 21, 2025
At the moment, `_create_optimizer` is hardcoding a list optimizers available.

Adding a simple registry is allowing users to fully customize their optimizer choice and configuration (when used in combination with a custom `build_optimizers`).
@balancap balancap force-pushed the adding-optimizer-registry branch from 1c47613 to 8aae457 Compare February 21, 2025 16:12
@balancap
Copy link
Contributor Author

balancap commented Feb 21, 2025

I am introducing the minimal change to be able to introduce custom optimizers in the training loop. But there is probably a case for extending this work to avoid hard coding

optimizer_kwargs = {
    "lr": lr,
    "betas": (0.9, 0.95),
    "weight_decay": 0.1,
    "fused": fused,
    "foreach": not fused,
}

in the current build_optimizers implementation.

Additionally, it could also be useful to have a registry for LRSchedulerLambda functions. Not strictly necessary for the users, as they can customize build_lr_schedulers, but could simplify a bit things for them.

@fegin
Copy link
Contributor

fegin commented Feb 21, 2025

If users would like to extend even more cases beyond the current optimizer, they should just use TrainSpec and replace the optimizer implementation with their owns. IMHO, that's the granularity TorchTitan provides for users to customize.

@balancap
Copy link
Contributor Author

Agree, that's what we want to do, have our own TrainSpec. But we would also like to be able to re-use the OptimizersContainer class, which is generic and not tied to Adam optimizers.

One simpler alternative, keeping TorchTitan lean as it is, would be to pass optimizer_cls to OptimizersContainer instead of the name: i.e.

class OptimizersContainer(Optimizer):
    """
    optimizers: List[Optimizer]
    model_parts: List[nn.Module]

    def __init__(
        self, model_parts: List[nn.Module], optimizer_cls: Type[Optimizer], optimizer_kwargs: Dict[str, Any]
    ) -> None:
          ....

How does it sound to you @fegin @tianyu-l ?

Independently of my need, I believe it is also a better design: a container should be agnostic to element type (i.e. std::vector, ...) and OptimizersContainer could be then more strongly typed as Generic[T], reflecting that the class expect all optimizers to be of the same type.

@fegin
Copy link
Contributor

fegin commented Feb 21, 2025

ye, using cls is a good idea, just like model. I didn't change it because it was originally coded and I didn't want to change too much in one PR. I vote for changing from name to cls.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants