Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions src/diffusers/modular_pipelines/qwenimage/before_denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def expected_components(self) -> List[ComponentSpec]:
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("latents"),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to how it's done in the other pipelines.

InputParam(name="height"),
InputParam(name="width"),
InputParam(name="num_images_per_prompt", default=1),
Expand Down Expand Up @@ -196,11 +197,11 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -
f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)

block_state.latents = randn_tensor(
shape, generator=block_state.generator, device=device, dtype=block_state.dtype
)
block_state.latents = components.pachifier.pack_latents(block_state.latents)
if block_state.latents is None:
block_state.latents = randn_tensor(
shape, generator=block_state.generator, device=device, dtype=block_state.dtype
)
block_state.latents = components.pachifier.pack_latents(block_state.latents)

self.set_block_state(state, block_state)
return components, state
Expand Down Expand Up @@ -549,7 +550,7 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -
block_state.width // components.vae_scale_factor // 2,
)
]
* block_state.batch_size
for _ in range(block_state.batch_size)
Copy link
Member Author

@sayakpaul sayakpaul Nov 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have two options:

  1. Either this
  2. Or how it's done in edit:
    img_shapes = [
    [
    (1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2),
    (1, calculated_height // self.vae_scale_factor // 2, calculated_width // self.vae_scale_factor // 2),
    ]
    ] * batch_size

Regardless, the current implementation isn't exactly the same as how the standard pipeline implements it and would break for the batched input tests we have.

]
block_state.txt_seq_lens = (
block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None
Expand Down
3 changes: 2 additions & 1 deletion src/diffusers/modular_pipelines/qwenimage/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,9 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -
block_state = self.get_block_state(state)

# YiYi Notes: remove support for output_type = "latents', we can just skip decode/encode step in modular
vae_scale_factor = 2 ** len(components.vae.temperal_downsample)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keeping vae_scale_factor fixed to 8, for example, would break the tests as we use a smaller VAE.

block_state.latents = components.pachifier.unpack_latents(
block_state.latents, block_state.height, block_state.width
block_state.latents, block_state.height, block_state.width, vae_scale_factor=vae_scale_factor
)
block_state.latents = block_state.latents.to(components.vae.dtype)

Expand Down
6 changes: 6 additions & 0 deletions src/diffusers/modular_pipelines/qwenimage/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,8 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
block_state.prompt_embeds = block_state.prompt_embeds[:, : block_state.max_sequence_length]
block_state.prompt_embeds_mask = block_state.prompt_embeds_mask[:, : block_state.max_sequence_length]

block_state.negative_prompt_embeds = None
block_state.negative_prompt_embeds_mask = None
Comment on lines +506 to +507
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise, no CFG settings would break.

if components.requires_unconditional_embeds:
negative_prompt = block_state.negative_prompt or ""
block_state.negative_prompt_embeds, block_state.negative_prompt_embeds_mask = get_qwen_prompt_embeds(
Expand Down Expand Up @@ -627,6 +629,8 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
device=device,
)

block_state.negative_prompt_embeds = None
block_state.negative_prompt_embeds_mask = None
if components.requires_unconditional_embeds:
negative_prompt = block_state.negative_prompt or " "
block_state.negative_prompt_embeds, block_state.negative_prompt_embeds_mask = get_qwen_prompt_embeds_edit(
Expand Down Expand Up @@ -679,6 +683,8 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
device=device,
)

block_state.negative_prompt_embeds = None
block_state.negative_prompt_embeds_mask = None
if components.requires_unconditional_embeds:
negative_prompt = block_state.negative_prompt or " "
block_state.negative_prompt_embeds, block_state.negative_prompt_embeds_mask = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,7 @@ class QwenImagePachifier(ConfigMixin):
config_name = "config.json"

@register_to_config
def __init__(
self,
patch_size: int = 2,
):
def __init__(self, patch_size: int = 2):
super().__init__()

def pack_latents(self, latents):
Expand Down
12 changes: 12 additions & 0 deletions tests/modular_pipelines/flux/test_modular_pipeline_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ def get_dummy_inputs(self, seed=0):
}
return inputs

# @pytest.mark.skipif(torch_device == "cpu", reason="Test needs an accelerator.")
def test_float16_inference(self):
super().test_float16_inference(9e-2)


class TestFluxImg2ImgModularPipelineFast(ModularPipelineTesterMixin):
pipeline_class = FluxModularPipeline
Expand Down Expand Up @@ -118,6 +122,10 @@ def test_save_from_pretrained(self):

assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3

# @pytest.mark.skipif(torch_device == "cpu", reason="Test needs an accelerator.")
def test_float16_inference(self):
super().test_float16_inference(8e-2)


class TestFluxKontextModularPipelineFast(ModularPipelineTesterMixin):
pipeline_class = FluxKontextModularPipeline
Expand Down Expand Up @@ -170,3 +178,7 @@ def test_save_from_pretrained(self):
image_slices.append(image[0, -3:, -3:, -1].flatten())

assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3

# @pytest.mark.skipif(torch_device == "cpu", reason="Test needs an accelerator.")
def test_float16_inference(self):
super().test_float16_inference(9e-2)
Empty file.
120 changes: 120 additions & 0 deletions tests/modular_pipelines/qwen/test_modular_pipeline_qwenimage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import PIL
import pytest

from diffusers.modular_pipelines import (
QwenImageAutoBlocks,
QwenImageEditAutoBlocks,
QwenImageEditModularPipeline,
QwenImageEditPlusAutoBlocks,
QwenImageEditPlusModularPipeline,
QwenImageModularPipeline,
)

from ..test_modular_pipelines_common import ModularGuiderTesterMixin, ModularPipelineTesterMixin


class TestQwenImageModularPipelineFast(ModularPipelineTesterMixin, ModularGuiderTesterMixin):
pipeline_class = QwenImageModularPipeline
pipeline_blocks_class = QwenImageAutoBlocks
repo = "hf-internal-testing/tiny-qwenimage-modular"

params = frozenset(["prompt", "height", "width", "negative_prompt", "attention_kwargs", "image", "mask_image"])
batch_params = frozenset(["prompt", "negative_prompt", "image", "mask_image"])

def get_dummy_inputs(self):
generator = self.get_generator()
inputs = {
"prompt": "dance monkey",
"negative_prompt": "bad quality",
"generator": generator,
"num_inference_steps": 2,
"height": 32,
"width": 32,
"max_sequence_length": 16,
"output_type": "pt",
}
return inputs

def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=5e-4)


class TestQwenImageEditModularPipelineFast(ModularPipelineTesterMixin, ModularGuiderTesterMixin):
pipeline_class = QwenImageEditModularPipeline
pipeline_blocks_class = QwenImageEditAutoBlocks
repo = "hf-internal-testing/tiny-qwenimage-edit-modular"

params = frozenset(["prompt", "height", "width", "negative_prompt", "attention_kwargs", "image", "mask_image"])
batch_params = frozenset(["prompt", "negative_prompt", "image", "mask_image"])

def get_dummy_inputs(self):
generator = self.get_generator()
inputs = {
"prompt": "dance monkey",
"negative_prompt": "bad quality",
"generator": generator,
"num_inference_steps": 2,
"height": 32,
"width": 32,
"output_type": "pt",
}
inputs["image"] = PIL.Image.new("RGB", (32, 32), 0)
return inputs

def test_guider_cfg(self):
super().test_guider_cfg(7e-5)


class TestQwenImageEditPlusModularPipelineFast(ModularPipelineTesterMixin, ModularGuiderTesterMixin):
pipeline_class = QwenImageEditPlusModularPipeline
pipeline_blocks_class = QwenImageEditPlusAutoBlocks
repo = "hf-internal-testing/tiny-qwenimage-edit-plus-modular"

# No `mask_image` yet.
params = frozenset(["prompt", "height", "width", "negative_prompt", "attention_kwargs", "image"])
batch_params = frozenset(["prompt", "negative_prompt", "image"])

def get_dummy_inputs(self):
generator = self.get_generator()
inputs = {
"prompt": "dance monkey",
"negative_prompt": "bad quality",
"generator": generator,
"num_inference_steps": 2,
"height": 32,
"width": 32,
"output_type": "pt",
}
inputs["image"] = PIL.Image.new("RGB", (32, 32), 0)
return inputs

@pytest.mark.xfail(condition=True, reason="Batch of multiple images needs to be revisited", strict=True)
def test_num_images_per_prompt(self):
super().test_num_images_per_prompt()

@pytest.mark.xfail(condition=True, reason="Batch of multiple images needs to be revisited", strict=True)
def test_inference_batch_consistent():
super().test_inference_batch_consistent()

@pytest.mark.xfail(condition=True, reason="Batch of multiple images needs to be revisited", strict=True)
def test_inference_batch_single_identical():
super().test_inference_batch_single_identical()
Comment on lines +107 to +117
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are skipped in the standard pipeline tests, too.


def test_guider_cfg(self):
super().test_guider_cfg(1e-3)
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from ...models.unets.test_models_unet_2d_condition import create_ip_adapter_state_dict
from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
from ..test_modular_pipelines_common import ModularPipelineTesterMixin
from ..test_modular_pipelines_common import ModularGuiderTesterMixin, ModularPipelineTesterMixin


enable_full_determinism()
Expand All @@ -37,13 +37,11 @@ class SDXLModularTesterMixin:
"""

def _test_stable_diffusion_xl_euler(self, expected_image_shape, expected_slice, expected_max_diff=1e-2):
sd_pipe = self.get_pipeline()
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
sd_pipe = self.get_pipeline().to(torch_device)

inputs = self.get_dummy_inputs()
image = sd_pipe(**inputs, output="images")
image_slice = image[0, -3:, -3:, -1]
image_slice = image[0, -3:, -3:, -1].cpu()

assert image.shape == expected_image_shape
max_diff = torch.abs(image_slice.flatten() - expected_slice).max()
Expand Down Expand Up @@ -110,7 +108,7 @@ def test_ip_adapter(self, expected_max_diff: float = 1e-4, expected_pipe_slice=N
pipe = blocks.init_pipeline(self.repo)
pipe.load_components(torch_dtype=torch.float32)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)

cross_attention_dim = pipe.unet.config.get("cross_attention_dim")

# forward pass without ip adapter
Expand Down Expand Up @@ -219,9 +217,7 @@ def test_controlnet(self, expected_max_diff: float = 1e-4, expected_pipe_slice=N
# compare against static slices and that can be shaky (with a VVVV low probability).
expected_max_diff = 9e-4 if torch_device == "cpu" else expected_max_diff

pipe = self.get_pipeline()
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe = self.get_pipeline().to(torch_device)

# forward pass without controlnet
inputs = self.get_dummy_inputs()
Expand Down Expand Up @@ -251,9 +247,7 @@ def test_controlnet(self, expected_max_diff: float = 1e-4, expected_pipe_slice=N
assert max_diff_with_controlnet_scale > 1e-2, "Output with controlnet must be different from normal inference"

def test_controlnet_cfg(self):
pipe = self.get_pipeline()
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe = self.get_pipeline().to(torch_device)

# forward pass with CFG not applied
guider = ClassifierFreeGuidance(guidance_scale=1.0)
Expand All @@ -273,35 +267,11 @@ def test_controlnet_cfg(self):
assert max_diff > 1e-2, "Output with CFG must be different from normal inference"


class SDXLModularGuiderTesterMixin:
def test_guider_cfg(self):
pipe = self.get_pipeline()
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)

# forward pass with CFG not applied
guider = ClassifierFreeGuidance(guidance_scale=1.0)
pipe.update_components(guider=guider)

inputs = self.get_dummy_inputs()
out_no_cfg = pipe(**inputs, output="images")

# forward pass with CFG applied
guider = ClassifierFreeGuidance(guidance_scale=7.5)
pipe.update_components(guider=guider)
inputs = self.get_dummy_inputs()
out_cfg = pipe(**inputs, output="images")

assert out_cfg.shape == out_no_cfg.shape
max_diff = np.abs(out_cfg - out_no_cfg).max()
assert max_diff > 1e-2, "Output with CFG must be different from normal inference"


class TestSDXLModularPipelineFast(
SDXLModularTesterMixin,
SDXLModularIPAdapterTesterMixin,
SDXLModularControlNetTesterMixin,
SDXLModularGuiderTesterMixin,
ModularGuiderTesterMixin,
ModularPipelineTesterMixin,
):
"""Test cases for Stable Diffusion XL modular pipeline fast tests."""
Expand Down Expand Up @@ -335,18 +305,7 @@ def test_stable_diffusion_xl_euler(self):
self._test_stable_diffusion_xl_euler(
expected_image_shape=self.expected_image_output_shape,
expected_slice=torch.tensor(
[
0.5966781,
0.62939394,
0.48465094,
0.51573336,
0.57593524,
0.47035995,
0.53410417,
0.51436996,
0.47313565,
],
device=torch_device,
[0.3886, 0.4685, 0.4953, 0.4217, 0.4317, 0.3945, 0.4847, 0.4704, 0.4731],
),
expected_max_diff=1e-2,
)
Expand All @@ -359,7 +318,7 @@ class TestSDXLImg2ImgModularPipelineFast(
SDXLModularTesterMixin,
SDXLModularIPAdapterTesterMixin,
SDXLModularControlNetTesterMixin,
SDXLModularGuiderTesterMixin,
ModularGuiderTesterMixin,
ModularPipelineTesterMixin,
):
"""Test cases for Stable Diffusion XL image-to-image modular pipeline fast tests."""
Expand Down Expand Up @@ -400,20 +359,7 @@ def get_dummy_inputs(self, seed=0):
def test_stable_diffusion_xl_euler(self):
self._test_stable_diffusion_xl_euler(
expected_image_shape=self.expected_image_output_shape,
expected_slice=torch.tensor(
[
0.56943184,
0.4702148,
0.48048905,
0.6235963,
0.551138,
0.49629188,
0.60031277,
0.5688907,
0.43996853,
],
device=torch_device,
),
expected_slice=torch.tensor([0.5246, 0.4466, 0.444, 0.3246, 0.4443, 0.5108, 0.5225, 0.559, 0.5147]),
expected_max_diff=1e-2,
)

Expand All @@ -425,7 +371,7 @@ class SDXLInpaintingModularPipelineFastTests(
SDXLModularTesterMixin,
SDXLModularIPAdapterTesterMixin,
SDXLModularControlNetTesterMixin,
SDXLModularGuiderTesterMixin,
ModularGuiderTesterMixin,
ModularPipelineTesterMixin,
):
"""Test cases for Stable Diffusion XL inpainting modular pipeline fast tests."""
Expand Down
Loading
Loading