Skip to content

Commit 051c8a1

Browse files
authored
Fix Stable Diffusion 3.x pooled prompt embedding with multiple images (#12306)
1 parent d54622c commit 051c8a1

File tree

7 files changed

+7
-7
lines changed

7 files changed

+7
-7
lines changed

src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ def _get_clip_prompt_embeds(
355355
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
356356
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
357357

358-
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
358+
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
359359
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
360360

361361
return prompt_embeds, pooled_prompt_embeds

src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,7 @@ def _get_clip_prompt_embeds(
373373
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
374374
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
375375

376-
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
376+
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
377377
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
378378

379379
return prompt_embeds, pooled_prompt_embeds

src/diffusers/pipelines/pag/pipeline_pag_sd_3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ def _get_clip_prompt_embeds(
326326
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
327327
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
328328

329-
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
329+
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
330330
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
331331

332332
return prompt_embeds, pooled_prompt_embeds

src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ def _get_clip_prompt_embeds(
342342
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
343343
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
344344

345-
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
345+
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
346346
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
347347

348348
return prompt_embeds, pooled_prompt_embeds

src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,7 @@ def _get_clip_prompt_embeds(
336336
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
337337
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
338338

339-
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
339+
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
340340
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
341341

342342
return prompt_embeds, pooled_prompt_embeds

src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,7 @@ def _get_clip_prompt_embeds(
361361
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
362362
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
363363

364-
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
364+
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
365365
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
366366

367367
return prompt_embeds, pooled_prompt_embeds

src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ def _get_clip_prompt_embeds(
367367
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
368368
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
369369

370-
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
370+
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
371371
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
372372

373373
return prompt_embeds, pooled_prompt_embeds

0 commit comments

Comments
 (0)