Skip to content

Commit 8f80dda

Browse files
sayakpaulyiyixuxu
andauthored
[tests] add tests for flux modular (t2i, i2i, kontext) (#12566)
* start flux modular tests. * up * add kontext * up * up * up * Update src/diffusers/modular_pipelines/flux/denoise.py Co-authored-by: YiYi Xu <[email protected]> * up * up --------- Co-authored-by: YiYi Xu <[email protected]>
1 parent cdbf0ad commit 8f80dda

File tree

8 files changed

+152
-27
lines changed

8 files changed

+152
-27
lines changed

src/diffusers/modular_pipelines/components_manager.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,11 @@ def __call__(self, hooks, model_id, model, execution_device):
164164

165165
device_type = execution_device.type
166166
device_module = getattr(torch, device_type, torch.cuda)
167-
mem_on_device = device_module.mem_get_info(execution_device.index)[0]
167+
try:
168+
mem_on_device = device_module.mem_get_info(execution_device.index)[0]
169+
except AttributeError:
170+
raise AttributeError(f"Do not know how to obtain obtain memory info for {str(device_module)}.")
171+
168172
mem_on_device = mem_on_device - self.memory_reserve_margin
169173
if current_module_size < mem_on_device:
170174
return []
@@ -699,6 +703,8 @@ def enable_auto_cpu_offload(self, device: Union[str, int, torch.device] = None,
699703
if not is_accelerate_available():
700704
raise ImportError("Make sure to install accelerate to use auto_cpu_offload")
701705

706+
# TODO: add a warning if mem_get_info isn't available on `device`.
707+
702708
for name, component in self.components.items():
703709
if isinstance(component, torch.nn.Module) and hasattr(component, "_hf_hook"):
704710
remove_hook_from_module(component, recurse=True)

src/diffusers/modular_pipelines/flux/before_denoise.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -598,7 +598,7 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip
598598
and getattr(block_state, "image_width", None) is not None
599599
):
600600
image_latent_height = 2 * (int(block_state.image_height) // (components.vae_scale_factor * 2))
601-
image_latent_width = 2 * (int(block_state.width) // (components.vae_scale_factor * 2))
601+
image_latent_width = 2 * (int(block_state.image_width) // (components.vae_scale_factor * 2))
602602
img_ids = FluxPipeline._prepare_latent_image_ids(
603603
None, image_latent_height // 2, image_latent_width // 2, device, dtype
604604
)

src/diffusers/modular_pipelines/flux/denoise.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def inputs(self) -> List[Tuple[str, Any]]:
5959
),
6060
InputParam(
6161
"guidance",
62-
required=True,
62+
required=False,
6363
type_hint=torch.Tensor,
6464
description="Guidance scale as a tensor",
6565
),
@@ -141,7 +141,7 @@ def inputs(self) -> List[Tuple[str, Any]]:
141141
),
142142
InputParam(
143143
"guidance",
144-
required=True,
144+
required=False,
145145
type_hint=torch.Tensor,
146146
description="Guidance scale as a tensor",
147147
),

src/diffusers/modular_pipelines/flux/encoders.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def expected_components(self) -> List[ComponentSpec]:
9595
ComponentSpec(
9696
"image_processor",
9797
VaeImageProcessor,
98-
config=FrozenDict({"vae_scale_factor": 16}),
98+
config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 16}),
9999
default_creation_method="from_config",
100100
),
101101
]
@@ -143,10 +143,6 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState):
143143
class FluxKontextProcessImagesInputStep(ModularPipelineBlocks):
144144
model_name = "flux-kontext"
145145

