Official Repository of the paper: Brain-JEPA: Brain Dynamics Foundation Model with Gradient Positioning and Spatiotemporal Masking (NeurIPS 2024, Spotlight)
Schematic overview of Brain-JEPA.
Brain-JEPA has three main contributions:
- Brain gradient positioning for ROI locations, and sine and cosine functions for temporal positioning.
- Spatiotemporal masking: Brain-JEPA employs a single observation block to predict the representations of target blocks sampled from three distinct regions: Cross-ROI, Cross-Time, and Double-Cross.
- The use of JEPA architecture.
Brain gradient positioning. Brain cortical regions are situated in the top 3 gradient axes and colored based on their positions. These colors are then projected back into the brain surface.
The fMRI data was parcellated into 450 ROIs, comprising 50 subcortical ROIs from the Tian Scale III atlas, followed by 400 cortical ROIs from the Schaefer atlas.
Three publicly available datasets were used in this paper: UKB, HCP-Aging, and ADNI.
The in-house dataset for NC/MCI classification is from Memory, Ageing and Cognition Centre (MACC).
UKB dataset is directly downloaded from Mansour et al..
We followed the preprocessing pipelines in Wu et al. and Kong et al. for HCP-Aging and ADNI, respectively.
Population-level brain gradient was derived from UKB data with Brainspace toolbox.
.
├── configs # directory in which all experiment '.yaml' configs are stored
├── downstream_tasks # the downstream package
│ ├── utils # shared downstream utilities
│ ├── engine_finetune.py # train_one_epoch and evaluation
│ ├── main_finetune.py # fine-tuning training loop
│ ├── main_linearprobe.py # linear-probing training loop
│ ├── models_vit.py # model for downstream tasks
├── data # put dataset and gradient file here
├── logs # put the downloaded pre-trained checkpoints here
├── output_dirs # put the downloaded example downstream checkpoints here
├── src # the package
│ ├── helper.py # helper functions for init of models
│ ├── train.py # pretraining
│ ├── datasets # datasets, data loaders, ...
│ ├── models # model definitions
│ ├── masks # mask collators, masking utilities, ...
│ └── utils # shared utilities
├── downstream_eval.py # entrypoint for launch downstream tasks\
Checkpoints of the pre-trained model, and example fine-tuned model, can be downloaded from here.
conda create -n brain-jepa python=3.8
pip install -r requirement.txt
This implementation starts from the main.py, which parses the experiment config file and runs the pre-training locally on a multi-GPU (or single-GPU) machine. For example, to run Brain-JEPA pretraining on GPUs "0","1", "2" and "3" on a local machine using the config configs/ukb_vitb_ep300.yaml, type the command:
python main.py \
--fname configs/ukb_vitb_ep300.yaml \
--devices cuda:0 cuda:1 cuda:2 cuda:3
Note: Our pretraining was performed on 4 A100 (40G) GPUs.
Example:
sh scripts/classification/run_downstream_FT_hca_sex.sh
Config files: Note that all experiment parameters are specified in config files (as opposed to command-line-arguments). See the configs/ directory for example config files.
- Release code for representation conditional fMRI signal reconstruction
Our codebase builds heavily on I-JEPA and MAE.
Thanks for the opensourcing!
If you find this repository useful in your research, please consider giving a star ⭐ and a citation
@article{BrainJEPA,
title={Brain-JEPA: Brain Dynamics Foundation Model with Gradient Positioning and Spatiotemporal Masking},
author={Zijian Dong and Ruilin Li and Yilei Wu and Thuan Tinh Nguyen and Joanna Su Xian Chong and Fang Ji and Nathanael Ren Jie Tong and Christopher Li Hsian Chen and Juan Helen Zhou},
journal={NeurIPS 2024},
year={2024}
}