Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
71 commits
Select commit Hold shift + click to select a range
5a1d899
demo
loubbrad Dec 31, 2024
cfae8ee
demo fix
loubbrad Jan 4, 2025
ea68e75
mess it all up agian
loubbrad Jan 4, 2025
d6a865b
demo finished
loubbrad Jan 5, 2025
977f54b
undo mistake
loubbrad Jan 6, 2025
877d6e0
update demo
loubbrad Jan 7, 2025
9a5c011
add prefill compile
loubbrad Jan 7, 2025
3dd44bf
add class finetuning
loubbrad Feb 20, 2025
7d99bad
add seq sep option to PretrainingDataset
loubbrad Feb 21, 2025
014f0b9
change from genre to composer
loubbrad Feb 21, 2025
6b53ba4
update emb eval scripts
loubbrad Feb 24, 2025
e6c1e2a
add explore script
loubbrad Feb 24, 2025
1a94cb4
add contrastive ft
loubbrad Feb 24, 2025
3dc6953
add missing changes
loubbrad Feb 24, 2025
32b832d
add loop
loubbrad Feb 24, 2025
68fe378
fix arg bug
loubbrad Feb 24, 2025
dae0b03
update eval
loubbrad Feb 27, 2025
06ef338
fix eval hang
loubbrad Feb 28, 2025
19f9795
add data aug
loubbrad Mar 5, 2025
2c086e1
fix data aug
loubbrad Mar 5, 2025
b524ab3
formalize eval
loubbrad Mar 8, 2025
87c82ac
eval scripts
loubbrad Mar 10, 2025
58d439f
fix range bug
loubbrad Mar 10, 2025
b238edc
add m3 only embeddings
loubbrad Mar 11, 2025
b485c04
update script for m3 embeddings
loubbrad Mar 11, 2025
bda70ac
update for pianist eval
loubbrad Mar 11, 2025
3fa2b29
add pianist8 dataset script
loubbrad Mar 11, 2025
4a7427e
adjust per file emb logic and update scripts
loubbrad Mar 12, 2025
c8cc7b8
update datasets/training/model scripts to support embedding conditioning
loubbrad Mar 14, 2025
626b5b4
add ft-dataset script
loubbrad Mar 14, 2025
93bcb22
change use embeddings train logic
loubbrad Mar 15, 2025
72160f7
fix model ft loading
loubbrad Mar 15, 2025
50b27d3
fix arg
loubbrad Mar 15, 2025
f9e15de
fix ddp model error
loubbrad Mar 15, 2025
43689b4
add pca
loubbrad Mar 18, 2025
46b9daf
keshav
loubbrad Mar 20, 2025
9aeafd2
keshav add args
loubbrad Mar 20, 2025
9ef4db9
fix keshav
loubbrad Mar 21, 2025
67d86bb
update sampling and demo
loubbrad May 22, 2025
654d9de
add looping and ending to demo
loubbrad May 23, 2025
77d27b5
push mlx imp for test
loubbrad May 25, 2025
445d484
fix sample script
loubbrad May 26, 2025
dc9fdcb
add continuous prefill and speculative duration calculation
loubbrad May 27, 2025
4f54e41
add off-msg streaming and fix timing alignment
loubbrad May 28, 2025
b16394c
fix early-off logic with dumb hack
loubbrad May 28, 2025
fc43f70
fix stream_midi logic
loubbrad May 29, 2025
ba835ff
port demo to mlx
loubbrad May 29, 2025
571c0a6
add script
loubbrad May 29, 2025
a0944a4
update mlx demo
loubbrad Jun 2, 2025
6e8aeab
partial tree refactor for release
loubbrad Jun 3, 2025
a37ba9c
add resid dropout to model
loubbrad Jun 3, 2025
91802df
import fix
loubbrad Jun 3, 2025
f689029
inference tree skeleton
loubbrad Jun 3, 2025
3491ae4
fix tree
loubbrad Jun 3, 2025
3614a8b
rm scripts
loubbrad Jun 3, 2025
97e2a5c
refactor entrypoint for generate
loubbrad Jun 3, 2025
1daac44
cfg conditioned generation refactored for torch_cuda
loubbrad Jun 3, 2025
479edc1
add mlx backend for conditioned generation
loubbrad Jun 4, 2025
9ba3a00
fix mlx backend for conditioned gen
loubbrad Jun 4, 2025
2252039
update cli flags to standard unix format
loubbrad Jun 4, 2025
5c0a435
migrate to pyproject.toml
loubbrad Jun 4, 2025
0878402
add toml
loubbrad Jun 4, 2025
f86435f
remove old plan
loubbrad Jun 4, 2025
d5f46b9
add README draft
loubbrad Jun 9, 2025
23343cc
update README
loubbrad Jun 9, 2025
e5b8777
rmv test_dataset
loubbrad Jun 10, 2025
38d0ff2
update README
loubbrad Jun 10, 2025
442cc7a
demo adjustments
loubbrad Jun 16, 2025
893495d
add input delay correction
loubbrad Jun 16, 2025
6832c02
update README
loubbrad Jun 19, 2025
4474f6a
Merge branch 'dev' of github.com:loubbrad/aria into dev
loubbrad Jun 19, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,6 @@ fluidsynth/
tests/test_results
lightning_logs/
.vscode/
paper
hf
_scripts
9 changes: 0 additions & 9 deletions Makefile

