Skip to content

mint-vu/sinkhorn_drift

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

5 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Sinkhorn-Drifting Generative Models


Codebase Structure

sinkhorn_paper_code/
├── core/                  # Shared drifting loss (Baseline + Sinkhorn)
│   ├── drifting_loss.py
│   └── models/ema.py
├── toy/                   # Toy example experiments (Section 5.2)
│   ├── Gen_Modeling.py
│   └── plot_w2_meanstd.py
├── mnist/                 # MNIST experiments (Section 5.3)
│   ├── models.py
│   ├── train_ae.py
│   ├── encode_latents.py
│   ├── train_drifting.py
│   ├── eval_emd.py
│   ├── eval_acc.py
│   └── make_figure.py
└── ffhq/                  # FFHQ image generation experiments (Section 5.4)
    ├── drift_ffhq.py
    ├── eval_ckpt_fid_emd.py
    └── fid_score.py

Toy Experiments (Section 5.2)

python toy/Gen_Modeling.py \
  --targets 8-Gaussians,Checkerboard \
  --methods one-sided,two-sided,sinkhorn \
  --eps-list 0.01,0.05,0.1 \
  --drift-impl log \
  --dist-metric l2_sq \
  --sinkhorn-iters 20 \
  --hidden 256 --blocks 6 --dim-in 16 \
  --res-scale 0.9 --out-init-std 0.001 \
  --batch-size 1024 \
  --lr 2e-4 \
  --steps 5000 \
  --eval-every 100 --eval-n 1000 \
  --device cuda:0

This command reproduces the main toy comparison used for Section 5.2 with the Gaussian kernel (dist_metric=l2_sq) and saves:

  • per-eps generated-vs-target scatter grids
  • per-eps W_2^2 curves
  • per-eps combined figures
  • final model checkpoints under the run directory

To aggregate W_2^2 over 5 random seeds:

for SEED in 42 43 44 45 46; do
  python toy/Gen_Modeling.py \
    --targets 8-Gaussians,Checkerboard \
    --methods one-sided,two-sided,sinkhorn \
    --eps-list 0.01,0.05,0.1 \
    --drift-impl log \
    --dist-metric l2_sq \
    --sinkhorn-iters 20 \
    --hidden 256 --blocks 6 --dim-in 16 \
    --res-scale 0.9 --out-init-std 0.001 \
    --batch-size 1024 \
    --lr 2e-4 \
    --steps 5000 \
    --eval-every 100 --eval-n 1000 \
    --device cuda:0 \
    --seed ${SEED} \
    --run-name fig2_l2sq_seed${SEED}
done

python toy/plot_w2_meanstd.py \
  --runs-glob 'runs/*fig2_l2sq_seed*' \
  --targets '8-Gaussians,Checkerboard' \
  --methods 'one-sided,two-sided,sinkhorn' \
  --eps-list '0.01,0.05,0.1' \
  --eps-descending \
  --right-ylabel \
  --legend-subplot-col 1 \
  --title 'Mean ± std over 5 random seeds' \
  --fig-width-in 6.8 --fig-height-in 7.2 \
  --out-pdf figures/fig2_w2_meanstd.pdf \
  --out-png figures/fig2_w2_meanstd.png

Appendix toy figures can be generated by changing:

  • --targets Moons,Spiral
  • --dist-metric l2 for the Laplacian-kernel experiments
  • --dist-metric l2_sq for the Gaussian-kernel experiments


Installation

git clone https://github.com/mint-vu/sinkhorn_drift.git
cd sinkhorn_drift
pip install -e .

MNIST Experiments (Section 5.3)

1. Train Autoencoder

python -m mnist.train_ae --run-name mnist_ae

2. Precompute Test Latents

python -m mnist.encode_latents --ae-ckpt runs/mnist_ae/<run>/ae_final.pt

3. Train Generators (τ sweep, Gaussian kernel)

The following commands reproduce Table 1. Each τ value uses a fixed run name tag (e.g. tau0p005 for τ=0.005) to avoid shell floating-point formatting issues.

AE_CKPT=runs/mnist_ae/<run>/ae_final.pt

COMMON="--ae-ckpt $AE_CKPT \
        --nneg 64 --npos 64 --nuncond 16 --steps 5000 \
        --lr 0.0002 --weight-decay 0.01 \
        --warmup-steps 750 --grad-clip 2.0 --ema-decay 0.999 \
        --omega-min 1.0 --omega-max 4.0 --omega-exponent 3.0 \
        --dist-metric l2_sq --seed 0"

# Baseline
for tag_rho in "tau0p005 0.005" "tau0p01 0.01" "tau0p02 0.02" "tau0p025 0.025" \
               "tau0p03 0.03" "tau0p04 0.04" "tau0p05 0.05" "tau0p1 0.1"; do
    tag=$(echo $tag_rho | cut -d' ' -f1)
    rho=$(echo $tag_rho | cut -d' ' -f2)
    python -m mnist.train_drifting $COMMON \
        --coupling partial_two_sided --drift-form alg2_joint \
        --sinkhorn-marginal none \
        --temps $rho --run-name mnist_baseline_${tag}_l2sq
done

# Sinkhorn
for tag_rho in "tau0p005 0.005" "tau0p01 0.01" "tau0p02 0.02" "tau0p025 0.025" \
               "tau0p03 0.03" "tau0p04 0.04" "tau0p05 0.05" "tau0p1 0.1"; do
    tag=$(echo $tag_rho | cut -d' ' -f1)
    rho=$(echo $tag_rho | cut -d' ' -f2)
    python -m mnist.train_drifting $COMMON \
        --coupling sinkhorn --drift-form split \
        --sinkhorn-iters 20 --sinkhorn-marginal weighted_cols \
        --temps $rho --run-name mnist_sinkhorn_${tag}_l2sq
done

4. Evaluate (W₂² and Accuracy)

AE_CKPT=runs/mnist_ae/<run>/ae_final.pt

for tag in tau0p005 tau0p01 tau0p02 tau0p025 tau0p03 tau0p04 tau0p05 tau0p1; do
    for method in baseline sinkhorn; do
        python -m mnist.eval_emd \
            --gen-ckpt runs/mnist_drift/mnist_${method}_${tag}_l2sq/ckpt_final.pt \
            --ae-ckpt $AE_CKPT --omega 1.0 --data-root ./data
        python -m mnist.eval_acc \
            --gen-ckpt runs/mnist_drift/mnist_${method}_${tag}_l2sq/ckpt_final.pt \
            --ae-ckpt $AE_CKPT --omega 1.0 --data-root ./data
    done
done

5. Train Generators (τ sweep, Laplacian kernel — Appendix)

AE_CKPT=runs/mnist_ae/<run>/ae_final.pt

COMMON_LAP="--ae-ckpt $AE_CKPT \
        --nneg 64 --npos 64 --nuncond 16 --steps 5000 \
        --lr 0.0002 --weight-decay 0.01 \
        --warmup-steps 750 --grad-clip 2.0 --ema-decay 0.999 \
        --omega-min 1.0 --omega-max 4.0 --omega-exponent 3.0 \
        --dist-metric l2 --seed 0"

# Baseline (Laplacian)
for tag_rho in "tau0p005 0.005" "tau0p01 0.01" "tau0p02 0.02" "tau0p025 0.025" \
               "tau0p03 0.03" "tau0p04 0.04" "tau0p05 0.05" "tau0p1 0.1"; do
    tag=$(echo $tag_rho | cut -d' ' -f1)
    rho=$(echo $tag_rho | cut -d' ' -f2)
    python -m mnist.train_drifting $COMMON_LAP \
        --coupling partial_two_sided --drift-form alg2_joint \
        --sinkhorn-marginal none \
        --temps $rho --run-name mnist_baseline_${tag}
