Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Flexible normalization layers #95

Open
wants to merge 12 commits into
base: develop
Choose a base branch
from

Conversation

jakob-schloer
Copy link
Collaborator

@jakob-schloer jakob-schloer commented Dec 5, 2024

This PR combines the issue ecmwf/anemoi-core#31 with the PR #35 by @cathalobrien.

Describe your changes

This PR makes it possible to switch the implementation of Linear and LayerNorm kernels in the config.

At the moment we use torch.NN implementation for many layers in Anemoi model e.g. torch.nn.layerNorm, torch.NN.linear. This has the advantage of being available out of the box with torch and portable to many systems (CPU, AMD and Nvidia GPUs). However, other layer implementations might be more efficient for certain hardware, or we might want to use a custom layer.

This PR adds the following block to config/model/.yaml:

  layer_kernels:
    LayerNorm:
      #_target_: "transformer_engine.pytorch.LayerNorm"
      _target_: "liger_kernel.transformers.rms_norm.LigerRMSNorm"
      #_target_: "torch.nn.LayerNorm" #the default PyTorch implementation
      _partial_: True
      #Any arguments to your chosen function go here e.g.
      #bias: False
    Linear:
      #_target_: "transformer_engine.pytorch.Linear"
      _target_: "torch.nn.Linear"
      _partial_: True

You can pass any parameters to your new kernels in the config file, after "partial : True". Hydra tries to load the desired kernel in "models/encoder_processor_decoder.py". If the desired library isn't available, torch currently will fall back to torch.nn..

The calls to torch.nn are then replaced with

- self.layer_norm = nn.LayerNorm(normalized_shape=num_channels)
+ self.layer_norm = layer_kernels['LayerNorm'](normalized_shape=num_channels)

In the future, this syntax could be extended to replace other layers if required. More specifically, a follow up PR that includes conditional layer norm.

Type of change

  • New feature (non-breaking change which adds functionality)
  • This change requires a documentation update

Checklist before requesting a review

  • I have performed a self-review of my code
  • My code follows the style guidelines of this project
  • I have commented my code, particularly in hard-to-understand areas
  • I have updated the documentation and docstrings to reflect the changes
  • have added tests that prove my fix is effective or that my feature works
  • I have ensured that the code is still pip-installable after the changes and runs
  • I have not introduced new dependencies in the inference partion of the model
  • I have ran this on single GPU
  • I have ran this on multi-GPU or multi-node
  • I have ran the Benchmark Profiler against the old version of the code
  • I've tested the changes with the Transformer, GraphTransformer and GNN

@FussyDuck
Copy link

FussyDuck commented Dec 5, 2024

CLA assistant check
All committers have signed the CLA.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Specifying normalization layers.
3 participants