This file was deleted.

143 changes: 105 additions & 38 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,64 +1,131 @@
# gpt-aria
# Aria

[Discord](https://discord.com/invite/zBGx3azzUn)
This repository contains training, inference, and evaluation code for the paper [*Scaling Self-Supervised Representation Learning for Symbolic Piano Performance (ISMIR 2025)*](https://example.com/), as well as implementations of our real-time piano continuation demo. *Aria* is a pretrained autoregressive generative model for symbolic music, based on the LLaMA 3.2 (1B) architecture, which was trained on ~60k hours of MIDI transcriptions of expressive solo-piano recordings. Alongside the base model, we are releasing a checkpoint finetuned to improve generative quality, as well as a checkpoint finetuned to produce general-purpose piano MIDI embeddings using a SimCSE-style contrastive training objective.

A repository containing resources for pre-training, fine-tuning, and evaluating musical (MIDI) transformer models.
📖 Read our [release blog post](https://example.com/) and [paper](https://example.com/)
🤗 Access our models via the [HuggingFace page](https://huggingface.co/loubb/aria-medium-base)
📊 Get access to our training dataset [Aria-MIDI](https://huggingface.co/datasets/loubb/aria-midi) and train your own models

***Note that this project is under active development***
## Installation

## Description
Installation requires Python 3.11+. To install the package and all dependencies with pip:

The main goal of the gpt-aria project is to create a suite of powerful pre-trained generative (symbolic) music models. We want to investigate how modern training (pre-training & fine-tuning) techniques can be used to improve the quality/usefulness of such models. Alongside this we are building various data (MIDI) preprocessing tools, allowing **you** to easily fine-tune our models on your own data.
```bash
git clone https://github.com/EleutherAI/aria
cd aria
pip install -e ".[all]"
```

If you are new to symbolic music models, a good place to start are the following projects/blogposts by Google Magenta and OpenAI:
## Quickstart

- [Music Transformer](https://magenta.tensorflow.org/music-transformer)
- [MuseNet](https://openai.com/research/musenet)
Download model weights from the official HuggingFace page for our pretrained model, as well as checkpoints finetuned for piano-continuation and generating MIDI-embeddings:

Long story short: Transformer + MIDI + GPUs = 🎵 x ∞
- `aria-medium-base` ([huggingface](https://huggingface.co/loubb/aria-medium-base), [direct-download](https://huggingface.co/loubb/aria-medium-base/resolve/main/model.safetensors?download=true))
- `aria-medium-gen`([huggingface](https://huggingface.co/loubb/aria-medium-gen), [direct-download](https://huggingface.co/loubb/aria-medium-gen/resolve/main/model.safetensors?download=true))
- `aria-medium-embedding`([huggingface](https://huggingface.co/loubb/aria-medium-embedding), [direct-download](https://huggingface.co/loubb/aria-medium-embedding/resolve/main/model.safetensors?download=true))

## Installation
### Inference (Prompt Continuation)

Make sure you are using Python 3.10+. Note that I haven't explicitly developed this project for anything other than Linux. If you are using Windows, things might not work properly. In this case I suggest installing using WSL.
We provide optimized model implementations for PyTorch (CUDA) and MLX (Apple Silicon). You can generate continuations of a MIDI file using the CLI, e.g., using CUDA (Linux):

```
git clone https://github.com/eleutherai/aria
cd aria
pip install -e .
```bash
aria generate \
--backend torch_cuda \
--checkpoint_path <path-to-model-weights> \
--prompt_midi_path <path-to-midi-file-to-continue> \
--prompt_duration <length-in-seconds-for-prompt> \
--variations <number-of-variations-to-generate> \
--temp 0.98 \
--min_p 0.035 \
--length 2048 \
--save_dir <dir-to-save-results>
```

## Inference
Since the model has not been post-trained with instruction tuning or RLHF (similar to pre-instruct GPT models), it is very sensitive to input quality and performs best when prompted with well-played music. To get prompt MIDI files, see the `example-prompts/` directory, explore the [Aria-MIDI](https://huggingface.co/datasets/loubb/aria-midi) dataset, or transcribe your own files using our [piano-transcription model](https://github.com/EleutherAI/aria-amt). For a full list of sampling options: `aria generate -h`. If you wish to do inference on the CPU, please see the platform-agnostic implementation on our HuggingFace page [link].

You can find preliminary checkpoints at the following locations
### Intended Use and Limitations

Finetuned piano-only checkpoints (improved robustness):
Aria performs best when **continuing existing piano MIDI files** rather than generating music from scratch. While multi-track tokenization and generation are supported, the model was trained primarily on **single-track expressive piano performances**, and we recommend using single-track inputs for optimal results.

```
large - https://storage.googleapis.com/aria-checkpoints/large-abs-inst.safetensors
```
Due to the high representation of popular classical works (e.g., Chopin) in the training data and the difficulty of complete deduplication, the model may **memorize or closely reproduce** such pieces. For more original outputs, we suggest prompting Aria with **lesser-known works or your own compositions**.

Pretrained checkpoints:
### Inference (MIDI embeddings)

You can generate embeddings from MIDI files using the `aria.embeddings` module. This is primarily exposed with the `get_global_embedding_from_midi` function, for example:

```python
from aria.embeddings import get_global_embedding_from_midi
from aria.model import TransformerEMB, ModelConfig
from aria.config import load_model_config
from ariautils.tokenizer import AbsTokenizer

# Load model
model_config = ModelConfig(**load_model_config(name="medium-emb"))
model_config.set_vocab_size(AbsTokenizer().vocab_size)
model = TransformerEMB(model_config)
state_dict = load_file(filename=CHECKPOINT_PATH)
model.load_state_dict(state_dict=state_dict, strict=True)

# Generate embedding
embedding = get_global_embedding_from_midi(
model=model,
midi_path=MIDI_PATH,
device="cpu",
)
```
large - https://storage.googleapis.com/aria-checkpoints/large-abs-pt.bin
medium - https://storage.googleapis.com/aria-checkpoints/medium-abs-pt.bin
small - https://storage.googleapis.com/aria-checkpoints/small-abs-pt.bin
```

You can then sample using the cli:
Our embedding model was trained to capture composition-level and performance-level attributes, and therefore might not be appropriate for every use case.

## Real-time demo

In `demo/` we provide CUDA (Linux/PyTorch) and MLX (Apple Silicon) implementations of the real-time interactive piano-continuation demo showcased in our release blog post. For the demo we used an acoustic Yamaha Disklavier piano with simultaneous MIDI input and output ports connected via a standard MIDI interface.

❗**NOTE**: Responsiveness of the real-time demo is dependent on your system configuration, e.g., GPU FLOPS and memory bandwidth.

A MIDI input device is not strictly required to play around with the demo: By using the `--midi_path` and `--midi_through` arguments you can mock real-time input by playing from a MIDI file. All that is required are MIDI drivers (e.g., CoreMIDI, ALSA) and a virtual software instrument (e.g., Fluidsynth, Pianoteq) to render the output.

Example usage (MLX):

```bash
MIDI_PATH="example-prompts/pokey_jazz.mid"

python demo/demo_mlx.py \
--checkpoint <checkpoint-path> \
--midi_path ${MIDI_PATH} \
--midi_through <port-to-stream-midi-file-through> \
--midi_out <port-to-stream-generation-over> \
--save_path <path-to-save-result> \
--temp 0.98 \
--min_p 0.035
```
aria sample \
-m large \
-c <path-to-checkpoint> \
-p <path-to-midifile> \
-var <num-variations-to-generate> \
-trunc <seconds-in-to-truncate-prompt> \
-l <number-of-tokens-to-generate> \
-temp 0.95 \
-e

## Evaluation

We provide the specific files/splits we used for Aria-MIDI derived linear-probe and classification evaluations. These can be downloaded from HuggingFace ([direct-download](https://huggingface.co/loubb/aria-medium-base/resolve/main/eval-splits.tar.gz?download=true)). Class labels are provided in `metadata.json` with the schema:

```json
{
"<category>": {
"<split-name>": {
"<relative/path/to/file.mid>": "<metadata_value_for_that_category>",
...
},
...
},
...
}
```

You can use `aria sample -h` to see a full list of options. If you wish to sample from a pretrained checkpoint, please use the `-pt` flag.
## License and Attribution

The Aria project has been kindly supported by EleutherAI, Stability AI, as well as by a compute grant from the Ministry of Science and ICT of Korea. Our models and MIDI tooling are released under the Apache-2.0 license. If you use the models or tooling for follow-up work, please cite the paper in which they were introduced:

```bibtex
@inproceedings{bradshawscaling,
title={Scaling Self-Supervised Representation Learning for Symbolic Piano Performance},
author={Bradshaw, Louis and Fan, Honglu and Spangher, Alex and Biderman, Stella and Colton, Simon},
booktitle={arXiv preprint},
year={2025},
url={https://arxiv.org/abs/2504.15071}
}
```
Loading