Skip to content

Allow customized parameter grouping for automatic optimzier configuration #20743

@SeanZhang99

Description

@SeanZhang99

Description & Motivation

I personally rewrite the CLI's _add_configure_optimizers_to_model: replace self.model.parameters() with an additional method _get_model_parameters. The default behavior of this method is just returning self.model.parameters(). But after this change, I can directly inherit a new subclass from LightningCLI, change what happens in _get_model_parameters, to allow optimizers get parameter names and values together.

In this case, I can write a new optimizer to allow applying weight decay to different groups of parameters, while not manually write configure_optimizers method in my LightningModule, nor worrying about CLI config files or anything else.

What I'm current doing is:

  1. Write a new optimizer interface:
class AutoDetachBiasDecay(Optimizer):

    def __new__(
        cls,
        /,
        named_params: Iterator[tuple[str, Parameter]],
        optimizer_class: type[Optimizer],
        lr: float = 1.0e-3,
        weight_decay: float = 1.0e-2,
        **kwargs,
    ):
        decay, no_decay = [], []
        for name, param in named_params:
            if not param.requires_grad:
                continue
            if "bias" in name or "Norm" in name:
                no_decay.append(param)
            else:
                decay.append(param)

        grouped_params = [
            {"params": decay, "weight_decay": weight_decay, "lr": lr},
            {"params": no_decay, "weight_decay": 0.0, "lr": 5 * lr},
        ]

        return optimizer_class(grouped_params, **kwargs)
  1. Changed the cli.py file
    Before:
        optimizer = instantiate_class(self.model.parameters(), optimizer_init)
        lr_scheduler = instantiate_class(optimizer, lr_scheduler_init) if lr_scheduler_init else None
        fn = partial(self.configure_optimizers, optimizer=optimizer, lr_scheduler=lr_scheduler)
        update_wrapper(fn, self.configure_optimizers)  # necessary for `is_overridden`
        # override the existing method
        self.model.configure_optimizers = MethodType(fn, self.model)

After:

        optimizer = instantiate_class(self._get_parameters(), optimizer_init)
        lr_scheduler = instantiate_class(optimizer, lr_scheduler_init) if lr_scheduler_init else None
        fn = partial(self.configure_optimizers, optimizer=optimizer, lr_scheduler=lr_scheduler)
        update_wrapper(fn, self.configure_optimizers)  # necessary for `is_overridden`
        # override the existing method
        self.model.configure_optimizers = MethodType(fn, self.model)

    def _get_parameters(self) -> dict[str, Any]:
        return self.model.parameters()
  1. Write a new subclass of CLI
class NamedParamsCLI(LightningCLI):
    model: torch.nn.Module

    def _get_parameters(self):
        return self.model.named_parameters()
  1. Replace the optimizer config yaml file from
class_path: AdamW
init_args: something

to

class_path: path.to.AutoDetachBiasDecay
init_args:
  optimizer_class: torch.optim.AdamW
  lr: 1.e-4
  weight_decay: 1.e-2

I'm not quite sure if this proposal does not conflict with what configure_optimizers should do. But in my point of view, this change make it quite easy for me to add different parameters to separate groups, and applying different optimization parameters to them, WITHOUT the need of writing a bunch of codes in the configure_optimizers, nor writing additional parameters to pass into my LightningModule class.

But also I admit there are several problems. One would be that, different optimizers may want different initialization arguments, and jsonargparse will check argument name. Therefore, in my simple case, users cannot define other arguments. And currently, I'm not working with other arguments, so I did not put much time finding out solutions for this problem. And I would be happy to hear anyone giving feedback or suggestions on this feature.

Pitch

After this change to the cli.py file, users can write a very simple .py file to create their own optimizer interface, adding different parameters into groups, and applying different optimization parameters to different groups. E.g., in my case, different weight_decay and learning_rate
The only thing user should do is to start training from the NamedParamsCLI instead of LightningCLI (no change would be better), and copying the optimizer_interface.py code to their machine, change the optimizer -> class_path to the new file, and specify which optimizer they would like to use in the config file.

Alternatives

No response

Additional context

No response

cc @lantiga @Borda

Metadata

Metadata

Assignees

No one assigned

    Labels

    featureIs an improvement or enhancementneeds triageWaiting to be triaged by maintainers

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions