Skip to content
289 changes: 2 additions & 287 deletions src/diffusers/pipelines/flux/pipeline_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,13 @@
from ...models import AutoencoderKL, FluxTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_flux_utils import FluxMixin
from .pipeline_output import FluxPipelineOutput


Expand Down Expand Up @@ -146,6 +143,7 @@ def retrieve_timesteps(

class FluxPipeline(
DiffusionPipeline,
FluxMixin,
FluxLoraLoaderMixin,
FromSingleFileMixin,
TextualInversionLoaderMixin,
Expand Down Expand Up @@ -215,178 +213,6 @@ def __init__(
)
self.default_sample_size = 128

def _get_t5_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
num_images_per_prompt: int = 1,
max_sequence_length: int = 512,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype

prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)

if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)

text_inputs = self.tokenizer_2(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_length=False,
return_overflowing_tokens=False,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids

if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because `max_sequence_length` is set to "
f" {max_sequence_length} tokens: {removed_text}"
)

prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]

dtype = self.text_encoder_2.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)

_, seq_len, _ = prompt_embeds.shape

# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)

return prompt_embeds

def _get_clip_prompt_embeds(
self,
prompt: Union[str, List[str]],
num_images_per_prompt: int = 1,
device: Optional[torch.device] = None,
):
device = device or self._execution_device

prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)

if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)

text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer_max_length,
truncation=True,
return_overflowing_tokens=False,
return_length=False,
return_tensors="pt",
)

text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer_max_length} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)

# Use pooled output of CLIPTextModel
prompt_embeds = prompt_embeds.pooler_output
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)

# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)

return prompt_embeds

def encode_prompt(
self,
prompt: Union[str, List[str]],
prompt_2: Optional[Union[str, List[str]]] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
max_sequence_length: int = 512,
lora_scale: Optional[float] = None,
):
r"""

Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
used in all text-encoders
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
lora_scale (`float`, *optional*):
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
"""
device = device or self._execution_device

# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
self._lora_scale = lora_scale

# dynamically adjust the LoRA scale
if self.text_encoder is not None and USE_PEFT_BACKEND:
scale_lora_layers(self.text_encoder, lora_scale)
if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
scale_lora_layers(self.text_encoder_2, lora_scale)

prompt = [prompt] if isinstance(prompt, str) else prompt

if prompt_embeds is None:
prompt_2 = prompt_2 or prompt
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2

# We only use the pooled prompt output from the CLIPTextModel
pooled_prompt_embeds = self._get_clip_prompt_embeds(
prompt=prompt,
device=device,
num_images_per_prompt=num_images_per_prompt,
)
prompt_embeds = self._get_t5_prompt_embeds(
prompt=prompt_2,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
)

if self.text_encoder is not None:
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder, lora_scale)

if self.text_encoder_2 is not None:
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder_2, lora_scale)

dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)

return prompt_embeds, pooled_prompt_embeds, text_ids

def encode_image(self, image, device, num_images_per_prompt):
dtype = next(self.image_encoder.parameters()).dtype

Expand Down Expand Up @@ -503,97 +329,6 @@ def check_inputs(
if max_sequence_length is not None and max_sequence_length > 512:
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")

@staticmethod
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
latent_image_ids = torch.zeros(height, width, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]

latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape

latent_image_ids = latent_image_ids.reshape(
latent_image_id_height * latent_image_id_width, latent_image_id_channels
)

return latent_image_ids.to(device=device, dtype=dtype)

@staticmethod
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)

return latents

@staticmethod
def _unpack_latents(latents, height, width, vae_scale_factor):
batch_size, num_patches, channels = latents.shape

# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = 2 * (int(height) // (vae_scale_factor * 2))
width = 2 * (int(width) // (vae_scale_factor * 2))

latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5)

latents = latents.reshape(batch_size, channels // (2 * 2), height, width)

return latents

def enable_vae_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
deprecate(
"enable_vae_slicing",
"0.40.0",
depr_message,
)
self.vae.enable_slicing()

def disable_vae_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
deprecate(
"disable_vae_slicing",
"0.40.0",
depr_message,
)
self.vae.disable_slicing()

def enable_vae_tiling(self):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
deprecate(
"enable_vae_tiling",
"0.40.0",
depr_message,
)
self.vae.enable_tiling()

def disable_vae_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
deprecate(
"disable_vae_tiling",
"0.40.0",
depr_message,
)
self.vae.disable_tiling()

def prepare_latents(
self,
batch_size,
Expand Down Expand Up @@ -629,26 +364,6 @@ def prepare_latents(

return latents, latent_image_ids

@property
def guidance_scale(self):
return self._guidance_scale

@property
def joint_attention_kwargs(self):
return self._joint_attention_kwargs

@property
def num_timesteps(self):
return self._num_timesteps

@property
def current_timestep(self):
return self._current_timestep

@property
def interrupt(self):
return self._interrupt

@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
Expand Down
Loading
Loading