Skip to content

Commit 8ac17cd

Browse files
DN6sayakpaul
andauthored
[Modular] Some clean up for Modular tests (#12579)
* update * update --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent e4393fa commit 8ac17cd

File tree

3 files changed

+269
-199
lines changed

3 files changed

+269
-199
lines changed

tests/modular_pipelines/flux/test_modular_pipeline_flux.py

Lines changed: 76 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
import random
1717
import tempfile
18-
import unittest
1918

2019
import numpy as np
2120
import PIL
@@ -34,21 +33,16 @@
3433
from ..test_modular_pipelines_common import ModularPipelineTesterMixin
3534

3635

37-
class FluxModularTests:
36+
class TestFluxModularPipelineFast(ModularPipelineTesterMixin):
3837
pipeline_class = FluxModularPipeline
3938
pipeline_blocks_class = FluxAutoBlocks
4039
repo = "hf-internal-testing/tiny-flux-modular"
4140

42-
def get_pipeline(self, components_manager=None, torch_dtype=torch.float32):
43-
pipeline = self.pipeline_blocks_class().init_pipeline(self.repo, components_manager=components_manager)
44-
pipeline.load_components(torch_dtype=torch_dtype)
45-
return pipeline
41+
params = frozenset(["prompt", "height", "width", "guidance_scale"])
42+
batch_params = frozenset(["prompt"])
4643

47-
def get_dummy_inputs(self, device, seed=0):
48-
if str(device).startswith("mps"):
49-
generator = torch.manual_seed(seed)
50-
else:
51-
generator = torch.Generator(device=device).manual_seed(seed)
44+
def get_dummy_inputs(self, seed=0):
45+
generator = self.get_generator(seed)
5246
inputs = {
5347
"prompt": "A painting of a squirrel eating a burger",
5448
"generator": generator,
@@ -57,36 +51,47 @@ def get_dummy_inputs(self, device, seed=0):
5751
"height": 8,
5852
"width": 8,
5953
"max_sequence_length": 48,
60-
"output_type": "np",
54+
"output_type": "pt",
6155
}
6256
return inputs
6357

6458

65-
class FluxModularPipelineFastTests(FluxModularTests, ModularPipelineTesterMixin, unittest.TestCase):
66-
params = frozenset(["prompt", "height", "width", "guidance_scale"])
67-
batch_params = frozenset(["prompt"])
68-
59+
class TestFluxImg2ImgModularPipelineFast(ModularPipelineTesterMixin):
60+
pipeline_class = FluxModularPipeline
61+
pipeline_blocks_class = FluxAutoBlocks
62+
repo = "hf-internal-testing/tiny-flux-modular"
6963

70-
class FluxImg2ImgModularPipelineFastTests(FluxModularTests, ModularPipelineTesterMixin, unittest.TestCase):
7164
params = frozenset(["prompt", "height", "width", "guidance_scale", "image"])
7265
batch_params = frozenset(["prompt", "image"])
7366

7467
def get_pipeline(self, components_manager=None, torch_dtype=torch.float32):
7568
pipeline = super().get_pipeline(components_manager, torch_dtype)
69+
7670
# Override `vae_scale_factor` here as currently, `image_processor` is initialized with
7771
# fixed constants instead of
7872
# https://github.com/huggingface/diffusers/blob/d54622c2679d700b425ad61abce9b80fc36212c0/src/diffusers/pipelines/flux/pipeline_flux_img2img.py#L230C9-L232C10
7973
pipeline.image_processor = VaeImageProcessor(vae_scale_factor=2)
8074
return pipeline
8175

82-
def get_dummy_inputs(self, device, seed=0):
83-
inputs = super().get_dummy_inputs(device, seed)
84-
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
85-
image = image / 2 + 0.5
86-
inputs["image"] = image
87-
inputs["strength"] = 0.8
88-
inputs["height"] = 8
89-
inputs["width"] = 8
76+
def get_dummy_inputs(self, seed=0):
77+
generator = self.get_generator(seed)
78+
inputs = {
79+
"prompt": "A painting of a squirrel eating a burger",
80+
"generator": generator,
81+
"num_inference_steps": 4,
82+
"guidance_scale": 5.0,
83+
"height": 8,
84+
"width": 8,
85+
"max_sequence_length": 48,
86+
"output_type": "pt",
87+
}
88+
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(torch_device)
89+
image = image.cpu().permute(0, 2, 3, 1)[0]
90+
init_image = PIL.Image.fromarray(np.uint8(image)).convert("RGB")
91+
92+
inputs["image"] = init_image
93+
inputs["strength"] = 0.5
94+
9095
return inputs
9196

9297
def test_save_from_pretrained(self):
@@ -96,6 +101,7 @@ def test_save_from_pretrained(self):
96101

97102
with tempfile.TemporaryDirectory() as tmpdirname:
98103
base_pipe.save_pretrained(tmpdirname)
104+
99105
pipe = ModularPipeline.from_pretrained(tmpdirname).to(torch_device)
100106
pipe.load_components(torch_dtype=torch.float32)
101107
pipe.to(torch_device)
@@ -105,26 +111,62 @@ def test_save_from_pretrained(self):
105111

106112
image_slices = []
107113
for pipe in pipes:
108-
inputs = self.get_dummy_inputs(torch_device)
114+
inputs = self.get_dummy_inputs()
109115
image = pipe(**inputs, output="images")
110116

111117
image_slices.append(image[0, -3:, -3:, -1].flatten())
112118

113-
assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
119+
assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3
114120

115121

116-
class FluxKontextModularPipelineFastTests(FluxImg2ImgModularPipelineFastTests):
122+
class TestFluxKontextModularPipelineFast(ModularPipelineTesterMixin):
117123
pipeline_class = FluxKontextModularPipeline
118124
pipeline_blocks_class = FluxKontextAutoBlocks
119125
repo = "hf-internal-testing/tiny-flux-kontext-pipe"
120126

121-
def get_dummy_inputs(self, device, seed=0):
122-
inputs = super().get_dummy_inputs(device, seed)
127+
params = frozenset(["prompt", "height", "width", "guidance_scale", "image"])
128+
batch_params = frozenset(["prompt", "image"])
129+
130+
def get_dummy_inputs(self, seed=0):
131+
generator = self.get_generator(seed)
132+
inputs = {
133+
"prompt": "A painting of a squirrel eating a burger",
134+
"generator": generator,
135+
"num_inference_steps": 2,
136+
"guidance_scale": 5.0,
137+
"height": 8,
138+
"width": 8,
139+
"max_sequence_length": 48,
140+
"output_type": "pt",
141+
}
123142
image = PIL.Image.new("RGB", (32, 32), 0)
124-
_ = inputs.pop("strength")
143+
125144
inputs["image"] = image
126-
inputs["height"] = 8
127-
inputs["width"] = 8
128-
inputs["max_area"] = 8 * 8
145+
inputs["max_area"] = inputs["height"] * inputs["width"]
129146
inputs["_auto_resize"] = False
147+
130148
return inputs
149+
150+
def test_save_from_pretrained(self):
151+
pipes = []
152+
base_pipe = self.get_pipeline().to(torch_device)
153+
pipes.append(base_pipe)
154+
155+
with tempfile.TemporaryDirectory() as tmpdirname:
156+
base_pipe.save_pretrained(tmpdirname)
157+
158+
pipe = ModularPipeline.from_pretrained(tmpdirname).to(torch_device)
159+
pipe.load_components(torch_dtype=torch.float32)
160+
pipe.to(torch_device)
161+
pipe.image_processor = VaeImageProcessor(vae_scale_factor=2)
162+
163+
pipes.append(pipe)
164+
165+
image_slices = []
166+
for pipe in pipes:
167+
inputs = self.get_dummy_inputs()
168+
image = pipe(**inputs, output="images")
169+
170+
image_slices.append(image[0, -3:, -3:, -1].flatten())
171+
172+
assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3

0 commit comments

Comments
 (0)