Skip to content

Standardization of additional token identifiers across pipelines #11334

Open
@sayakpaul

Description

@sayakpaul

FluxPipeline has utilities that give us img_ids and txt_ids:

def _prepare_latent_image_ids(batch_size, height, width, device, dtype):

text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)

As such these are not created inside the transformer class.

Whereas in HiDream, we have something different.

text_ids are created inside the transformer class:

img_ids are overwritten:
https://github.com/huggingface/diffusers/blob/ce1063acfa0cbc2168a7e9dddd4282ab8013b810/src/diffusers/models/transformers/transformer_hidream_image.py#L771C13-L771C20 (probably intentional because it's conditioned)

Then the entire computation

if latents.shape[-2] != latents.shape[-1]:
B, C, H, W = latents.shape
pH, pW = H // self.transformer.config.patch_size, W // self.transformer.config.patch_size
img_sizes = torch.tensor([pH, pW], dtype=torch.int64).reshape(-1)
img_ids = torch.zeros(pH, pW, 3)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW)[None, :]
img_ids = img_ids.reshape(pH * pW, -1)
img_ids_pad = torch.zeros(self.transformer.max_seq, 3)
img_ids_pad[: pH * pW, :] = img_ids
img_sizes = img_sizes.unsqueeze(0).to(latents.device)
img_ids = img_ids_pad.unsqueeze(0).to(latents.device)
if self.do_classifier_free_guidance:
img_sizes = img_sizes.repeat(2 * B, 1)
img_ids = img_ids.repeat(2 * B, 1, 1)
else:
img_sizes = img_ids = None

happens inside the pipeline __call__(). Maybe this could take place inside a method similar to the FluxPipeline?

In general, these could be standardized a bit.

Cc: @yiyixuxu @a-r-r-o-w

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions