Skip to content

AutoencoderDC bug with pipe.enable_vae_slicing() and decoding multiple images #12338

@mingyi456

Description

@mingyi456

Describe the bug

When using the Sana_Sprint_1.6B_1024px and the SANA1.5_4.8B_1024px models, I cannot enable VAE slicing when generating multiple images. I guess this issue will affect the rest of the Sana model and pipeline configurations because they all use the same AutoencoderDC model.

I traced the issue to the following line of code, and if I remove the .sample part the issue seems to be fixed.

I intend to submit a PR for my proposed fix. Can I confirm that this is supposed to be the correct solution?

Reproduction

from diffusers import SanaSprintPipeline
import torch

pipe = SanaSprintPipeline.from_pretrained("Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers", text_encoder=text_encoder, torch_dtype=torch.bfloat16)
pipe.to("cuda")
pipe.enable_vae_slicing()

prompt = "A girl"
num_images_per_prompt = 8
output = pipe(
	prompt=prompt,
	height=1024,
	width=1024,
	num_inference_steps=2,
	num_images_per_prompt=num_images_per_prompt,
	intermediate_timesteps=1.3,
	max_timesteps=1.56830,
	timesteps=None
).images

Logs

Traceback (most recent call last):
  File "F:\AI setups\Diffusers\scripts\inference sana-sprint.py", line 24, in <module>
    output = pipe(
             ^^^^^
  File "F:\AI setups\Diffusers\diffusers-venv\Lib\site-packages\torch\utils\_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "F:\AI setups\Diffusers\diffusers-venv\Lib\site-packages\diffusers\pipelines\sana\pipeline_sana_sprint.py", line 874, in __call__
    image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "F:\AI setups\Diffusers\diffusers-venv\Lib\site-packages\diffusers\utils\accelerate_utils.py", line 46, in wrapper
    return method(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "F:\AI setups\Diffusers\diffusers-venv\Lib\site-packages\diffusers\models\autoencoders\autoencoder_dc.py", line 620, in decode
    decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "F:\AI setups\Diffusers\diffusers-venv\Lib\site-packages\diffusers\models\autoencoders\autoencoder_dc.py", line 620, in <listcomp>
    decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'Tensor' object has no attribute 'sample'

System Info

  • 🤗 Diffusers version: 0.36.0.dev0
  • Platform: Windows-10-10.0.26100-SP0
  • Running on Google Colab?: No
  • Python version: 3.11.9
  • PyTorch version (GPU?): 2.7.1+cu128 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.34.4
  • Transformers version: 4.55.0
  • Accelerate version: 1.10.0
  • PEFT version: 0.17.0
  • Bitsandbytes version: 0.47.0
  • Safetensors version: 0.6.2
  • xFormers version: 0.0.31.post1
  • Accelerator: NVIDIA GeForce RTX 4090, 24564 MiB
  • Using GPU in script?: No
  • Using distributed or parallel set-up in script?: No

Who can help?

@yiyixuxu @DN6

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions