sinkhorn_paper_code/
├── core/ # Shared drifting loss (Baseline + Sinkhorn)
│ ├── drifting_loss.py
│ └── models/ema.py
├── toy/ # Toy example experiments (Section 5.2)
│ ├── Gen_Modeling.py
│ └── plot_w2_meanstd.py
├── mnist/ # MNIST experiments (Section 5.3)
│ ├── models.py
│ ├── train_ae.py
│ ├── encode_latents.py
│ ├── train_drifting.py
│ ├── eval_emd.py
│ ├── eval_acc.py
│ └── make_figure.py
└── ffhq/ # FFHQ image generation experiments (Section 5.4)
├── drift_ffhq.py
├── eval_ckpt_fid_emd.py
└── fid_score.py
python toy/Gen_Modeling.py \
--targets 8-Gaussians,Checkerboard \
--methods one-sided,two-sided,sinkhorn \
--eps-list 0.01,0.05,0.1 \
--drift-impl log \
--dist-metric l2_sq \
--sinkhorn-iters 20 \
--hidden 256 --blocks 6 --dim-in 16 \
--res-scale 0.9 --out-init-std 0.001 \
--batch-size 1024 \
--lr 2e-4 \
--steps 5000 \
--eval-every 100 --eval-n 1000 \
--device cuda:0This command reproduces the main toy comparison used for Section 5.2 with the
Gaussian kernel (dist_metric=l2_sq) and saves:
- per-eps generated-vs-target scatter grids
- per-eps
W_2^2curves - per-eps combined figures
- final model checkpoints under the run directory
To aggregate W_2^2 over 5 random seeds:
for SEED in 42 43 44 45 46; do
python toy/Gen_Modeling.py \
--targets 8-Gaussians,Checkerboard \
--methods one-sided,two-sided,sinkhorn \
--eps-list 0.01,0.05,0.1 \
--drift-impl log \
--dist-metric l2_sq \
--sinkhorn-iters 20 \
--hidden 256 --blocks 6 --dim-in 16 \
--res-scale 0.9 --out-init-std 0.001 \
--batch-size 1024 \
--lr 2e-4 \
--steps 5000 \
--eval-every 100 --eval-n 1000 \
--device cuda:0 \
--seed ${SEED} \
--run-name fig2_l2sq_seed${SEED}
done
python toy/plot_w2_meanstd.py \
--runs-glob 'runs/*fig2_l2sq_seed*' \
--targets '8-Gaussians,Checkerboard' \
--methods 'one-sided,two-sided,sinkhorn' \
--eps-list '0.01,0.05,0.1' \
--eps-descending \
--right-ylabel \
--legend-subplot-col 1 \
--title 'Mean ± std over 5 random seeds' \
--fig-width-in 6.8 --fig-height-in 7.2 \
--out-pdf figures/fig2_w2_meanstd.pdf \
--out-png figures/fig2_w2_meanstd.pngAppendix toy figures can be generated by changing:
--targets Moons,Spiral--dist-metric l2for the Laplacian-kernel experiments--dist-metric l2_sqfor the Gaussian-kernel experiments
git clone https://github.com/mint-vu/sinkhorn_drift.git
cd sinkhorn_drift
pip install -e .python -m mnist.train_ae --run-name mnist_aepython -m mnist.encode_latents --ae-ckpt runs/mnist_ae/<run>/ae_final.ptThe following commands reproduce Table 1. Each τ value uses a fixed run name tag
(e.g. tau0p005 for τ=0.005) to avoid shell floating-point formatting issues.
AE_CKPT=runs/mnist_ae/<run>/ae_final.pt
COMMON="--ae-ckpt $AE_CKPT \
--nneg 64 --npos 64 --nuncond 16 --steps 5000 \
--lr 0.0002 --weight-decay 0.01 \
--warmup-steps 750 --grad-clip 2.0 --ema-decay 0.999 \
--omega-min 1.0 --omega-max 4.0 --omega-exponent 3.0 \
--dist-metric l2_sq --seed 0"
# Baseline
for tag_rho in "tau0p005 0.005" "tau0p01 0.01" "tau0p02 0.02" "tau0p025 0.025" \
"tau0p03 0.03" "tau0p04 0.04" "tau0p05 0.05" "tau0p1 0.1"; do
tag=$(echo $tag_rho | cut -d' ' -f1)
rho=$(echo $tag_rho | cut -d' ' -f2)
python -m mnist.train_drifting $COMMON \
--coupling partial_two_sided --drift-form alg2_joint \
--sinkhorn-marginal none \
--temps $rho --run-name mnist_baseline_${tag}_l2sq
done
# Sinkhorn
for tag_rho in "tau0p005 0.005" "tau0p01 0.01" "tau0p02 0.02" "tau0p025 0.025" \
"tau0p03 0.03" "tau0p04 0.04" "tau0p05 0.05" "tau0p1 0.1"; do
tag=$(echo $tag_rho | cut -d' ' -f1)
rho=$(echo $tag_rho | cut -d' ' -f2)
python -m mnist.train_drifting $COMMON \
--coupling sinkhorn --drift-form split \
--sinkhorn-iters 20 --sinkhorn-marginal weighted_cols \
--temps $rho --run-name mnist_sinkhorn_${tag}_l2sq
doneAE_CKPT=runs/mnist_ae/<run>/ae_final.pt
for tag in tau0p005 tau0p01 tau0p02 tau0p025 tau0p03 tau0p04 tau0p05 tau0p1; do
for method in baseline sinkhorn; do
python -m mnist.eval_emd \
--gen-ckpt runs/mnist_drift/mnist_${method}_${tag}_l2sq/ckpt_final.pt \
--ae-ckpt $AE_CKPT --omega 1.0 --data-root ./data
python -m mnist.eval_acc \
--gen-ckpt runs/mnist_drift/mnist_${method}_${tag}_l2sq/ckpt_final.pt \
--ae-ckpt $AE_CKPT --omega 1.0 --data-root ./data
done
doneAE_CKPT=runs/mnist_ae/<run>/ae_final.pt
COMMON_LAP="--ae-ckpt $AE_CKPT \
--nneg 64 --npos 64 --nuncond 16 --steps 5000 \
--lr 0.0002 --weight-decay 0.01 \
--warmup-steps 750 --grad-clip 2.0 --ema-decay 0.999 \
--omega-min 1.0 --omega-max 4.0 --omega-exponent 3.0 \
--dist-metric l2 --seed 0"
# Baseline (Laplacian)
for tag_rho in "tau0p005 0.005" "tau0p01 0.01" "tau0p02 0.02" "tau0p025 0.025" \
"tau0p03 0.03" "tau0p04 0.04" "tau0p05 0.05" "tau0p1 0.1"; do
tag=$(echo $tag_rho | cut -d' ' -f1)
rho=$(echo $tag_rho | cut -d' ' -f2)
python -m mnist.train_drifting $COMMON_LAP \
--coupling partial_two_sided --drift-form alg2_joint \
--sinkhorn-marginal none \
--temps $rho --run-name mnist_baseline_${tag}
done
# Sinkhorn (Laplacian)
for tag_rho in "tau0p005 0.005" "tau0p01 0.01" "tau0p02 0.02" "tau0p025 0.025" \
"tau0p03 0.03" "tau0p04 0.04" "tau0p05 0.05" "tau0p1 0.1"; do
tag=$(echo $tag_rho | cut -d' ' -f1)
rho=$(echo $tag_rho | cut -d' ' -f2)
python -m mnist.train_drifting $COMMON_LAP \
--coupling sinkhorn --drift-form split \
--sinkhorn-iters 20 --sinkhorn-marginal weighted_cols \
--temps $rho --run-name mnist_sinkhorn_${tag}
done# Gaussian kernel figure (mnist_gaussian.pdf)
python -m mnist.make_figure \
--ae-ckpt runs/mnist_ae/<run>/ae_final.pt \
--out figures/mnist_gaussian.pdf --kernel gaussian
# Laplacian kernel figure (mnist_laplacian.pdf)
python -m mnist.make_figure \
--ae-ckpt runs/mnist_ae/<run>/ae_final.pt \
--out figures/mnist_laplacian.pdf --kernel laplacianThis repository includes the code needed to reproduce the quantitative FFHQ results in Table 2: latent EMD and image FID for baseline vs. Sinkhorn across (\tau \in {0.1, 1.0, 10.0}).
The FFHQ pipeline expects latent datasets stored as .npz files with exactly
the following six keys:
male_childrenmale_adultmale_oldfemale_childrenfemale_adultfemale_old
Each value should be a float array of shape (N_c, 512), where 512 is the
ALAE latent dimension. A typical layout is:
data/
└── ffhq_latents_6class/
├── train_latents_by_class.npz
└── test_latents_by_class.npz
The training script uses:
train_latents_by_class.npzas the class-specific target poolstest_latents_by_class.npzfor periodic EMD monitoring during training
Image FID in Table 2 is computed after decoding latents with the same frozen ALAE decoder used in the paper. This decoder is not bundled in this repository. You must provide an external ALAE checkout that contains:
alae_ffhq_inference.pyconfigs/ffhq.yamltraining_artifacts/ffhq/
Set:
export ALAE_ROOT=/path/to/ALAEor pass --alae-root /path/to/ALAE to the evaluation script.
Common settings from the paper:
- architecture:
d_z=512,d_e=64,hidden=1024,n_hidden=3 - training:
iters=1000,batch_size=4096,lr=3e-4,emb_lr=1e-3 - drift:
dist=l2_sq,sinkhorn_iters=30 - logging/eval:
emd_every=50,emd_samples=512,log_every=50 - seed:
42
TRAIN_NPZ=data/ffhq_latents_6class/train_latents_by_class.npz
TEST_NPZ=data/ffhq_latents_6class/test_latents_by_class.npz
for EPS in 0.1 1.0 10.0; do
TAG=$(echo "$EPS" | tr '.' 'p')
python -m ffhq.drift_ffhq \
--train-npz $TRAIN_NPZ \
--test-npz $TEST_NPZ \
--save-path runs/ffhq/eps_${TAG}_baseline/drift_ffhq_model.pt \
--emd-plot runs/ffhq/eps_${TAG}_baseline/drift_ffhq_emd.png \
--emd-perclass-plot runs/ffhq/eps_${TAG}_baseline/drift_ffhq_emd_perclass.png \
--pca-plot runs/ffhq/eps_${TAG}_baseline/drift_ffhq_pca.png \
--d-z 512 --d-e 64 --hidden 1024 --n-hidden 3 \
--iters 1000 --batch-size 4096 --lr 3e-4 --emb-lr 1e-3 \
--plan two-sided --eps $EPS --sinkhorn-iters 30 --dist l2_sq \
--emd-every 50 --emd-samples 512 --log-every 50 \
--seed 42
python -m ffhq.drift_ffhq \
--train-npz $TRAIN_NPZ \
--test-npz $TEST_NPZ \
--save-path runs/ffhq/eps_${TAG}_sinkhorn/drift_ffhq_model.pt \
--emd-plot runs/ffhq/eps_${TAG}_sinkhorn/drift_ffhq_emd.png \
--emd-perclass-plot runs/ffhq/eps_${TAG}_sinkhorn/drift_ffhq_emd_perclass.png \
--pca-plot runs/ffhq/eps_${TAG}_sinkhorn/drift_ffhq_pca.png \
--d-z 512 --d-e 64 --hidden 1024 --n-hidden 3 \
--iters 1000 --batch-size 4096 --lr 3e-4 --emb-lr 1e-3 \
--plan sinkhorn --eps $EPS --sinkhorn-iters 30 --dist l2_sq \
--emd-every 50 --emd-samples 512 --log-every 50 \
--seed 42
doneTable 2 reports:
- latent EMD (
solver=emd,metric=l2_sq) - image FID after ALAE decoding at
1024 x 1024
The evaluation uses 1000 real and 1000 generated samples per class.
TRAIN_NPZ=data/ffhq_latents_6class/train_latents_by_class.npz
for EPS in 0.1 1.0 10.0; do
TAG=$(echo "$EPS" | tr '.' 'p')
python -m ffhq.eval_ckpt_fid_emd \
--ckpt-path runs/ffhq/eps_${TAG}_baseline/drift_ffhq_model.pt \
--real-npz $TRAIN_NPZ \
--n-per-class 1000 \
--seed 42 \
--device cuda:0 \
--gen-batch 512 \
--decode-batch 8 \
--decode-impl batch \
--fid-batch 64 \
--save-size 1024 \
--solver emd \
--metric l2_sq \
--ot-iters 200000 \
--alae-root $ALAE_ROOT \
--output-dir runs/ffhq/eps_${TAG}_baseline/eval_table2
python -m ffhq.eval_ckpt_fid_emd \
--ckpt-path runs/ffhq/eps_${TAG}_sinkhorn/drift_ffhq_model.pt \
--real-npz $TRAIN_NPZ \
--n-per-class 1000 \
--seed 42 \
--device cuda:0 \
--gen-batch 512 \
--decode-batch 8 \
--decode-impl batch \
--fid-batch 64 \
--save-size 1024 \
--solver emd \
--metric l2_sq \
--ot-iters 200000 \
--alae-root $ALAE_ROOT \
--output-dir runs/ffhq/eps_${TAG}_sinkhorn/eval_table2
doneEach evaluation directory will contain:
metrics_fid_emd.jsonsampled_latents_by_class.npzreal_images/<class>/fake_images/<class>/
The paper numbers are the mean of the six per-class values in each
metrics_fid_emd.json.
Pretrained checkpoints are included in pretrained/. The directory structure is:
pretrained/
├── ae/
│ └── ae_final.pt # Autoencoder (latent_dim=6, trained 50 epochs)
├── mnist_gaussian/ # Table 1 — Gaussian kernel
└── mnist_laplacian/ # Appendix Table 3 — Laplacian kernel