Official implementation of Beyond Scores: Proximal Diffusion Models.
To appear in NeurIPS 2025.
by Zhenghan Fang, Mateo Díaz, Sam Buchanan, and Jeremias Sulam
We introduce a new framework for diffusion models based on backward discretization and proximal operators. Our Proximal Diffusion Models (ProxDM) achieve faster theoretical convergence rate and superior empirical sample quality under limited sampling steps.
FID on CIFAR-10
ProxDM achieves lower FID with fewer steps.
- 2025/12/01: Added the CelebA-HQ (256x256) experiment, using a latent-space variant of ProxDM.
Install the dependencies via conda:
conda env create -f environment.yml
conda activate prox_diffusion
or uv:
uv sync
source .venv/bin/activate
All pretrained models are available on HuggingFace.
Download the models manually to assets/pretrained_models/, or run scripts/download_models_and_fid_stats.sh.
Download the csv file from here to data/datasaurus/datasaurus.csv.
You can download the datasets using torchvision by running:
python scripts/download_datasets.py
- MNIST: stored in
data/mnist/ - CIFAR-10: stored in
data/cifar10/
Download from Kaggle - see: scripts/download_celeba_hq_256.sh.
Reference FID statistics for MNIST, CIFAR10 and CelebA-HQ are available on HuggingFace.
Download the stat files manually to assets/fid_stats/, or run scripts/download_models_and_fid_stats.sh.
To self-compute reference stats, see: scripts/prepare_fid_stats.py.
This experiment evaluates score-based and proximal samplers on a 2D mixture of Dirac deltas (the dino dataset in Datasaurus), using exact score and proximals, and computes Wasserstein distances to the target distribution.
- Generate samples and compute Wasserstein distances:
python scripts/low_dim_experiment/sample.pyResults will be saved in output/low_dim/dino/.
- Run
scripts/low_dim_experiment/plot.ipynbto produce the figures in the paper.
We provide commands to reproduce all ProxDM training experiments in the paper:
# MNIST Score
accelerate launch --multi_gpu --num_processes 4 train.py \
--config configs/mnist/score.py --ckpt_dir output/train/mnist/score
# MNIST Prox Hybrid
accelerate launch --multi_gpu --num_processes 4 train.py \
--config configs/mnist/prox_hybrid.py --ckpt_dir output/train/mnist/prox_hybrid
# MNIST Prox Backward
accelerate launch --multi_gpu --num_processes 4 train.py \
--config configs/mnist/prox_backward.py --ckpt_dir output/train/mnist/prox_backward
# CIFAR10 Score
accelerate launch --multi_gpu --num_processes 4 train.py \
--config configs/cifar10/score.py --ckpt_dir output/train/cifar10/score \
# CIFAR10 Prox Hybrid
accelerate launch --multi_gpu --num_processes 4 train.py \
--config configs/cifar10/prox_hybrid.py --ckpt_dir output/train/cifar10/prox_hybrid \
# CIFAR10 Score (5, 10, 20 steps)
accelerate launch --multi_gpu --num_processes 4 train.py \
--config configs/cifar10/score_subset.py --ckpt_dir output/train/cifar10/score_subset \
# CIFAR10 Prox Hybrid (5, 10, 20 steps)
accelerate launch --multi_gpu --num_processes 4 train.py \
--config configs/cifar10/prox_hybrid_subset.py --ckpt_dir output/train/cifar10/prox_hybrid_subset \
# CIFAR10 Prox Hybrid (5, 10, 20 steps, No heur.)
accelerate launch --multi_gpu --num_processes 4 train.py \
--config configs/cifar10/prox_hybrid_subset_noheur.py --ckpt_dir output/train/cifar10/prox_hybrid_subset_noheur \
# CelebA-HQ Score (latent space)
accelerate launch --multi_gpu --num_processes 8 \
train.py --config configs/celeba_hq_vae/score.py \
--ckpt_dir output/train/celeba_hq/score
# CelebA-HQ Prox Hybrid (latent space)
accelerate launch --multi_gpu --num_processes 8 \
train.py --config configs/celeba_hq_vae/prox_hybrid.py \
--ckpt_dir output/train/celeba_hq/prox_hybridWe provide commands to reproduce all FID results in the paper:
BATCH_SIZE=10000
# MNIST Score
python eval.py \
--config configs/mnist/score.py \
--ckpt_path assets/pretrained_models/mnist_score.pth --output_root output/eval \
--config.fid.n_samples 50000 --config.fid.batch_size $BATCH_SIZE \
# MNIST Score ODE
python eval.py \
--config configs/mnist/score.py \
--ckpt_path assets/pretrained_models/mnist_score.pth --output_root output/eval \
--config.fid.n_samples 50000 --config.fid.batch_size $BATCH_SIZE \
--config.fid.sample_method ode_euler_eps \
# MNIST Prox Hybrid
python eval.py \
--config configs/mnist/prox_hybrid.py \
--ckpt_path assets/pretrained_models/mnist_prox_hybrid.pth --output_root output/eval \
--config.fid.n_samples 50000 --config.fid.batch_size $BATCH_SIZE \
# MNIST Prox Backward
python eval.py \
--config configs/mnist/prox_backward.py \
--ckpt_path assets/pretrained_models/mnist_prox_backward.pth --output_root output/eval \
--config.fid.n_samples 50000 --config.fid.batch_size $BATCH_SIZE \
--config.fid.steps "(20,50,100,1000)" \
# CIFAR10 Score
python eval.py \
--config configs/cifar10/score.py \
--ckpt_path assets/pretrained_models/cifar10_score.pth --output_root output/eval \
--config.fid.n_samples 50000 --config.fid.batch_size $BATCH_SIZE \
# CIFAR10 Score ODE
python eval.py \
--config configs/cifar10/score.py \
--ckpt_path assets/pretrained_models/cifar10_score.pth --output_root output/eval \
--config.fid.n_samples 50000 --config.fid.batch_size $BATCH_SIZE \
--config.fid.sample_method ode_euler_eps \
# CIFAR10 Prox Hybrid
python eval.py \
--config configs/cifar10/prox_hybrid.py \
--ckpt_path assets/pretrained_models/cifar10_prox_hybrid.pth --output_root output/eval \
--config.fid.n_samples 50000 --config.fid.batch_size $BATCH_SIZE \
# CIFAR10 Score (5, 10, 20 steps)
python eval.py \
--config configs/cifar10/score_subset.py \
--ckpt_path assets/pretrained_models/cifar10_score_subset.pth --output_root output/eval \
--config.fid.n_samples 50000 --config.fid.batch_size $BATCH_SIZE \
--config.fid.steps "(5,10,20)" \
# CIFAR10 Score ODE (5, 10, 20 steps)
python eval.py \
--config configs/cifar10/score_subset.py \
--ckpt_path assets/pretrained_models/cifar10_score_subset.pth --output_root output/eval \
--config.fid.n_samples 50000 --config.fid.batch_size $BATCH_SIZE \
--config.fid.steps "(5,10,20)" \
--config.fid.sample_method ode_euler_eps \
# CIFAR10 Prox Hybrid (5, 10, 20 steps)
python eval.py \
--config configs/cifar10/prox_hybrid_subset.py \
--ckpt_path assets/pretrained_models/cifar10_prox_hybrid_subset.pth --output_root output/eval \
--config.fid.n_samples 50000 --config.fid.batch_size $BATCH_SIZE \
--config.fid.steps "(5,10,20)" \
# CIFAR10 Prox Hybrid (5, 10, 20 steps, No heur.)
python eval.py \
--config configs/cifar10/prox_hybrid_subset_noheur.py \
--ckpt_path assets/pretrained_models/cifar10_prox_hybrid_subset_noheur.pth --output_root output/eval \
--config.fid.n_samples 50000 --config.fid.batch_size $BATCH_SIZE \
--config.fid.steps "(5,10,20)" \
# CelebA-HQ Score SDE
python eval.py \
--config configs/celeba_hq_vae/score.py \
--ckpt_path assets/pretrained_models/celeba_hq_score.pth --output_root output/eval \
--config.fid.n_samples 30000
# CelebA-HQ Score ODE
python eval.py \
--config configs/celeba_hq_vae/score.py \
--ckpt_path assets/pretrained_models/celeba_hq_score.pth --output_root output/eval \
--config.fid.n_samples 30000 \
--config.fid.sample_method ode_euler_eps
# CelebA-HQ Prox Hybrid
python eval.py \
--config configs/celeba_hq_vae/prox_hybrid.py \
--ckpt_path assets/pretrained_models/celeba_hq_prox_hybrid.pth --output_root output/eval \
--config.fid.n_samples 30000If you find the code useful for your research, please consider citing
@article{fang2025beyond,
title={Beyond Scores: Proximal Diffusion Models},
author={Fang, Zhenghan and D{\'\i}az, Mateo and Buchanan, Sam and Sulam, Jeremias},
journal={arXiv preprint arXiv:2507.08956},
year={2025}
}- score_sde_pytorch (provide reference for baseline score-based samplers)
- pytorch-ddpm (provide implementation of U-Net architecture)
- pytorch-fid (provide FID implementation in PyTorch)
For questions or comments, please contact zfang23@jhu.edu.