146-
def __init__(self, _auto_resize=True):
147-
self._auto_resize = _auto_resize
148-
super().__init__()
149-
150146
@property
151147
def description(self) -> str:
152148
return (
@@ -167,7 +163,7 @@ def expected_components(self) -> List[ComponentSpec]:
167163

168164
@property
169165
def inputs(self) -> List[InputParam]:
170-
return [InputParam("image")]
166+
return [InputParam("image"), InputParam("_auto_resize", type_hint=bool, default=True)]
171167

172168
@property
173169
def intermediate_outputs(self) -> List[OutputParam]:
@@ -195,7 +191,8 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState):
195191
img = images[0]
196192
image_height, image_width = components.image_processor.get_default_height_width(img)
197193
aspect_ratio = image_width / image_height
198-
if self._auto_resize:
194+
_auto_resize = block_state._auto_resize
195+
if _auto_resize:
199196
# Kontext is trained on specific resolutions, using one of them is recommended
200197
_, image_width, image_height = min(
201198
(abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS

src/diffusers/modular_pipelines/flux/inputs.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,10 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip
112112
block_state.prompt_embeds = block_state.prompt_embeds.view(
113113
block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1
114114
)
115+
pooled_prompt_embeds = block_state.pooled_prompt_embeds.repeat(1, block_state.num_images_per_prompt)
116+
block_state.pooled_prompt_embeds = pooled_prompt_embeds.view(
117+
block_state.batch_size * block_state.num_images_per_prompt, -1
118+
)
115119
self.set_block_state(state, block_state)
116120

117121
return components, state

tests/modular_pipelines/flux/__init__.py

Whitespace-only changes.
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
# coding=utf-8
2+
# Copyright 2025 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import random
17+
import tempfile
18+
import unittest
19+
20+
import numpy as np
21+
import PIL
22+
import torch
23+
24+
from diffusers.image_processor import VaeImageProcessor
25+
from diffusers.modular_pipelines import (
26+
FluxAutoBlocks,
27+
FluxKontextAutoBlocks,
28+
FluxKontextModularPipeline,
29+
FluxModularPipeline,
30+
ModularPipeline,
31+
)
32+
33+
from ...testing_utils import floats_tensor, torch_device
34+
from ..test_modular_pipelines_common import ModularPipelineTesterMixin
35+
36+
37+
class FluxModularTests:
38+
pipeline_class = FluxModularPipeline
39+
pipeline_blocks_class = FluxAutoBlocks
40+
repo = "hf-internal-testing/tiny-flux-modular"
41+
42+
def get_pipeline(self, components_manager=None, torch_dtype=torch.float32):
43+
pipeline = self.pipeline_blocks_class().init_pipeline(self.repo, components_manager=components_manager)
44+
pipeline.load_components(torch_dtype=torch_dtype)
45+
return pipeline
46+
47+
def get_dummy_inputs(self, device, seed=0):
48+
if str(device).startswith("mps"):
49+
generator = torch.manual_seed(seed)
50+
else:
51+
generator = torch.Generator(device=device).manual_seed(seed)
52+
inputs = {
53+
"prompt": "A painting of a squirrel eating a burger",
54+
"generator": generator,
55+
"num_inference_steps": 2,
56+
"guidance_scale": 5.0,
57+
"height": 8,
58+
"width": 8,
59+
"max_sequence_length": 48,
60+
"output_type": "np",
61+
}
62+
return inputs
63+
64+
65+
class FluxModularPipelineFastTests(FluxModularTests, ModularPipelineTesterMixin, unittest.TestCase):
66+
params = frozenset(["prompt", "height", "width", "guidance_scale"])
67+
batch_params = frozenset(["prompt"])
68+
69+
70+
class FluxImg2ImgModularPipelineFastTests(FluxModularTests, ModularPipelineTesterMixin, unittest.TestCase):
71+
params = frozenset(["prompt", "height", "width", "guidance_scale", "image"])
72+
batch_params = frozenset(["prompt", "image"])
73+
74+
def get_pipeline(self, components_manager=None, torch_dtype=torch.float32):
75+
pipeline = super().get_pipeline(components_manager, torch_dtype)
76+
# Override `vae_scale_factor` here as currently, `image_processor` is initialized with
77+
# fixed constants instead of
78+
# https://github.com/huggingface/diffusers/blob/d54622c2679d700b425ad61abce9b80fc36212c0/src/diffusers/pipelines/flux/pipeline_flux_img2img.py#L230C9-L232C10
79+
pipeline.image_processor = VaeImageProcessor(vae_scale_factor=2)
80+
return pipeline
81+
82+
def get_dummy_inputs(self, device, seed=0):
83+
inputs = super().get_dummy_inputs(device, seed)
84+
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
85+
image = image / 2 + 0.5
86+
inputs["image"] = image
87+
inputs["strength"] = 0.8
88+
inputs["height"] = 8
89+
inputs["width"] = 8
90+
return inputs
91+
92+
def test_save_from_pretrained(self):
93+
pipes = []
94+
base_pipe = self.get_pipeline().to(torch_device)
95+
pipes.append(base_pipe)
96+
97+
with tempfile.TemporaryDirectory() as tmpdirname:
98+
base_pipe.save_pretrained(tmpdirname)
99+
pipe = ModularPipeline.from_pretrained(tmpdirname).to(torch_device)
100+
pipe.load_components(torch_dtype=torch.float32)
101+
pipe.to(torch_device)
102+
pipe.image_processor = VaeImageProcessor(vae_scale_factor=2)
103+
104+
pipes.append(pipe)
105+
106+
image_slices = []
107+
for pipe in pipes:
108+
inputs = self.get_dummy_inputs(torch_device)
109+
image = pipe(**inputs, output="images")
110+
111+
image_slices.append(image[0, -3:, -3:, -1].flatten())
112+
113+
assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
114+
115+
116+
class FluxKontextModularPipelineFastTests(FluxImg2ImgModularPipelineFastTests):
117+
pipeline_class = FluxKontextModularPipeline
118+
pipeline_blocks_class = FluxKontextAutoBlocks
119+
repo = "hf-internal-testing/tiny-flux-kontext-pipe"
120+
121+
def get_dummy_inputs(self, device, seed=0):
122+
inputs = super().get_dummy_inputs(device, seed)
123+
image = PIL.Image.new("RGB", (32, 32), 0)
124+
_ = inputs.pop("strength")
125+
inputs["image"] = image
126+
inputs["height"] = 8
127+
inputs["width"] = 8
128+
inputs["max_area"] = 8 * 8
129+
inputs["_auto_resize"] = False
130+
return inputs

tests/modular_pipelines/stable_diffusion_xl/test_modular_pipeline_stable_diffusion_xl.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,24 +21,12 @@
2121
import torch
2222
from PIL import Image
2323

24-
from diffusers import (
25-
ClassifierFreeGuidance,
26-
StableDiffusionXLAutoBlocks,
27-
StableDiffusionXLModularPipeline,
28-
)
24+
from diffusers import ClassifierFreeGuidance, StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline
2925
from diffusers.loaders import ModularIPAdapterMixin
3026

31-
from ...models.unets.test_models_unet_2d_condition import (
32-
create_ip_adapter_state_dict,
33-
)
34-
from ...testing_utils import (
35-
enable_full_determinism,
36-
floats_tensor,
37-
torch_device,
38-
)
39-
from ..test_modular_pipelines_common import (
40-
ModularPipelineTesterMixin,
41-
)
27+
from ...models.unets.test_models_unet_2d_condition import create_ip_adapter_state_dict
28+
from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
29+
from ..test_modular_pipelines_common import ModularPipelineTesterMixin
4230

4331

4432
enable_full_determinism()

0 commit comments

Comments
 (0)