By Yingheng Wang, Yair Schiff, Aaron Gokaslan, Weishen Pan, Fei Wang, Chris De Sa, Volodymyr Kuleshov
![]() |
![]() |
We introduce InfoDiffusion, a principled probabilistic extension of diffusion models that supports low-dimensional latents with associated variational learning objectives that are regularized with a mutual information term. We show that these algorithms simultaneously yield high-quality samples and latent representations, achieving competitive performance with state-of-the-art methods on both fronts.
In this repo, we release:
-
The Auxiliary-Variable Diffusion Models (AVDM):
- Diffusion decoder conditioned on auxiliary variable using AdaNorm
- Simplified loss calucation for auxiliary latent variables with semantic prior
-
Baseline implementations [Examples]:
- A set of model variants from the VAE family (VAE,
$\beta$ -VAE, InfoVAE) with different priors (Gaussian, Mixture of Gaussians, spiral). - A simplified version of Diffusion Autoencoder DiffAE within our AVDM framework.
- A minimal and efficient implementation of vanilla diffusion models.
- A set of model variants from the VAE family (VAE,
- Evaluation metrics:
-
Samplers:
- DDPM sampling and DDIM sampling.
- Two phase sampling where these two phases samples from regular diffusion models and VADM consecutivley.
- Latent sampling that has an auxiliary latent diffusion model used to sample
$\mathbf{z}_t$ along with$\mathbf{x}_t$ . - Reverse DDIM sampling to visualize the latent
$\mathbf{x}_T$ from$\mathbf{x}_0$ .
-
run.py
: Routines for training and evaluation -
models.py
: Diffusion models (InfoDiffusion, DiffAE, regular diffusion), VAEs (InfoVAE,$\beta$ -VAE, VAE) -
modules.py
: Neural network blocks -
sampling.py
: DDPM/DDIM sampler, Reverse DDIM sampler, Two-phase sampler, Latent sampler -
utils.py
: LR scheduler, logging, utils to calculate priors -
gen_fid_stats.py
: Generate stats used for FID calculation -
calc_fid.py
: Calculation FID scores
To get started, create a conda environment containing the required dependencies.
conda create -n infodiffusion
conda activate infodiffusion
pip install -r requirements.txt
Run the training using the bash script:
bash run.sh
or
python run.py --model diff --mode train --mmd_weight 0.1 --a_dim 32 --epochs 50 --dataset celeba --batch_size 32 --save_epochs 5 --deterministic --prior regular --r_seed 64
the arguments in this script are given to train a diffusion model --model diff
using Maximum Mean Discrepancy (MMD) --mmd_weight 0.1
with a regular Gaussian prior --prior regular
on CelebA --dataset celeba
.
Below, we describe the steps required for evaluation the trained diffusion models.
Throughout, the main entry point for running experiments is the run.py
script.
We also provide sample bash
scripts for launching these evaluation runs.
In general, different evaluation runs can be switched using --mode
, which takes one of the following values:
eval
: sampling images from the trained diffusion model.eval_fid
: sampling images for FID score calculation.save_latent
: save the auxiliary variables.disentangle
: run evaluation on auxiliary variable disentanglement.interpolate
: run interpolation between two given input images.latent_quality
: save the auxiliary variables and latent variables for classification.train_latent_ddim
: train the latent diffusion models used in latent sampler.plot_latent
: plot the latent space. However, the quantitative disentanglement evaluation, the latent classification, and the FID score calculation need multiple steps.
To evaluate latent disentanglement, we need to conduct the following steps:
-
save_latent.sh
: save the auxiliary variables$\mathbf{z}$ and latent variables$\mathbf{x_T}$ . -
eval_disentangle.sh
: evaluate the latent disentanglement by computing DCI and TAD scores.
python run.py --model diff --mode save_latent --a_dim 256 --mmd_weight 0.1 --epochs 50 --dataset celeba --sampling_number 16 --deterministic --prior regular --r_seed 64
python eval_disentanglement.py --model diff --a_dim 256 --mmd_weight 0.1 --epochs 50 --dataset celeba --sampling_number 16 --deterministic --prior regular --r_seed 64
To run latent classification, we need to conduct the following steps:
-
save_latent.sh
: save the auxiliary variables$\mathbf{z}$ and latent variables$\mathbf{x_T}$ used to train the classifier. -
eval_disentangle.sh
: use the same evaluation script for disentanglement to train the classifier and obtain the classification accuracy.
python run.py --model diff --mode save_latent --a_dim 256 --mmd_weight 0.1 --epochs 50 --dataset celeba --sampling_number 16 --deterministic --prior regular --r_seed 64
python eval_disentanglement.py --model diff --a_dim 256 --mmd_weight 0.1 --epochs 50 --dataset celeba --sampling_number 16 --deterministic --prior regular --r_seed 64
To calculate the FID scores, we need to conduct the following steps:
eval_fid.sh
: train diffusion models and latent diffusion models and generate samples from them.gen_fid.sh
: generate FID stats given the dataset name and the folder storing the preprocessed images from this dataset.calc_fid.sh
: calculate FID scores given the dataset name and the folder storing the generated samples.
We also provide the commands in the above steps:
python run.py --model diff --mode train --mmd_weight 0.1 --a_dim 256 --epochs 50 --dataset celeba --batch_size 32 --save_epochs 5 --deterministic --prior regular --r_seed 64
python run.py --model diff --mode save_latent --disent_metric tad --mmd_weight 0.1 --a_dim 256 --epochs 50 --dataset celeba --deterministic --prior regular --r_seed 64
python run.py --model diff --mode train_latent_ddim --a_dim 256 --epochs 50 --mmd_weight 0.1 --dataset celeba --deterministic --save_epoch 10 --prior regular --r_seed 64
python run.py --model diff --mode eval_fid --split_step 500 --a_dim 256 --batch_size 256 --mmd_weight 0.1 --sampling_number 10000 --epochs 50 --dataset celeba --is_latent --prior regular --r_seed 64
python gen_fid_stats.py celeba ./celeba_imgs
python calc_fid.py celeba ./imgs/celeba_32d_0.1mmd/eval-fid-latent
Note: please refer to clean-fid for more options to calculate FID.
The baselines can be easily switched by using the argument --model
, which takes in one of the following values ['diff', 'vae', 'vanilla']
where 'diff'
is for AVDM, 'vae'
is for the VAE model family, and 'vanilla'
is for the regular diffusion models. Below is an example to train InfoVAE:
python run.py --model vae --mode train --mmd_weight 0.1 --a_dim 32 --epochs 50 --dataset celeba --batch_size 32 --save_epochs 5 --prior regular --r_seed 64
The main
branch provides codes and implementations optimized for representation learning tasks and InfoDiffusion-dev
provides codes closer to the version for reproducing the results reported in the paper.
This research code is provided as-is, without any support or guarantee of quality. However, if you identify any issues or areas for improvement, please feel free to raise an issue or submit a pull request. We will do our best to address them.
@inproceedings{wang2023infodiffusion,
title={Infodiffusion: Representation learning using information maximizing diffusion models},
author={Wang, Yingheng and Schiff, Yair and Gokaslan, Aaron and Pan, Weishen and Wang, Fei and De Sa, Christopher and Kuleshov, Volodymyr},
booktitle={International Conference on Machine Learning},
pages={36336--36354},
year={2023},
organization={PMLR}
}