This project is a clean, modular, and scalable implementation of audio embedding models using PyTorch Lightning and Hydra in a self-supervised learning (SSL) regime. It is originally based on the lightning-hydra-template, designed to be extensible and runnable on local or cluster environments, and currently supports SSL training for Audio-JEPA, RQA-JEPA, BEST-RQ (ViT based), and BEST-RQ-2 architectures, with more coming up.
The goal of this project is to provide a robust codebase for training and experimenting with audio embedding models. Key features include:
- Modular Architecture: Components like Spectrogram, Masking, and ViT are decoupled.
- Configurable Positional Embeddings: Support for RoPE (2D Rotary Embeddings), SinCos (2D Sinusoidal), and Learnable embeddings.
- Hydra Configuration: flexible experiment management via hierarchical config files.
- Lightning Trainer: Simplified training loop, logging, and checkpointing.
- Modern Tooling: Uses
uvfor fast and reliable dependency management.
This project uses uv for dependency management.
-
Install
uv(if not already installed):curl -LsSf https://astral.sh/uv/install.sh | sh -
Clone the repository:
git clone <repository_url> cd audio-embeddings
-
Install dependencies:
uv sync
For development/testing tools (including
pytest), sync all groups:uv sync --all-groups
-
Enable shared git hooks (runs
uv syncafter merge/checkout/rewrite):git config core.hooksPath .githooks
To start training with the default configuration:
uv run src/train.pyRun on GPU with Weights & Biases logging:
uv run src/train.py trainer=gpu logger=wandbOverride hyperparameters on the command line:
uv run src/train.py data.batch_size=64 trainer.max_epochs=50Train directly from local UPS tar shards on a cluster filesystem:
uv run src/train.py \
data=ups_webdataset \
trainer=cpu \
+trainer.fast_dev_run=True \
data.shard_globs='[${oc.env:UPS_DATA_ROOT,/path/to/ups}/audio/*.tar,${oc.env:UPS_DATA_ROOT,/path/to/ups}/audio2/*.tar]'The loader expects UPS shard samples under the mp3 key.
Train from local PeoplesSpeech parquet splits (train*.parquet, valid*.parquet, test*.parquet) under a subset folder (default clean):
uv run src/train.py \
data=peoples_speech \
trainer=cpu \
+trainer.fast_dev_run=True \
data.data_root='${oc.env:DSDIR,/path/to/datasets}/HuggingFace/MLCommons/peoples_speech' \
data.cache_dir='${oc.env:SCRATCH,/tmp}'Run the cluster preset:
uv run src/train.py experiment=best_rq_2/peoples_speechIf your local layout differs, override data.split_file_patterns.* and/or column names from CLI/config.
You can switch between different positional embedding strategies easily:
RoPE:
uv run src/train.py model.net.encoder.pos_embed_type=ropeTo run training offline but still have model checkpoints staged for upload (which standard WandB restricts):
uv run src/train.py \
logger=wandb \
logger.wandb.offline=True \
logger.wandb.log_model=False \
+callbacks.wandb_offline_checkpoint._target_=src.callbacks.wandb_callbacks.WandbOfflineCheckpointCallback \
trainer=gpu trainer.devices=1 \
data.batch_size=128 trainer.max_epochs=100These checkpoints will be uploaded when you run wandb sync.
2D SinCos:
uv run src/train.py ++model.net.encoder.pos_embed_type=sincos ++model.net.predictor.pos_embed_type=sincosLearnable:
uv run src/train.py ++model.net.encoder.pos_embed_type=learnable ++model.net.predictor.pos_embed_type=learnableβββ configs/ # Hydra configuration files
β βββ callbacks/ # Callback configs (checkpoints, early stopping)
β βββ data/ # Data configs (AudioSet, etc.)
β βββ logger/ # Logger configs (WandB, Tensorboard)
β βββ model/ # Model configs (AudioJEPA parameters)
β βββ trainer/ # Trainer configs (CPU, GPU, strategies)
β βββ train.yaml # Main configuration entry point
βββ src/
β βββ data/ # Data loading logic
β β βββ audioset_datamodule.py # AudioSet DataModule & Dataset
β βββ models/ # Model architectures
β β βββ components/ # Reusable blocks
β β β βββ masking.py # Masking generators
β β β βββ patch_embed.py # Patchification
β β β βββ rope.py # 2D Rotary Embeddings
β β β βββ spectrogram.py # Audio preprocessing
β β β βββ vit.py # Vision Transformer (Student/Teacher/Predictor)
β β βββ audio_jepa_module.py # Main LightningModule
β βββ utils/ # Utility functions
β βββ train.py # Training entry point
βββ scripts/ # Helper scripts
βββ tests/ # Verification tests
βββ pyproject.toml # Project dependencies
βββ README.md # This file
- Create your model components in
src/models/components/. - Create a new LightningModule in
src/models/(or updateAudioJEPAModule). - Create a new config file in
configs/model/my_new_model.yaml. - Run with
uv run src/train.py model=my_new_model.
- Create a new DataModule in
src/data/. - Create a new config file in
configs/data/my_dataset.yaml. - Run with
uv run src/train.py data=my_dataset.
- Callbacks: Add custom callbacks in
src/callbacks/(if needed) or use existing Lightning callbacks, and configure them inconfigs/callbacks/. - Metrics: Add metrics logging in
training_steporvalidation_stepinsidesrc/models/audio_jepa_module.py.
The project uses a two-tier testing workflow:
- Fast
pytestchecks by default. - Heavier
integration/datachecks on demand.
pytest is defined in the dev dependency group in pyproject.toml, so examples below use --group dev.
Run the default fast pytest suite:
uv run --group dev pytestRun a single pytest file:
uv run --group dev pytest tests/test_audio_utils.py -qRun slower integration/data checks:
uv run --group dev pytest -m "integration or data"torchcodec is used by integration/data tests and by dataset decoding code paths (for example YT1B/UPS loaders). On macOS, these workflows may fail when FFmpeg shared libraries installed via Homebrew are not found at runtime.
Run with a fallback library path:
Fish:
env DYLD_FALLBACK_LIBRARY_PATH=/opt/homebrew/lib:/opt/homebrew/opt/ffmpeg/lib uv run --group dev pytest -m "integration or data"Bash/Zsh:
DYLD_FALLBACK_LIBRARY_PATH=/opt/homebrew/lib:/opt/homebrew/opt/ffmpeg/lib uv run --group dev pytest -m "integration or data"Keep script-based verifications for manual/component checks:
uv run tests/verify_rope.py
uv run tests/verify_custom_rope.py
uv run tests/verify_data.pyThis project supports a private-first workflow:
originis the private canonical repo.publicis the public mirror.- Public updates are release-gated from
release/<version>branches.
The publication pipeline uses deterministic sanitization rules in
.public-sanitize.yml and release tooling in:
scripts/sanitize_for_public.pyscripts/publish_public.shdocs/RELEASING_PUBLIC.md
- Public Issues/PRs stay enabled.
- Accepted public PRs are ported into private
master. - Ported changes are included in the next public release publish.
This repository is licensed under the MIT License. See LICENSE.
This repository also includes vendored third-party code under Apache-2.0.
See THIRD_PARTY_LICENSES.md and licenses/APACHE-2.0-LIGHTNING.txt for details.
@inproceedings{tuncay2025audio,
title={Audio-JEPA: Joint-Embedding Predictive Architecture for Audio Representation Learning},
author={Tuncay, Ludovic and Labb{\'e}, Etienne and Benetos, Emmanouil and Pellegrini, Thomas},
booktitle={ICME 2025},
year={2025},
booktitle={IEEE},
}Citation coming soon.
Citation coming soon.