Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Specifying normalization layers. #31

Open
jakob-schloer opened this issue Nov 22, 2024 · 7 comments · May be fixed by ecmwf/anemoi-models#95
Open

Specifying normalization layers. #31

jakob-schloer opened this issue Nov 22, 2024 · 7 comments · May be fixed by ecmwf/anemoi-models#95
Assignees
Labels
enhancement New feature or request models

Comments

@jakob-schloer
Copy link
Collaborator

Is your feature request related to a problem? Please describe.

Currently, the processor is implemented with LayerNormalization. I would like to use other normalization layers (https://pytorch.org/docs/stable/nn.html#normalization-layers) including custom normalization layers.

Describe the solution you'd like

I would like to specify the normalization layer of the processor in the config, e.g. transformer.yaml:

layer_norm:  # This needs to be a partial instantiation since it is used in multiple places
  _target_: torch.nn.LayerNorm 
  _partial_: True
  normalized_shape: ${model.num_channels}

processor:
  _target_: anemoi.models.processor.TransformerProcessor
  _convert_: all
  activation: ${model.activation}
  num_layers: 16
  num_chunks: 2
  mlp_hidden_ratio: 4 # GraphTransformer or Transformer only
  num_heads: 16 # GraphTransformer or Transformer only
  window_size: 512
  dropout_p: 0.0 # GraphTransformer
  layer_norm: ${model.layer_norm} # (Optional) Default nn.LayerNorm

Describe alternatives you've considered

No response

Additional context

No response

Organisation

No response

@jakob-schloer jakob-schloer added the enhancement New feature or request label Nov 22, 2024
@jakob-schloer jakob-schloer self-assigned this Nov 22, 2024
@clessig
Copy link

clessig commented Nov 22, 2024

Cathal already did experiments with RMSNorm (but from TransformerEngine, I think). It might have been hard coded but good to coordinate.

CC: @cathalobrien

@cathalobrien
Copy link
Contributor

cathalobrien commented Nov 22, 2024

Hey, yeah i have this PR ecmwf/anemoi-models#35 . I put it on ice a while back bc I thought it would cause problems in inference if we have arbitrary functions in the checkpoint file.

but now that the checkpoints are weights only, it should be fine. I can refresh it next week

@jakob-schloer
Copy link
Collaborator Author

jakob-schloer commented Nov 22, 2024

I see, this is related but I was thinking of something more general. I would like to be able to write custom normalization layers, e.g.

class TransformerProcessorBlock(BaseBlock):
    """Transformer block with MultiHeadSelfAttention and MLPs."""

    def __init__(
        self,
        num_channels: int,
        hidden_dim: int,
        num_heads: int,
        activation: str,
        window_size: int,
        dropout_p: float = 0.0,
        layer_norm: Optional[dict] = None,
    ):
        super().__init__()

        try:
            act_func = getattr(nn, activation)
        except AttributeError as ae:
            LOGGER.error("Activation function %s not supported", activation)
            raise RuntimeError from ae

        # Instantiate normalization layers using Hydra
        self.layer_norm1 = layer_norm()
        self.layer_norm2 = layer_norm()
        ...
    
    def forward(
        self,
        x: Tensor,
        shapes: list,
        batch_size: int,
        model_comm_group: Optional[ProcessGroup] = None,
        **layer_kwargs,
    ) -> Tensor:
        # Need to be out of place for gradient propagation
        x = x + self.attention(self.layer_norm1(x, **layer_kwargs), shapes, batch_size, model_comm_group=model_comm_group)
        x = x + self.mlp(self.layer_norm2(x, **layer_kwargs))
        return x

Do you think this could be combined with your PR @cathalobrien?

@cathalobrien
Copy link
Contributor

Ah I see, yeah I think this should work.

I already have this implemented

    LayerNorm:
      #_target_: "torch.nn.LayerNorm" #the default PyTorch implementation
      _target_: "liger_kernel.transformers.rms_norm.LigerRMSNorm" # my desired layernorm
      _partial_: True

I havent tried with a handwritten layernorm, but i assume as long as the import in target points to the right place it should be fine.

I like your idea of passing **layer_kwargs directly to the instantiated layer_norm, i was wondering how to handle arbitrary parameters at the time.

@jakob-schloer
Copy link
Collaborator Author

I like your idea of passing **layer_kwargs directly to the instantiated layer_norm, i was wondering how to handle arbitrary parameters at the time.

On a second thought, I believe it should be only **kwargs. In the future someone wants to do something else in the forward function.

@clessig
Copy link

clessig commented Nov 22, 2024

Yes, e.g. cross attention or some fancy bias terms for the attention could also be passed.

@jakob-schloer
Copy link
Collaborator Author

I close this, since PR ecmwf/anemoi-models#35 has this already.

@jakob-schloer jakob-schloer reopened this Nov 22, 2024
@jakob-schloer jakob-schloer closed this as not planned Won't fix, can't repro, duplicate, stale Nov 22, 2024
@jakob-schloer jakob-schloer reopened this Dec 4, 2024
@jakob-schloer jakob-schloer linked a pull request Dec 5, 2024 that will close this issue
11 tasks
@jakob-schloer jakob-schloer linked a pull request Dec 5, 2024 that will close this issue
11 tasks
@JesperDramsch JesperDramsch transferred this issue from ecmwf/anemoi-models Dec 19, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request models
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants