Official Implementation of Diffusion Bridge Implicit Models.
DBIM offers a suite of fast samplers tailored for Denoising Diffusion Bridge Models (DDBMs). We clean the codebase to support a broad range of diffusion bridges, facilitating unified training and sampling workflows. We also streamline the deployment process by replacing the cumbersome MPI-based distributed launcher with the more efficient and engineer-friendly torchrun
.
To install all packages in this codebase along with their dependencies, run
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121
pip install blobfile piq matplotlib opencv-python joblib lmdb scipy clean-fid easydict torchmetrics rich ipdb
Please put the downloaded checkpoints under assets/ckpts/
.
For image translation, we directly adopt the pretrained checkpoints from DDBM:
- Edges2Handbags: e2h_ema_0.9999_420000.pt
- DIODE: diode_ema_0.9999_440000.pt
We remove the dependency on external packages such as flash_attn
in this codebase, which is already supported natively by PyTorch. After downloading the two checkpoints above, please run python preprocess_ckpt.py
to complete the conversion.
For image restoration:
- Center 128x128 Inpainting on ImageNet 256x256: imagenet256_inpaint_ema_0.9999_400000.pt
Please put (or link) the datasets under assets/datasets/
.
- For Edges2Handbags, please follow instructions from here. The resulting folder structure should be
assets/datasets/edges2handbags/train
andassets/datasets/edges2handbags/val
. - For DIODE, please download the training dataset and the data list from here. The resulting folder structure should be
assets/datasets/DIODE/train
andassets/datasets/DIODE/data_list
. - For ImageNet, please download the dataset from here. The resulting folder structure should be
assets/datasets/ImageNet/train
andassets/datasets/ImageNet/val
.
We also provide automatic downloading scripts.
cd assets/datasets
bash download_extract_edges2handbags.sh
bash download_extract_DIODE.sh
bash download_extract_ImageNet.sh
After downloading, the DIODE dataset requires preprocessing by running python preprocess_depth.py
.
bash scripts/sample.sh $DATASET_NAME $NFE $SAMPLER ($AUX)
-
$DATASET_NAME
can be chosen frome2h
/diode
/imagenet_inpaint_center
. -
$NFE
is the Number of Function Evaluations, which is proportional to the sampling time. -
$SAMPLER
can be chosen fromheun
/dbim
/dbim_high_order
.-
heun
is the vanilla sampler of DDBM, which simulates the SDE/ODE step alternatively. In this case,$AUX
is not required. -
dbim
anddbim_high_order
are our proposed samplers. When usingdbim
,$AUX
corresponds to$\eta$ which controls the stochasticity level (floating-point value in$[0,1]$ ). When usingdbim_high_order
,$AUX
corresponds to the order (2 or 3).
-
The samples will be saved to workdir/
.
Before evaluating the image translation results, please download the reference statistics from DDBM and put them under assets/stats/
:
- Reference stats for Edge2Handbags: edges2handbags_ref_64_data.npz.
- Reference stats for DIODE: diode_ref_256_data.npz.
The evaluation can automatically proceed by specifying the same dataset and sampler arguments as sampling:
bash scripts/evaluate.sh $DATASET_NAME $NFE $SAMPLER ($AUX)
This codebase is built upon DDBM and I2SB.
If you find this method and/or code useful, please consider citing
@article{zheng2024diffusion,
title={Diffusion Bridge Implicit Models},
author={Zheng, Kaiwen and He, Guande and Chen, Jianfei and Bao, Fan and Zhu, Jun},
journal={arXiv preprint arXiv:2405.15885},
year={2024}
}