The official PyTorch implementation of the paper "Fourier Head: Helping Large Language Models Learn Complex Probability Distributions". Please visit our webpage for more details.
The Fourier head is a neural network layer which learns a continuous probability density function using Fourier series, and returns a discrete approximation of it.
Large language models are often adapted to model non-linguistic tokens. If these tokens have an underlying continuous structure (e.g. time series forecasting, or RL control tasks) then replacing the linear classification head with the Fourier head can boost downstream performance significantly.
Tip
If you're looking for an example to cannibalize for your own experiments, we recommend you start with the synthetic toy example, since the implementation is minimalist and self-contained.
In fourier_head.py we implement the Fourier head. This is a self-contained file, so you can copy it into your codebase and import it like you would any neural network layer. Example usage:
import torch
from fourier_head import Fourier_Head
class MyNetwork(torch.nn.Module):
def __init__(self, input_dim=16, output_dim=18, num_frequencies=42, device="cpu"):
super(MyNetwork, self).__init__()
# Initialize Fourier head with specified input and output dimensions
self.classification_head = Fourier_Head(
input_dim, output_dim, num_frequencies, device=device
)
def forward(self, x):
# Fourier head maps (batch_size, input_dim) --> (batch_size, output_dim)
return self.classification_head(x)
device = "cuda" if torch.cuda.is_available() else "cpu"
tensor_input = torch.randn(32, 16).to(device) # Batch size 32, input dimension 16
model = MyNetwork(device=device).to(device)
tensor_output = model(tensor_input)
print(tensor_output.shape) # Expected shape: (32, 18)
Note
- In the paper, the main use case for the Fourier head is as a drop-in replacement for the linear classification head.
Accordingly, our implementation of the Fourier head outputs the inverse softmax (i.e. the log) of the categorical distribution that you obtain from quantizing the learned continuous PDF. In other words: to obtain the continuous-looking Fourier head PMFs as in the paper, you need to apply
softmax
to the output of ourFourier_Head
. - Some older versions of PyTorch can't execute
torch.nn.functional.conv1d
on complex-valued tensors. We provide an implementation that works for this case inside imitation-learning/mingpt/_fourier_head.py.
Our paper contains four sets of experiments with the Fourier head. Look inside the corresponding subdirectory for the code to recreate the results and figures from that section of the paper.
We thank the authors of the works we build upon:
If you find this code useful in your research, please cite:
@misc{gillman2024fourierheadhelpinglarge,
title={Fourier Head: Helping Large Language Models Learn Complex Probability Distributions},
author={Nate Gillman and Daksh Aggarwal and Michael Freeman and Saurabh Singh and Chen Sun},
year={2024},
eprint={2410.22269},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2410.22269},
}