done

# Sinkhorn (Laplacian)
for tag_rho in "tau0p005 0.005" "tau0p01 0.01" "tau0p02 0.02" "tau0p025 0.025" \
               "tau0p03 0.03" "tau0p04 0.04" "tau0p05 0.05" "tau0p1 0.1"; do
    tag=$(echo $tag_rho | cut -d' ' -f1)
    rho=$(echo $tag_rho | cut -d' ' -f2)
    python -m mnist.train_drifting $COMMON_LAP \
        --coupling sinkhorn --drift-form split \
        --sinkhorn-iters 20 --sinkhorn-marginal weighted_cols \
        --temps $rho --run-name mnist_sinkhorn_${tag}
done

6. Generate Figures

# Gaussian kernel figure (mnist_gaussian.pdf)
python -m mnist.make_figure \
    --ae-ckpt runs/mnist_ae/<run>/ae_final.pt \
    --out figures/mnist_gaussian.pdf --kernel gaussian

# Laplacian kernel figure (mnist_laplacian.pdf)
python -m mnist.make_figure \
    --ae-ckpt runs/mnist_ae/<run>/ae_final.pt \
    --out figures/mnist_laplacian.pdf --kernel laplacian

FFHQ Experiments (Section 5.4, Table 2 only)

This repository includes the code needed to reproduce the quantitative FFHQ results in Table 2: latent EMD and image FID for baseline vs. Sinkhorn across (\tau \in {0.1, 1.0, 10.0}).

Data format

The FFHQ pipeline expects latent datasets stored as .npz files with exactly the following six keys:

  • male_children
  • male_adult
  • male_old
  • female_children
  • female_adult
  • female_old

Each value should be a float array of shape (N_c, 512), where 512 is the ALAE latent dimension. A typical layout is:

data/
└── ffhq_latents_6class/
    ├── train_latents_by_class.npz
    └── test_latents_by_class.npz

The training script uses:

  • train_latents_by_class.npz as the class-specific target pools
  • test_latents_by_class.npz for periodic EMD monitoring during training

External ALAE dependency

Image FID in Table 2 is computed after decoding latents with the same frozen ALAE decoder used in the paper. This decoder is not bundled in this repository. You must provide an external ALAE checkout that contains:

  • alae_ffhq_inference.py
  • configs/ffhq.yaml
  • training_artifacts/ffhq/

Set:

export ALAE_ROOT=/path/to/ALAE

or pass --alae-root /path/to/ALAE to the evaluation script.

Train the six FFHQ models used in Table 2

Common settings from the paper:

  • architecture: d_z=512, d_e=64, hidden=1024, n_hidden=3
  • training: iters=1000, batch_size=4096, lr=3e-4, emb_lr=1e-3
  • drift: dist=l2_sq, sinkhorn_iters=30
  • logging/eval: emd_every=50, emd_samples=512, log_every=50
  • seed: 42
TRAIN_NPZ=data/ffhq_latents_6class/train_latents_by_class.npz
TEST_NPZ=data/ffhq_latents_6class/test_latents_by_class.npz

for EPS in 0.1 1.0 10.0; do
  TAG=$(echo "$EPS" | tr '.' 'p')

  python -m ffhq.drift_ffhq \
    --train-npz $TRAIN_NPZ \
    --test-npz $TEST_NPZ \
    --save-path runs/ffhq/eps_${TAG}_baseline/drift_ffhq_model.pt \
    --emd-plot runs/ffhq/eps_${TAG}_baseline/drift_ffhq_emd.png \
    --emd-perclass-plot runs/ffhq/eps_${TAG}_baseline/drift_ffhq_emd_perclass.png \
    --pca-plot runs/ffhq/eps_${TAG}_baseline/drift_ffhq_pca.png \
    --d-z 512 --d-e 64 --hidden 1024 --n-hidden 3 \
    --iters 1000 --batch-size 4096 --lr 3e-4 --emb-lr 1e-3 \
    --plan two-sided --eps $EPS --sinkhorn-iters 30 --dist l2_sq \
    --emd-every 50 --emd-samples 512 --log-every 50 \
    --seed 42

  python -m ffhq.drift_ffhq \
    --train-npz $TRAIN_NPZ \
    --test-npz $TEST_NPZ \
    --save-path runs/ffhq/eps_${TAG}_sinkhorn/drift_ffhq_model.pt \
    --emd-plot runs/ffhq/eps_${TAG}_sinkhorn/drift_ffhq_emd.png \
    --emd-perclass-plot runs/ffhq/eps_${TAG}_sinkhorn/drift_ffhq_emd_perclass.png \
    --pca-plot runs/ffhq/eps_${TAG}_sinkhorn/drift_ffhq_pca.png \
    --d-z 512 --d-e 64 --hidden 1024 --n-hidden 3 \
    --iters 1000 --batch-size 4096 --lr 3e-4 --emb-lr 1e-3 \
    --plan sinkhorn --eps $EPS --sinkhorn-iters 30 --dist l2_sq \
    --emd-every 50 --emd-samples 512 --log-every 50 \
    --seed 42
done

Evaluate Table 2 metrics

Table 2 reports:

  • latent EMD (solver=emd, metric=l2_sq)
  • image FID after ALAE decoding at 1024 x 1024

The evaluation uses 1000 real and 1000 generated samples per class.

TRAIN_NPZ=data/ffhq_latents_6class/train_latents_by_class.npz

for EPS in 0.1 1.0 10.0; do
  TAG=$(echo "$EPS" | tr '.' 'p')

  python -m ffhq.eval_ckpt_fid_emd \
    --ckpt-path runs/ffhq/eps_${TAG}_baseline/drift_ffhq_model.pt \
    --real-npz $TRAIN_NPZ \
    --n-per-class 1000 \
    --seed 42 \
    --device cuda:0 \
    --gen-batch 512 \
    --decode-batch 8 \
    --decode-impl batch \
    --fid-batch 64 \
    --save-size 1024 \
    --solver emd \
    --metric l2_sq \
    --ot-iters 200000 \
    --alae-root $ALAE_ROOT \
    --output-dir runs/ffhq/eps_${TAG}_baseline/eval_table2

  python -m ffhq.eval_ckpt_fid_emd \
    --ckpt-path runs/ffhq/eps_${TAG}_sinkhorn/drift_ffhq_model.pt \
    --real-npz $TRAIN_NPZ \
    --n-per-class 1000 \
    --seed 42 \
    --device cuda:0 \
    --gen-batch 512 \
    --decode-batch 8 \
    --decode-impl batch \
    --fid-batch 64 \
    --save-size 1024 \
    --solver emd \
    --metric l2_sq \
    --ot-iters 200000 \
    --alae-root $ALAE_ROOT \
    --output-dir runs/ffhq/eps_${TAG}_sinkhorn/eval_table2
done

Each evaluation directory will contain:

  • metrics_fid_emd.json
  • sampled_latents_by_class.npz
  • real_images/<class>/
  • fake_images/<class>/

The paper numbers are the mean of the six per-class values in each metrics_fid_emd.json.


Pretrained Checkpoints (MNIST)

Pretrained checkpoints are included in pretrained/. The directory structure is:

pretrained/
├── ae/
│   └── ae_final.pt                   # Autoencoder (latent_dim=6, trained 50 epochs)
├── mnist_gaussian/                   # Table 1 — Gaussian kernel 
└── mnist_laplacian/                  # Appendix Table 3 — Laplacian kernel

About

Generative Modeling by Drifting through the Gradient Flow of Sinkhorn Divergence

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages