We evaluate different tasks on the following 4 image datasets:
CLEVRTex
: object segmentation, image reconstruction, compositional generationCelebA
: image reconstruction, compositional generationVOC/COCO
: object segmentation
We will take SlotDiffusion
on CLEVRTex
for example.
The 2 baselines Slot Attention
and SLATE
follow similar steps.
To run on other datasets, simply replace the config file with the desired one.
SlotDiffusion training involves 2 steps: first train a VQ-VAE to discretize images into patch tokens, and then train a slot-conditioned Latent Diffusion Model (LDM) to reconstruct these tokens.
Run the following command to train VQ-VAE (requires 2 or 4 GPUs):
python -m torch.distributed.launch --nproc_per_node=2 --master_port=29501 \
scripts/train.py --task img_based \
--params slotdiffusion/img_based/configs/sa_ldm/vqvae_clevrtex_params-res128.py \
--fp16 --ddp --cudnn
Alternatively, we provide pre-trained VQ-VAE weight as pretrained/vqvae_clevrtex_params-res128.pth
.
Run the following command to train SlotDiffusion on VQ-VAE tokens:
python scripts/train.py --task img_based \
--params slotdiffusion/img_based/configs/sa_ldm/sa_ldm_clevrtex_params-res128.py \
--fp16 --cudnn
Alternatively, we provide pre-trained SlotDiffusion weight as pretrained/sa_ldm_clevrtex_params-res128.pth
.
Run the following command to evaluate the object segmentation performance:
python slotdiffusion/img_based/test_seg.py \
--params slotdiffusion/img_based/configs/sa_ldm/sa_ldm_clevrtex_params-res128.py \
--weight $WEIGHT \
--bs 64 # optional, change to desired value
Run the following command to evaluate the image reconstruction performance (we support DDP testing as reconstruction is slow, especially for SLATE; if you do not need DDP, run with python slotdiffusion/img_based/test_recon.py ...
):
python -m torch.distributed.launch --nproc_per_node=$NUM_GPU --master_port=29501 \
slotdiffusion/img_based/test_recon.py \
--params slotdiffusion/img_based/configs/sa_ldm/sa_ldm_clevrtex_params-res128.py \
--weight $WEIGHT \
--bs 64 # optional, change to desired value
Run the following command to evaluate the image reconstruction performance (DDP to speed up testing as well; replace with python slotdiffusion/img_based/test_comp_gen.py ...
if DDP not needed):
python -m torch.distributed.launch --nproc_per_node=$NUM_GPU --master_port=29501 \
slotdiffusion/img_based/test_comp_gen.py \
--params slotdiffusion/img_based/configs/sa_ldm/sa_ldm_clevrtex_params-res128.py \
--weight $WEIGHT \
--bs 64 # optional, change to desired value
Note:
- The compositional generation implemented here is a simplied version, where we randomly compose slots within a batch to generate novel samples. According to our experiments, the FID result is close to the visual concept library method described in paper Section 3.3. Therefore, we implement it here to simplify the evaluation process
- To compute the FID, you need to manually call the
pytorch-fid
package. Suppose you test the weight located atxxx/model.pth
, we will save the GT images underxxx/eval/gt_imgs/
, and the generated images underxxx/eval/comp_imgs/
. Runpython -m pytorch_fid xxx/eval/gt_imgs xxx/eval/comp_imgs
to compute the FID - The reconstructed images after running
test_recon.py
will be saved underxxx/eval/recon_imgs/
Slot Attention training does not require any pre-trained tokenizers. You can train it with the provided config files.
Similar to SlotDiffusion, SLATE training consists of 2 steps: pre-train dVAE, and then train SLATE. You can train it with the provided config files.