Skip to content

ZhenghanFang/ProxDM

Repository files navigation

ProxDM

Official implementation of Beyond Scores: Proximal Diffusion Models.

To appear in NeurIPS 2025.

by Zhenghan Fang, Mateo Díaz, Sam Buchanan, and Jeremias Sulam


We introduce a new framework for diffusion models based on backward discretization and proximal operators. Our Proximal Diffusion Models (ProxDM) achieve faster theoretical convergence rate and superior empirical sample quality under limited sampling steps.

FID on CIFAR-10
FID vs NFE on CIFAR-10
ProxDM achieves lower FID with fewer steps.


🚀 Updates

  • 2025/12/01: Added the CelebA-HQ (256x256) experiment, using a latent-space variant of ProxDM.

Environment

Install the dependencies via conda:

conda env create -f environment.yml
conda activate prox_diffusion

or uv:

uv sync
source .venv/bin/activate

Pretrained Models

All pretrained models are available on HuggingFace.

Download the models manually to assets/pretrained_models/, or run scripts/download_models_and_fid_stats.sh.

Data Preparation

1. Toy 2D dataset: Datasaurus Dozen

Download the csv file from here to data/datasaurus/datasaurus.csv.

2. Standard benchmarks: MNIST and CIFAR-10

You can download the datasets using torchvision by running:

python scripts/download_datasets.py
  • MNIST: stored in data/mnist/
  • CIFAR-10: stored in data/cifar10/

3. High-resolution dataset: CelebA-HQ (256x256)

Download from Kaggle - see: scripts/download_celeba_hq_256.sh.

4. Reference statistics for FID

Reference FID statistics for MNIST, CIFAR10 and CelebA-HQ are available on HuggingFace.

Download the stat files manually to assets/fid_stats/, or run scripts/download_models_and_fid_stats.sh.

To self-compute reference stats, see: scripts/prepare_fid_stats.py.

Synthetic Example (Datasaurus "dino")

This experiment evaluates score-based and proximal samplers on a 2D mixture of Dirac deltas (the dino dataset in Datasaurus), using exact score and proximals, and computes Wasserstein distances to the target distribution.

  1. Generate samples and compute Wasserstein distances:
python scripts/low_dim_experiment/sample.py

Results will be saved in output/low_dim/dino/.

  1. Run scripts/low_dim_experiment/plot.ipynb to produce the figures in the paper.

Experiments on MNIST and CIFAR-10

Training

We provide commands to reproduce all ProxDM training experiments in the paper:

# MNIST Score
accelerate launch --multi_gpu --num_processes 4 train.py \
--config configs/mnist/score.py --ckpt_dir output/train/mnist/score

# MNIST Prox Hybrid
accelerate launch --multi_gpu --num_processes 4 train.py \
--config configs/mnist/prox_hybrid.py --ckpt_dir output/train/mnist/prox_hybrid

# MNIST Prox Backward
accelerate launch --multi_gpu --num_processes 4 train.py \
--config configs/mnist/prox_backward.py --ckpt_dir output/train/mnist/prox_backward

# CIFAR10 Score
accelerate launch --multi_gpu --num_processes 4 train.py \
--config configs/cifar10/score.py --ckpt_dir output/train/cifar10/score \

# CIFAR10 Prox Hybrid
accelerate launch --multi_gpu --num_processes 4 train.py \
--config configs/cifar10/prox_hybrid.py --ckpt_dir output/train/cifar10/prox_hybrid \

# CIFAR10 Score (5, 10, 20 steps)
accelerate launch --multi_gpu --num_processes 4 train.py \
--config configs/cifar10/score_subset.py --ckpt_dir output/train/cifar10/score_subset \

# CIFAR10 Prox Hybrid (5, 10, 20 steps)
accelerate launch --multi_gpu --num_processes 4 train.py \
--config configs/cifar10/prox_hybrid_subset.py --ckpt_dir output/train/cifar10/prox_hybrid_subset \

# CIFAR10 Prox Hybrid (5, 10, 20 steps, No heur.)
accelerate launch --multi_gpu --num_processes 4 train.py \
--config configs/cifar10/prox_hybrid_subset_noheur.py --ckpt_dir output/train/cifar10/prox_hybrid_subset_noheur \

# CelebA-HQ Score (latent space)
accelerate launch --multi_gpu --num_processes 8 \
train.py --config configs/celeba_hq_vae/score.py \
--ckpt_dir output/train/celeba_hq/score

# CelebA-HQ Prox Hybrid (latent space)
accelerate launch --multi_gpu --num_processes 8 \
train.py --config configs/celeba_hq_vae/prox_hybrid.py \
--ckpt_dir output/train/celeba_hq/prox_hybrid

Evaluation (Compute FID)

We provide commands to reproduce all FID results in the paper:

BATCH_SIZE=10000

# MNIST Score
python eval.py \
--config configs/mnist/score.py \
--ckpt_path assets/pretrained_models/mnist_score.pth --output_root output/eval \
--config.fid.n_samples 50000 --config.fid.batch_size $BATCH_SIZE \

# MNIST Score ODE
python eval.py \
--config configs/mnist/score.py \
--ckpt_path assets/pretrained_models/mnist_score.pth --output_root output/eval \
--config.fid.n_samples 50000 --config.fid.batch_size $BATCH_SIZE \
--config.fid.sample_method ode_euler_eps \

# MNIST Prox Hybrid
python eval.py \
--config configs/mnist/prox_hybrid.py \
--ckpt_path assets/pretrained_models/mnist_prox_hybrid.pth --output_root output/eval \
--config.fid.n_samples 50000 --config.fid.batch_size $BATCH_SIZE \

# MNIST Prox Backward
python eval.py \
--config configs/mnist/prox_backward.py \
--ckpt_path assets/pretrained_models/mnist_prox_backward.pth --output_root output/eval \
--config.fid.n_samples 50000 --config.fid.batch_size $BATCH_SIZE \
--config.fid.steps "(20,50,100,1000)" \

# CIFAR10 Score
python eval.py \
--config configs/cifar10/score.py \
--ckpt_path assets/pretrained_models/cifar10_score.pth --output_root output/eval \
--config.fid.n_samples 50000 --config.fid.batch_size $BATCH_SIZE \

# CIFAR10 Score ODE
python eval.py \
--config configs/cifar10/score.py \
--ckpt_path assets/pretrained_models/cifar10_score.pth --output_root output/eval \
--config.fid.n_samples 50000 --config.fid.batch_size $BATCH_SIZE \
--config.fid.sample_method ode_euler_eps \

# CIFAR10 Prox Hybrid
python eval.py \
--config configs/cifar10/prox_hybrid.py \
--ckpt_path assets/pretrained_models/cifar10_prox_hybrid.pth --output_root output/eval \
--config.fid.n_samples 50000 --config.fid.batch_size $BATCH_SIZE \

# CIFAR10 Score (5, 10, 20 steps)
python eval.py \
--config configs/cifar10/score_subset.py \
--ckpt_path assets/pretrained_models/cifar10_score_subset.pth --output_root output/eval \
--config.fid.n_samples 50000 --config.fid.batch_size $BATCH_SIZE \
--config.fid.steps "(5,10,20)" \

# CIFAR10 Score ODE (5, 10, 20 steps)
python eval.py \
--config configs/cifar10/score_subset.py \
--ckpt_path assets/pretrained_models/cifar10_score_subset.pth --output_root output/eval \
--config.fid.n_samples 50000 --config.fid.batch_size $BATCH_SIZE \
--config.fid.steps "(5,10,20)" \
--config.fid.sample_method ode_euler_eps \

# CIFAR10 Prox Hybrid (5, 10, 20 steps)
python eval.py \
--config configs/cifar10/prox_hybrid_subset.py \
--ckpt_path assets/pretrained_models/cifar10_prox_hybrid_subset.pth --output_root output/eval \
--config.fid.n_samples 50000 --config.fid.batch_size $BATCH_SIZE \
--config.fid.steps "(5,10,20)" \

# CIFAR10 Prox Hybrid (5, 10, 20 steps, No heur.)
python eval.py \
--config configs/cifar10/prox_hybrid_subset_noheur.py \
--ckpt_path assets/pretrained_models/cifar10_prox_hybrid_subset_noheur.pth --output_root output/eval \
--config.fid.n_samples 50000 --config.fid.batch_size $BATCH_SIZE \
--config.fid.steps "(5,10,20)" \

# CelebA-HQ Score SDE
python eval.py \
--config configs/celeba_hq_vae/score.py \
--ckpt_path assets/pretrained_models/celeba_hq_score.pth --output_root output/eval \
--config.fid.n_samples 30000

# CelebA-HQ Score ODE
python eval.py \
--config configs/celeba_hq_vae/score.py \
--ckpt_path assets/pretrained_models/celeba_hq_score.pth --output_root output/eval \
--config.fid.n_samples 30000 \
--config.fid.sample_method ode_euler_eps

# CelebA-HQ Prox Hybrid
python eval.py \
--config configs/celeba_hq_vae/prox_hybrid.py \
--ckpt_path assets/pretrained_models/celeba_hq_prox_hybrid.pth --output_root output/eval \
--config.fid.n_samples 30000

References

If you find the code useful for your research, please consider citing

@article{fang2025beyond,
  title={Beyond Scores: Proximal Diffusion Models},
  author={Fang, Zhenghan and D{\'\i}az, Mateo and Buchanan, Sam and Sulam, Jeremias},
  journal={arXiv preprint arXiv:2507.08956},
  year={2025}
}

Acknowledgements

For questions or comments, please contact zfang23@jhu.edu.

About

Beyond Scores: Proximal Diffusion Models

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors