From b8b6465e38a03a5cc4833c19ea4105a98c6e2a76 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Tue, 8 Apr 2025 18:31:12 +0300 Subject: [PATCH 01/76] initial commit --- examples/dreambooth/README_hidream.md | 0 .../train_dreambooth_lora_hidream.py | 1973 +++++++++++++++++ 2 files changed, 1973 insertions(+) create mode 100644 examples/dreambooth/README_hidream.md create mode 100644 examples/dreambooth/train_dreambooth_lora_hidream.py diff --git a/examples/dreambooth/README_hidream.md b/examples/dreambooth/README_hidream.md new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py new file mode 100644 index 000000000000..8040da7efd72 --- /dev/null +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -0,0 +1,1973 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import copy +import itertools +import logging +import math +import os +import random +import shutil +import warnings +from contextlib import nullcontext +from pathlib import Path + +import numpy as np +import torch +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed +from huggingface_hub import create_repo, upload_folder +from huggingface_hub.utils import insecure_hashlib +from peft import LoraConfig, set_peft_model_state_dict +from peft.utils import get_peft_model_state_dict +from PIL import Image +from PIL.ImageOps import exif_transpose +from torch.utils.data import Dataset +from torchvision import transforms +from torchvision.transforms.functional import crop +from tqdm.auto import tqdm +from transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast, LlamaForCausalLM + +import diffusers +from diffusers import ( + AutoencoderKL, + FlowMatchEulerDiscreteScheduler, + UniPCMultistepScheduler, + HiDreamImagePipeline, + HiDreamImageTransformer2DModel, +) +from diffusers.optimization import get_scheduler +from diffusers.training_utils import ( + _set_state_dict_into_text_encoder, + cast_training_params, + compute_density_for_timestep_sampling, + compute_loss_weighting_for_sd3, + free_memory, +) +from diffusers.utils import ( + check_min_version, + convert_unet_state_dict_to_peft, + is_wandb_available, +) +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card +from diffusers.utils.torch_utils import is_compiled_module + + +if is_wandb_available(): + import wandb + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.33.0.dev0") + +logger = get_logger(__name__) + + +def save_model_card( + repo_id: str, + images=None, + base_model: str = None, + train_text_encoder=False, + instance_prompt=None, + validation_prompt=None, + repo_folder=None, +): + widget_dict = [] + if images is not None: + for i, image in enumerate(images): + image.save(os.path.join(repo_folder, f"image_{i}.png")) + widget_dict.append( + {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}} + ) + + model_description = f""" +# Hi Dream Image DreamBooth LoRA - {repo_id} + + + +## Model description + +These are {repo_id} DreamBooth LoRA weights for {base_model}. + +The weights were trained using [DreamBooth](https://dreambooth.github.io/) with the [HiDream diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_hidream.md). + +Was LoRA for the text encoder enabled? {train_text_encoder}. + +## Trigger words + +You should use `{instance_prompt}` to trigger the image generation. + +## Download model + +[Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab. + +## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers) + +```py +from diffusers import AutoPipelineForText2Image +import torch +pipeline = AutoPipelineForText2Image.from_pretrained("", torch_dtype=torch.bfloat16).to('cuda') +pipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors') +image = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0] +``` + +For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) + +## License + +Please adhere to the licensing terms as described [here](). +""" + model_card = load_or_create_model_card( + repo_id_or_path=repo_id, + from_training=True, + license="other", + base_model=base_model, + prompt=instance_prompt, + model_description=model_description, + widget=widget_dict, + ) + tags = [ + "text-to-image", + "diffusers-training", + "diffusers", + "lora", + "hidream", + "hidream-diffusers", + "template:sd-lora", + ] + + model_card = populate_model_card(model_card, tags=tags) + model_card.save(os.path.join(repo_folder, "README.md")) + + +def load_text_encoders(class_one, class_two, class_three, class_four): + text_encoder_one = class_one.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant + ) + text_encoder_two = class_two.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant + ) + text_encoder_three = class_three.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder_3", revision=args.revision, variant=args.variant + ) + # text_encoder_four = class_four.from_pretrained( + # args.pretrained_model_name_or_path, subfolder="text_encoder_4", revision=args.revision, variant=args.variant + # ) + return text_encoder_one, text_encoder_two, text_encoder_three + + +def log_validation( + pipeline, + args, + accelerator, + pipeline_args, + epoch, + torch_dtype, + is_final_validation=False, +): + logger.info( + f"Running validation... \n Generating {args.num_validation_images} images with prompt:" + f" {args.validation_prompt}." + ) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + # run inference + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None + # autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() + autocast_ctx = nullcontext() + + with autocast_ctx: + images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)] + + for tracker in accelerator.trackers: + phase_name = "test" if is_final_validation else "validation" + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + phase_name: [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) + ] + } + ) + + del pipeline + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + return images + + +def import_model_class_from_model_name_or_path( + pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" +): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, subfolder=subfolder, revision=revision + ) + model_class = text_encoder_config.architectures[0] + if model_class == "CLIPTextModelWithProjection": + from transformers import CLIPTextModelWithProjection + + return CLIPTextModelWithProjection + elif model_class == "T5EncoderModel": + from transformers import T5EncoderModel + + return T5EncoderModel + elif model_class == "LlamaForCausalLM": + from transformers import LlamaForCausalLM + + return LlamaForCausalLM + else: + raise ValueError(f"{model_class} is not supported.") + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--instance_data_dir", + type=str, + default=None, + help=("A folder containing the training data. "), + ) + + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + + parser.add_argument( + "--image_column", + type=str, + default="image", + help="The column of the dataset containing the target image. By " + "default, the standard Image Dataset maps out 'file_name' " + "to 'image'.", + ) + parser.add_argument( + "--caption_column", + type=str, + default=None, + help="The column of the dataset containing the instance prompt for each image", + ) + + parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.") + + parser.add_argument( + "--class_data_dir", + type=str, + default=None, + required=False, + help="A folder containing the training data of class images.", + ) + parser.add_argument( + "--instance_prompt", + type=str, + default=None, + required=True, + help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'", + ) + parser.add_argument( + "--class_prompt", + type=str, + default=None, + help="The prompt to specify images in the same class as provided instance images.", + ) + parser.add_argument( + "--max_sequence_length", + type=int, + default=512, + help="Maximum sequence length to use with with the T5 text encoder", + ) + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + help="A prompt that is used during validation to verify that the model is learning.", + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=4, + help="Number of images that should be generated during validation with `validation_prompt`.", + ) + parser.add_argument( + "--validation_epochs", + type=int, + default=50, + help=( + "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`." + ), + ) + parser.add_argument( + "--rank", + type=int, + default=4, + help=("The dimension of the LoRA update matrices."), + ) + parser.add_argument( + "--with_prior_preservation", + default=False, + action="store_true", + help="Flag to add prior preservation loss.", + ) + parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") + parser.add_argument( + "--num_class_images", + type=int, + default=100, + help=( + "Minimal class images for prior preservation loss. If there are not enough images already present in" + " class_data_dir, additional images will be sampled with class_prompt." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="flux-dreambooth-lora", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--random_flip", + action="store_true", + help="whether to randomly flip images horizontally", + ) + parser.add_argument( + "--train_text_encoder", + action="store_true", + help="Whether to train the text encoder. If set, the text encoder should be float32 precision.", + ) + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + + parser.add_argument( + "--guidance_scale", + type=float, + default=3.5, + help="the FLUX.1 dev variant is a guidance distilled model", + ) + + parser.add_argument( + "--text_encoder_lr", + type=float, + default=5e-6, + help="Text encoder learning rate to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument( + "--weighting_scheme", + type=str, + default="none", + choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"], + help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'), + ) + parser.add_argument( + "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument( + "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument( + "--mode_scale", + type=float, + default=1.29, + help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", + ) + parser.add_argument( + "--optimizer", + type=str, + default="AdamW", + help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'), + ) + + parser.add_argument( + "--use_8bit_adam", + action="store_true", + help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW", + ) + + parser.add_argument( + "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers." + ) + parser.add_argument( + "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers." + ) + parser.add_argument( + "--prodigy_beta3", + type=float, + default=None, + help="coefficients for computing the Prodigy stepsize using running averages. If set to None, " + "uses the value of square root of beta2. Ignored if optimizer is adamW", + ) + parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay") + parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params") + parser.add_argument( + "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder" + ) + + parser.add_argument( + "--lora_layers", + type=str, + default=None, + help=( + 'The transformer modules to apply LoRA training on. Please specify the layers in a comma seperated. E.g. - "to_k,to_q,to_v,to_out.0" will result in lora training of attention layers only' + ), + ) + + parser.add_argument( + "--adam_epsilon", + type=float, + default=1e-08, + help="Epsilon value for the Adam optimizer and Prodigy optimizers.", + ) + + parser.add_argument( + "--prodigy_use_bias_correction", + type=bool, + default=True, + help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW", + ) + parser.add_argument( + "--prodigy_safeguard_warmup", + type=bool, + default=True, + help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. " + "Ignored if optimizer is adamW", + ) + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--cache_latents", + action="store_true", + default=False, + help="Cache the VAE latents", + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--upcast_before_saving", + action="store_true", + default=False, + help=( + "Whether to upcast the trained transformer layers to float32 before saving (at the end of training). " + "Defaults to precision dtype used for training to save memory" + ), + ) + parser.add_argument( + "--prior_generation_precision", + type=str, + default=None, + choices=["no", "fp32", "fp16", "bf16"], + help=( + "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32." + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + if args.dataset_name is None and args.instance_data_dir is None: + raise ValueError("Specify either `--dataset_name` or `--instance_data_dir`") + + if args.dataset_name is not None and args.instance_data_dir is not None: + raise ValueError("Specify only one of `--dataset_name` or `--instance_data_dir`") + + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + if args.with_prior_preservation: + if args.class_data_dir is None: + raise ValueError("You must specify a data directory for class images.") + if args.class_prompt is None: + raise ValueError("You must specify prompt for class images.") + else: + # logger is not available yet + if args.class_data_dir is not None: + warnings.warn("You need not use --class_data_dir without --with_prior_preservation.") + if args.class_prompt is not None: + warnings.warn("You need not use --class_prompt without --with_prior_preservation.") + + return args + + +class DreamBoothDataset(Dataset): + """ + A dataset to prepare the instance and class images with the prompts for fine-tuning the model. + It pre-processes the images. + """ + + def __init__( + self, + instance_data_root, + instance_prompt, + class_prompt, + class_data_root=None, + class_num=None, + size=1024, + repeats=1, + center_crop=False, + ): + self.size = size + self.center_crop = center_crop + + self.instance_prompt = instance_prompt + self.custom_instance_prompts = None + self.class_prompt = class_prompt + + # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory, + # we load the training data using load_dataset + if args.dataset_name is not None: + try: + from datasets import load_dataset + except ImportError: + raise ImportError( + "You are trying to load your data using the datasets library. If you wish to train using custom " + "captions please install the datasets library: `pip install datasets`. If you wish to load a " + "local folder containing images only, specify --instance_data_dir instead." + ) + # Downloading and loading a dataset from the hub. + # See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + ) + # Preprocessing the datasets. + column_names = dataset["train"].column_names + + # 6. Get the column names for input/target. + if args.image_column is None: + image_column = column_names[0] + logger.info(f"image column defaulting to {image_column}") + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + instance_images = dataset["train"][image_column] + + if args.caption_column is None: + logger.info( + "No caption column provided, defaulting to instance_prompt for all images. If your dataset " + "contains captions/prompts for the images, make sure to specify the " + "column as --caption_column" + ) + self.custom_instance_prompts = None + else: + if args.caption_column not in column_names: + raise ValueError( + f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + custom_instance_prompts = dataset["train"][args.caption_column] + # create final list of captions according to --repeats + self.custom_instance_prompts = [] + for caption in custom_instance_prompts: + self.custom_instance_prompts.extend(itertools.repeat(caption, repeats)) + else: + self.instance_data_root = Path(instance_data_root) + if not self.instance_data_root.exists(): + raise ValueError("Instance images root doesn't exists.") + + instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())] + self.custom_instance_prompts = None + + self.instance_images = [] + for img in instance_images: + self.instance_images.extend(itertools.repeat(img, repeats)) + + self.pixel_values = [] + train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR) + train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size) + train_flip = transforms.RandomHorizontalFlip(p=1.0) + train_transforms = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + for image in self.instance_images: + image = exif_transpose(image) + if not image.mode == "RGB": + image = image.convert("RGB") + image = train_resize(image) + if args.random_flip and random.random() < 0.5: + # flip + image = train_flip(image) + if args.center_crop: + y1 = max(0, int(round((image.height - args.resolution) / 2.0))) + x1 = max(0, int(round((image.width - args.resolution) / 2.0))) + image = train_crop(image) + else: + y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution)) + image = crop(image, y1, x1, h, w) + image = train_transforms(image) + self.pixel_values.append(image) + + self.num_instance_images = len(self.instance_images) + self._length = self.num_instance_images + + if class_data_root is not None: + self.class_data_root = Path(class_data_root) + self.class_data_root.mkdir(parents=True, exist_ok=True) + self.class_images_path = list(self.class_data_root.iterdir()) + if class_num is not None: + self.num_class_images = min(len(self.class_images_path), class_num) + else: + self.num_class_images = len(self.class_images_path) + self._length = max(self.num_class_images, self.num_instance_images) + else: + self.class_data_root = None + + self.image_transforms = transforms.Compose( + [ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + def __len__(self): + return self._length + + def __getitem__(self, index): + example = {} + instance_image = self.pixel_values[index % self.num_instance_images] + example["instance_images"] = instance_image + + if self.custom_instance_prompts: + caption = self.custom_instance_prompts[index % self.num_instance_images] + if caption: + example["instance_prompt"] = caption + else: + example["instance_prompt"] = self.instance_prompt + + else: # custom prompts were provided, but length does not match size of image dataset + example["instance_prompt"] = self.instance_prompt + + if self.class_data_root: + class_image = Image.open(self.class_images_path[index % self.num_class_images]) + class_image = exif_transpose(class_image) + + if not class_image.mode == "RGB": + class_image = class_image.convert("RGB") + example["class_images"] = self.image_transforms(class_image) + example["class_prompt"] = self.class_prompt + + return example + + +def collate_fn(examples, with_prior_preservation=False): + pixel_values = [example["instance_images"] for example in examples] + prompts = [example["instance_prompt"] for example in examples] + + # Concat class and instance examples for prior preservation. + # We do this to avoid doing two forward passes. + if with_prior_preservation: + pixel_values += [example["class_images"] for example in examples] + prompts += [example["class_prompt"] for example in examples] + + pixel_values = torch.stack(pixel_values) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + batch = {"pixel_values": pixel_values, "prompts": prompts} + return batch + + +class PromptDataset(Dataset): + "A simple dataset to prepare the prompts to generate class images on multiple GPUs." + + def __init__(self, prompt, num_samples): + self.prompt = prompt + self.num_samples = num_samples + + def __len__(self): + return self.num_samples + + def __getitem__(self, index): + example = {} + example["prompt"] = self.prompt + example["index"] = index + return example + + +def tokenize_prompt(tokenizer, prompt, max_sequence_length): + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + return text_input_ids + +def _encode_prompt_with_llama3( + text_encoder, + tokenizer, + prompt: str, + max_sequence_length: int = 128, + device=None, + text_input_ids=None, + num_images_per_prompt: int = 1, +): + device = device or self._execution_device + dtype = dtype or self.text_encoder_4.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer_4( + prompt, + padding="max_length", + max_length=min(max_sequence_length, self.tokenizer_4.model_max_length), + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + attention_mask = text_inputs.attention_mask + untruncated_ids = self.tokenizer_4(prompt, padding="longest", return_tensors="pt").input_ids +def _encode_prompt_with_t5( + text_encoder, + tokenizer, + max_sequence_length=512, + prompt=None, + num_images_per_prompt=1, + device=None, + text_input_ids=None, +): + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if tokenizer is not None: + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + else: + if text_input_ids is None: + raise ValueError("text_input_ids must be provided when the tokenizer is not specified") + + prompt_embeds = text_encoder(text_input_ids.to(device))[0] + + if hasattr(text_encoder, "module"): + dtype = text_encoder.module.dtype + else: + dtype = text_encoder.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + +def _encode_prompt_with_clip( + text_encoder, + tokenizer, + prompt: str, + device=None, + text_input_ids=None, + num_images_per_prompt: int = 1, +): + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if tokenizer is not None: + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=77, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + else: + if text_input_ids is None: + raise ValueError("text_input_ids must be provided when the tokenizer is not specified") + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False) + + if hasattr(text_encoder, "module"): + dtype = text_encoder.module.dtype + else: + dtype = text_encoder.dtype + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds + + +def encode_prompt( + text_encoders, + tokenizers, + prompt: str, + max_sequence_length, + device=None, + num_images_per_prompt: int = 1, + text_input_ids_list=None, +): + prompt = [prompt] if isinstance(prompt, str) else prompt + + if hasattr(text_encoders[0], "module"): + dtype = text_encoders[0].module.dtype + else: + dtype = text_encoders[0].dtype + + pooled_prompt_embeds = _encode_prompt_with_clip( + text_encoder=text_encoders[0], + tokenizer=tokenizers[0], + prompt=prompt, + device=device if device is not None else text_encoders[0].device, + num_images_per_prompt=num_images_per_prompt, + text_input_ids=text_input_ids_list[0] if text_input_ids_list else None, + ) + + prompt_embeds = _encode_prompt_with_t5( + text_encoder=text_encoders[1], + tokenizer=tokenizers[1], + max_sequence_length=max_sequence_length, + prompt=prompt, + num_images_per_prompt=num_images_per_prompt, + device=device if device is not None else text_encoders[1].device, + text_input_ids=text_input_ids_list[1] if text_input_ids_list else None, + ) + + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + + return prompt_embeds, pooled_prompt_embeds, text_ids + + +def main(args): + if args.report_to == "wandb" and args.hub_token is not None: + raise ValueError( + "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." + " Please use `huggingface-cli login` to authenticate with the Hub." + ) + + if torch.backends.mps.is_available() and args.mixed_precision == "bf16": + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + kwargs_handlers=[kwargs], + ) + + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + + if args.report_to == "wandb": + if not is_wandb_available(): + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Generate class images if prior preservation is enabled. + if args.with_prior_preservation: + class_images_dir = Path(args.class_data_dir) + if not class_images_dir.exists(): + class_images_dir.mkdir(parents=True) + cur_class_images = len(list(class_images_dir.iterdir())) + + if cur_class_images < args.num_class_images: + has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available() + torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32 + if args.prior_generation_precision == "fp32": + torch_dtype = torch.float32 + elif args.prior_generation_precision == "fp16": + torch_dtype = torch.float16 + elif args.prior_generation_precision == "bf16": + torch_dtype = torch.bfloat16 + pipeline = FluxPipeline.from_pretrained( + args.pretrained_model_name_or_path, + torch_dtype=torch_dtype, + revision=args.revision, + variant=args.variant, + ) + pipeline.set_progress_bar_config(disable=True) + + num_new_images = args.num_class_images - cur_class_images + logger.info(f"Number of class images to sample: {num_new_images}.") + + sample_dataset = PromptDataset(args.class_prompt, num_new_images) + sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) + + sample_dataloader = accelerator.prepare(sample_dataloader) + pipeline.to(accelerator.device) + + for example in tqdm( + sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process + ): + images = pipeline(example["prompt"]).images + + for i, image in enumerate(images): + hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest() + image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" + image.save(image_filename) + + del pipeline + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, + exist_ok=True, + ).repo_id + + # Load the tokenizers + tokenizer_one = CLIPTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + ) + tokenizer_two = T5TokenizerFast.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer_2", + revision=args.revision, + ) + + # import correct text encoder classes + text_encoder_cls_one = import_model_class_from_model_name_or_path( + args.pretrained_model_name_or_path, args.revision + ) + text_encoder_cls_two = import_model_class_from_model_name_or_path( + args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2" + ) + + # Load scheduler and models + noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + args.pretrained_model_name_or_path, subfolder="scheduler" + ) + noise_scheduler_copy = copy.deepcopy(noise_scheduler) + text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two) + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="vae", + revision=args.revision, + variant=args.variant, + ) + transformer = FluxTransformer2DModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant + ) + + # We only train the additional adapter LoRA layers + transformer.requires_grad_(False) + vae.requires_grad_(False) + text_encoder_one.requires_grad_(False) + text_encoder_two.requires_grad_(False) + + # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16: + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + + vae.to(accelerator.device, dtype=weight_dtype) + transformer.to(accelerator.device, dtype=weight_dtype) + text_encoder_one.to(accelerator.device, dtype=weight_dtype) + text_encoder_two.to(accelerator.device, dtype=weight_dtype) + + if args.gradient_checkpointing: + transformer.enable_gradient_checkpointing() + if args.train_text_encoder: + text_encoder_one.gradient_checkpointing_enable() + + if args.lora_layers is not None: + target_modules = [layer.strip() for layer in args.lora_layers.split(",")] + else: + target_modules = [ + "attn.to_k", + "attn.to_q", + "attn.to_v", + "attn.to_out.0", + "attn.add_k_proj", + "attn.add_q_proj", + "attn.add_v_proj", + "attn.to_add_out", + "ff.net.0.proj", + "ff.net.2", + "ff_context.net.0.proj", + "ff_context.net.2", + ] + + # now we will add new LoRA weights the transformer layers + transformer_lora_config = LoraConfig( + r=args.rank, + lora_alpha=args.rank, + init_lora_weights="gaussian", + target_modules=target_modules, + ) + transformer.add_adapter(transformer_lora_config) + if args.train_text_encoder: + text_lora_config = LoraConfig( + r=args.rank, + lora_alpha=args.rank, + init_lora_weights="gaussian", + target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], + ) + text_encoder_one.add_adapter(text_lora_config) + + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + transformer_lora_layers_to_save = None + text_encoder_one_lora_layers_to_save = None + + for model in models: + if isinstance(model, type(unwrap_model(transformer))): + transformer_lora_layers_to_save = get_peft_model_state_dict(model) + elif isinstance(model, type(unwrap_model(text_encoder_one))): + text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model) + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + FluxPipeline.save_lora_weights( + output_dir, + transformer_lora_layers=transformer_lora_layers_to_save, + text_encoder_lora_layers=text_encoder_one_lora_layers_to_save, + ) + + def load_model_hook(models, input_dir): + transformer_ = None + text_encoder_one_ = None + + while len(models) > 0: + model = models.pop() + + if isinstance(model, type(unwrap_model(transformer))): + transformer_ = model + elif isinstance(model, type(unwrap_model(text_encoder_one))): + text_encoder_one_ = model + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + lora_state_dict = FluxPipeline.lora_state_dict(input_dir) + + transformer_state_dict = { + f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") + } + transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) + incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) + if args.train_text_encoder: + # Do we need to call `scale_lora_layers()` here? + _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_) + + # Make sure the trainable params are in float32. This is again needed since the base models + # are in `weight_dtype`. More details: + # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 + if args.mixed_precision == "fp16": + models = [transformer_] + if args.train_text_encoder: + models.extend([text_encoder_one_]) + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params(models) + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32 and torch.cuda.is_available(): + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Make sure the trainable params are in float32. + if args.mixed_precision == "fp16": + models = [transformer] + if args.train_text_encoder: + models.extend([text_encoder_one]) + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params(models, dtype=torch.float32) + + transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters())) + if args.train_text_encoder: + text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters())) + + # Optimization parameters + transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate} + if args.train_text_encoder: + # different learning rate for text encoder and unet + text_parameters_one_with_lr = { + "params": text_lora_parameters_one, + "weight_decay": args.adam_weight_decay_text_encoder, + "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate, + } + params_to_optimize = [transformer_parameters_with_lr, text_parameters_one_with_lr] + else: + params_to_optimize = [transformer_parameters_with_lr] + + # Optimizer creation + if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"): + logger.warning( + f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]." + "Defaulting to adamW" + ) + args.optimizer = "adamw" + + if args.use_8bit_adam and not args.optimizer.lower() == "adamw": + logger.warning( + f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was " + f"set to {args.optimizer.lower()}" + ) + + if args.optimizer.lower() == "adamw": + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + optimizer = optimizer_class( + params_to_optimize, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + if args.optimizer.lower() == "prodigy": + try: + import prodigyopt + except ImportError: + raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`") + + optimizer_class = prodigyopt.Prodigy + + if args.learning_rate <= 0.1: + logger.warning( + "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" + ) + if args.train_text_encoder and args.text_encoder_lr: + logger.warning( + f"Learning rates were provided both for the transformer and the text encoder- e.g. text_encoder_lr:" + f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. " + f"When using prodigy only learning_rate is used as the initial learning rate." + ) + # changes the learning rate of text_encoder_parameters_one to be + # --learning_rate + params_to_optimize[1]["lr"] = args.learning_rate + + optimizer = optimizer_class( + params_to_optimize, + betas=(args.adam_beta1, args.adam_beta2), + beta3=args.prodigy_beta3, + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + decouple=args.prodigy_decouple, + use_bias_correction=args.prodigy_use_bias_correction, + safeguard_warmup=args.prodigy_safeguard_warmup, + ) + + # Dataset and DataLoaders creation: + train_dataset = DreamBoothDataset( + instance_data_root=args.instance_data_dir, + instance_prompt=args.instance_prompt, + class_prompt=args.class_prompt, + class_data_root=args.class_data_dir if args.with_prior_preservation else None, + class_num=args.num_class_images, + size=args.resolution, + repeats=args.repeats, + center_crop=args.center_crop, + ) + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_size=args.train_batch_size, + shuffle=True, + collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation), + num_workers=args.dataloader_num_workers, + ) + + if not args.train_text_encoder: + tokenizers = [tokenizer_one, tokenizer_two] + text_encoders = [text_encoder_one, text_encoder_two] + + def compute_text_embeddings(prompt, text_encoders, tokenizers): + with torch.no_grad(): + prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( + text_encoders, tokenizers, prompt, args.max_sequence_length + ) + prompt_embeds = prompt_embeds.to(accelerator.device) + pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device) + text_ids = text_ids.to(accelerator.device) + return prompt_embeds, pooled_prompt_embeds, text_ids + + # If no type of tuning is done on the text_encoder and custom instance prompts are NOT + # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid + # the redundant encoding. + if not args.train_text_encoder and not train_dataset.custom_instance_prompts: + instance_prompt_hidden_states, instance_pooled_prompt_embeds, instance_text_ids = compute_text_embeddings( + args.instance_prompt, text_encoders, tokenizers + ) + + # Handle class prompt for prior-preservation. + if args.with_prior_preservation: + if not args.train_text_encoder: + class_prompt_hidden_states, class_pooled_prompt_embeds, class_text_ids = compute_text_embeddings( + args.class_prompt, text_encoders, tokenizers + ) + + # Clear the memory here + if not args.train_text_encoder and not train_dataset.custom_instance_prompts: + del text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two + free_memory() + + # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), + # pack the statically computed variables appropriately here. This is so that we don't + # have to pass them to the dataloader. + + if not train_dataset.custom_instance_prompts: + if not args.train_text_encoder: + prompt_embeds = instance_prompt_hidden_states + pooled_prompt_embeds = instance_pooled_prompt_embeds + text_ids = instance_text_ids + if args.with_prior_preservation: + prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0) + pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, class_pooled_prompt_embeds], dim=0) + text_ids = torch.cat([text_ids, class_text_ids], dim=0) + # if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts) + # we need to tokenize and encode the batch prompts on all training steps + else: + tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt, max_sequence_length=77) + tokens_two = tokenize_prompt( + tokenizer_two, args.instance_prompt, max_sequence_length=args.max_sequence_length + ) + if args.with_prior_preservation: + class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt, max_sequence_length=77) + class_tokens_two = tokenize_prompt( + tokenizer_two, args.class_prompt, max_sequence_length=args.max_sequence_length + ) + tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0) + tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0) + + vae_config_shift_factor = vae.config.shift_factor + vae_config_scaling_factor = vae.config.scaling_factor + vae_config_block_out_channels = vae.config.block_out_channels + if args.cache_latents: + latents_cache = [] + for batch in tqdm(train_dataloader, desc="Caching latents"): + with torch.no_grad(): + batch["pixel_values"] = batch["pixel_values"].to( + accelerator.device, non_blocking=True, dtype=weight_dtype + ) + latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist) + + if args.validation_prompt is None: + del vae + free_memory() + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + # Prepare everything with our `accelerator`. + if args.train_text_encoder: + ( + transformer, + text_encoder_one, + optimizer, + train_dataloader, + lr_scheduler, + ) = accelerator.prepare( + transformer, + text_encoder_one, + optimizer, + train_dataloader, + lr_scheduler, + ) + else: + transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + transformer, optimizer, train_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_name = "dreambooth-flux-dev-lora" + accelerator.init_trackers(tracker_name, config=vars(args)) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the mos recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): + sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype) + schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device) + timesteps = timesteps.to(accelerator.device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + for epoch in range(first_epoch, args.num_train_epochs): + transformer.train() + if args.train_text_encoder: + text_encoder_one.train() + # set top parameter requires_grad = True for gradient checkpointing works + unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True) + + for step, batch in enumerate(train_dataloader): + models_to_accumulate = [transformer] + if args.train_text_encoder: + models_to_accumulate.extend([text_encoder_one]) + with accelerator.accumulate(models_to_accumulate): + prompts = batch["prompts"] + + # encode batch prompts when custom prompts are provided for each image - + if train_dataset.custom_instance_prompts: + if not args.train_text_encoder: + prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings( + prompts, text_encoders, tokenizers + ) + else: + tokens_one = tokenize_prompt(tokenizer_one, prompts, max_sequence_length=77) + tokens_two = tokenize_prompt( + tokenizer_two, prompts, max_sequence_length=args.max_sequence_length + ) + prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( + text_encoders=[text_encoder_one, text_encoder_two], + tokenizers=[None, None], + text_input_ids_list=[tokens_one, tokens_two], + max_sequence_length=args.max_sequence_length, + device=accelerator.device, + prompt=prompts, + ) + else: + elems_to_repeat = len(prompts) + if args.train_text_encoder: + prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( + text_encoders=[text_encoder_one, text_encoder_two], + tokenizers=[None, None], + text_input_ids_list=[ + tokens_one.repeat(elems_to_repeat, 1), + tokens_two.repeat(elems_to_repeat, 1), + ], + max_sequence_length=args.max_sequence_length, + device=accelerator.device, + prompt=args.instance_prompt, + ) + + # Convert images to latent space + if args.cache_latents: + model_input = latents_cache[step].sample() + else: + pixel_values = batch["pixel_values"].to(dtype=vae.dtype) + model_input = vae.encode(pixel_values).latent_dist.sample() + model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor + model_input = model_input.to(dtype=weight_dtype) + + vae_scale_factor = 2 ** (len(vae_config_block_out_channels) - 1) + + latent_image_ids = FluxPipeline._prepare_latent_image_ids( + model_input.shape[0], + model_input.shape[2] // 2, + model_input.shape[3] // 2, + accelerator.device, + weight_dtype, + ) + # Sample noise that we'll add to the latents + noise = torch.randn_like(model_input) + bsz = model_input.shape[0] + + # Sample a random timestep for each image + # for weighting schemes where we sample timesteps non-uniformly + u = compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=bsz, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, + ) + indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() + timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device) + + # Add noise according to flow matching. + # zt = (1 - texp) * x + texp * z1 + sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) + noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise + + packed_noisy_model_input = FluxPipeline._pack_latents( + noisy_model_input, + batch_size=model_input.shape[0], + num_channels_latents=model_input.shape[1], + height=model_input.shape[2], + width=model_input.shape[3], + ) + + # handle guidance + if unwrap_model(transformer).config.guidance_embeds: + guidance = torch.tensor([args.guidance_scale], device=accelerator.device) + guidance = guidance.expand(model_input.shape[0]) + else: + guidance = None + + # Predict the noise residual + model_pred = transformer( + hidden_states=packed_noisy_model_input, + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) + timestep=timesteps / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + return_dict=False, + )[0] + model_pred = FluxPipeline._unpack_latents( + model_pred, + height=model_input.shape[2] * vae_scale_factor, + width=model_input.shape[3] * vae_scale_factor, + vae_scale_factor=vae_scale_factor, + ) + + # these weighting schemes use a uniform timestep sampling + # and instead post-weight the loss + weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + + # flow matching loss + target = noise - model_input + + if args.with_prior_preservation: + # Chunk the noise and model_pred into two parts and compute the loss on each part separately. + model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) + target, target_prior = torch.chunk(target, 2, dim=0) + + # Compute prior loss + prior_loss = torch.mean( + (weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape( + target_prior.shape[0], -1 + ), + 1, + ) + prior_loss = prior_loss.mean() + + # Compute regular loss. + loss = torch.mean( + (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), + 1, + ) + loss = loss.mean() + + if args.with_prior_preservation: + # Add the prior loss to the instance loss. + loss = loss + args.prior_loss_weight * prior_loss + + accelerator.backward(loss) + if accelerator.sync_gradients: + params_to_clip = ( + itertools.chain(transformer.parameters(), text_encoder_one.parameters()) + if args.train_text_encoder + else transformer.parameters() + ) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + if accelerator.is_main_process: + if args.validation_prompt is not None and epoch % args.validation_epochs == 0: + # create pipeline + if not args.train_text_encoder: + text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two) + text_encoder_one.to(weight_dtype) + text_encoder_two.to(weight_dtype) + pipeline = FluxPipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=vae, + text_encoder=unwrap_model(text_encoder_one), + text_encoder_2=unwrap_model(text_encoder_two), + transformer=unwrap_model(transformer), + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + pipeline_args = {"prompt": args.validation_prompt} + images = log_validation( + pipeline=pipeline, + args=args, + accelerator=accelerator, + pipeline_args=pipeline_args, + epoch=epoch, + torch_dtype=weight_dtype, + ) + if not args.train_text_encoder: + del text_encoder_one, text_encoder_two + free_memory() + + images = None + del pipeline + + # Save the lora layers + accelerator.wait_for_everyone() + if accelerator.is_main_process: + transformer = unwrap_model(transformer) + if args.upcast_before_saving: + transformer.to(torch.float32) + else: + transformer = transformer.to(weight_dtype) + transformer_lora_layers = get_peft_model_state_dict(transformer) + + if args.train_text_encoder: + text_encoder_one = unwrap_model(text_encoder_one) + text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32)) + else: + text_encoder_lora_layers = None + + FluxPipeline.save_lora_weights( + save_directory=args.output_dir, + transformer_lora_layers=transformer_lora_layers, + text_encoder_lora_layers=text_encoder_lora_layers, + ) + + # Final inference + # Load previous pipeline + pipeline = FluxPipeline.from_pretrained( + args.pretrained_model_name_or_path, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + # load attention processors + pipeline.load_lora_weights(args.output_dir) + + # run inference + images = [] + if args.validation_prompt and args.num_validation_images > 0: + pipeline_args = {"prompt": args.validation_prompt} + images = log_validation( + pipeline=pipeline, + args=args, + accelerator=accelerator, + pipeline_args=pipeline_args, + epoch=epoch, + is_final_validation=True, + torch_dtype=weight_dtype, + ) + + if args.push_to_hub: + save_model_card( + repo_id, + images=images, + base_model=args.pretrained_model_name_or_path, + train_text_encoder=args.train_text_encoder, + instance_prompt=args.instance_prompt, + validation_prompt=args.validation_prompt, + repo_folder=args.output_dir, + ) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + images = None + del pipeline + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) From d728168afa5eee47075470a6449a7ef9df930cac Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Tue, 8 Apr 2025 19:31:30 +0300 Subject: [PATCH 02/76] initial commit --- .../train_dreambooth_lora_hidream.py | 231 +++++++++++------- 1 file changed, 139 insertions(+), 92 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index 8040da7efd72..8dd2d0b8c42f 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -42,7 +42,7 @@ from torchvision import transforms from torchvision.transforms.functional import crop from tqdm.auto import tqdm -from transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast, LlamaForCausalLM +from transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast, PreTrainedTokenizerFast import diffusers from diffusers import ( @@ -165,10 +165,10 @@ def load_text_encoders(class_one, class_two, class_three, class_four): text_encoder_three = class_three.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder_3", revision=args.revision, variant=args.variant ) - # text_encoder_four = class_four.from_pretrained( - # args.pretrained_model_name_or_path, subfolder="text_encoder_4", revision=args.revision, variant=args.variant - # ) - return text_encoder_one, text_encoder_two, text_encoder_three + text_encoder_four = class_four.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder_4", revision=args.revision, variant=args.variant + ) + return text_encoder_one, text_encoder_two, text_encoder_three, text_encoder_four def log_validation( @@ -919,26 +919,53 @@ def _encode_prompt_with_llama3( prompt: str, max_sequence_length: int = 128, device=None, + attention_mask=None, text_input_ids=None, num_images_per_prompt: int = 1, ): - device = device or self._execution_device - dtype = dtype or self.text_encoder_4.dtype - prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) - text_inputs = self.tokenizer_4( - prompt, - padding="max_length", - max_length=min(max_sequence_length, self.tokenizer_4.model_max_length), - truncation=True, - add_special_tokens=True, - return_tensors="pt", + if tokenizer is not None: + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=min(max_sequence_length, tokenizer.model_max_length), + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + attention_mask = text_inputs.attention_mask + else: + if text_input_ids is None: + raise ValueError("text_input_ids must be provided when the tokenizer is not specified") + if attention_mask is None: + raise ValueError("attention_mask must be provided when the tokenizer is not specified") + + if hasattr(text_encoder, "module"): + dtype = text_encoder.module.dtype + else: + dtype = text_encoder.dtype + + + outputs = text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask.to(device), + output_hidden_states=True, + output_attentions=True, ) - text_input_ids = text_inputs.input_ids - attention_mask = text_inputs.attention_mask - untruncated_ids = self.tokenizer_4(prompt, padding="longest", return_tensors="pt").input_ids + + prompt_embeds = outputs.hidden_states[1:] + prompt_embeds = torch.stack(prompt_embeds, dim=0) + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + _, _, seq_len, dim = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, 1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(-1, batch_size * num_images_per_prompt, seq_len, dim) + return prompt_embeds + def _encode_prompt_with_t5( text_encoder, tokenizer, @@ -948,7 +975,6 @@ def _encode_prompt_with_t5( device=None, text_input_ids=None, ): - prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) if tokenizer is not None: @@ -989,16 +1015,16 @@ def _encode_prompt_with_clip( prompt: str, device=None, text_input_ids=None, + max_sequence_length=218, num_images_per_prompt: int = 1, ): - prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) if tokenizer is not None: text_inputs = tokenizer( prompt, padding="max_length", - max_length=77, + max_length=min(max_sequence_length, 218), truncation=True, return_overflowing_tokens=False, return_length=False, @@ -1016,6 +1042,7 @@ def _encode_prompt_with_clip( dtype = text_encoder.module.dtype else: dtype = text_encoder.dtype + # Use pooled output of CLIPTextModel prompt_embeds = prompt_embeds.pooler_output prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) @@ -1035,36 +1062,56 @@ def encode_prompt( device=None, num_images_per_prompt: int = 1, text_input_ids_list=None, + llama_attention_mask=None, ): prompt = [prompt] if isinstance(prompt, str) else prompt - if hasattr(text_encoders[0], "module"): - dtype = text_encoders[0].module.dtype - else: - dtype = text_encoders[0].dtype - - pooled_prompt_embeds = _encode_prompt_with_clip( + pooled_prompt_embeds_1 = _encode_prompt_with_clip( text_encoder=text_encoders[0], tokenizer=tokenizers[0], prompt=prompt, + max_sequence_length=max_sequence_length, device=device if device is not None else text_encoders[0].device, num_images_per_prompt=num_images_per_prompt, text_input_ids=text_input_ids_list[0] if text_input_ids_list else None, ) - prompt_embeds = _encode_prompt_with_t5( + pooled_prompt_embeds_2 = _encode_prompt_with_clip( text_encoder=text_encoders[1], tokenizer=tokenizers[1], - max_sequence_length=max_sequence_length, prompt=prompt, - num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, device=device if device is not None else text_encoders[1].device, + num_images_per_prompt=num_images_per_prompt, text_input_ids=text_input_ids_list[1] if text_input_ids_list else None, ) - text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_1, pooled_prompt_embeds_2], dim=-1) - return prompt_embeds, pooled_prompt_embeds, text_ids + t5_prompt_embeds = _encode_prompt_with_t5( + text_encoder=text_encoders[2], + tokenizer=tokenizers[2], + max_sequence_length=max_sequence_length, + prompt=prompt, + num_images_per_prompt=num_images_per_prompt, + device=device if device is not None else text_encoders[2].device, + text_input_ids=text_input_ids_list[2] if text_input_ids_list else None, + ) + + llama3_prompt_embeds = _encode_prompt_with_llama3( + text_encoder=text_encoders[3], + tokenizer=tokenizers[3], + max_sequence_length=max_sequence_length, + prompt=prompt, + num_images_per_prompt=num_images_per_prompt, + device=device if device is not None else text_encoders[3].device, + text_input_ids=text_input_ids_list[3] if text_input_ids_list else None, + attention_mask=llama_attention_mask + ) + + prompt_embeds = [t5_prompt_embeds, llama3_prompt_embeds] + + return prompt_embeds, pooled_prompt_embeds def main(args): @@ -1134,7 +1181,7 @@ def main(args): torch_dtype = torch.float16 elif args.prior_generation_precision == "bf16": torch_dtype = torch.bfloat16 - pipeline = FluxPipeline.from_pretrained( + pipeline = HiDreamImagePipeline.from_pretrained( args.pretrained_model_name_or_path, torch_dtype=torch_dtype, revision=args.revision, @@ -1182,11 +1229,21 @@ def main(args): subfolder="tokenizer", revision=args.revision, ) - tokenizer_two = T5TokenizerFast.from_pretrained( + tokenizer_two = CLIPTokenizer.from_pretrained( args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, ) + tokenizer_three = T5TokenizerFast.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer_3", + revision=args.revision, + ) + tokenizer_four = PreTrainedTokenizerFast.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer_4", + revision=args.revision, + ) # import correct text encoder classes text_encoder_cls_one = import_model_class_from_model_name_or_path( @@ -1196,19 +1253,30 @@ def main(args): args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2" ) + text_encoder_cls_three = import_model_class_from_model_name_or_path( + args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_3" + ) + + text_encoder_cls_four = import_model_class_from_model_name_or_path( + args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_4" + ) + # Load scheduler and models noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( args.pretrained_model_name_or_path, subfolder="scheduler" ) noise_scheduler_copy = copy.deepcopy(noise_scheduler) - text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two) + text_encoder_one, text_encoder_two, text_encoder_three, text_encoder_four = load_text_encoders(text_encoder_cls_one, + text_encoder_cls_two, + text_encoder_cls_three, + text_encoder_cls_four) vae = AutoencoderKL.from_pretrained( args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant, ) - transformer = FluxTransformer2DModel.from_pretrained( + transformer = HiDreamImageTransformer2DModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant ) @@ -1217,6 +1285,8 @@ def main(args): vae.requires_grad_(False) text_encoder_one.requires_grad_(False) text_encoder_two.requires_grad_(False) + text_encoder_three.requires_grad_(False) + text_encoder_four.requires_grad_(False) # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision # as these weights are only used for inference, keeping weights in full precision is not required. @@ -1236,6 +1306,8 @@ def main(args): transformer.to(accelerator.device, dtype=weight_dtype) text_encoder_one.to(accelerator.device, dtype=weight_dtype) text_encoder_two.to(accelerator.device, dtype=weight_dtype) + text_encoder_three.to(accelerator.device, dtype=weight_dtype) + text_encoder_four.to(accelerator.device, dtype=weight_dtype) if args.gradient_checkpointing: transformer.enable_gradient_checkpointing() @@ -1299,7 +1371,7 @@ def save_model_hook(models, weights, output_dir): # make sure to pop weight so that corresponding model is not saved again weights.pop() - FluxPipeline.save_lora_weights( + HiDreamImagePipeline.save_lora_weights( output_dir, transformer_lora_layers=transformer_lora_layers_to_save, text_encoder_lora_layers=text_encoder_one_lora_layers_to_save, @@ -1319,7 +1391,7 @@ def load_model_hook(models, input_dir): else: raise ValueError(f"unexpected save model: {model.__class__}") - lora_state_dict = FluxPipeline.lora_state_dict(input_dir) + lora_state_dict = HiDreamImagePipeline.lora_state_dict(input_dir) transformer_state_dict = { f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") @@ -1474,37 +1546,37 @@ def load_model_hook(models, input_dir): ) if not args.train_text_encoder: - tokenizers = [tokenizer_one, tokenizer_two] - text_encoders = [text_encoder_one, text_encoder_two] + tokenizers = [tokenizer_one, tokenizer_two, tokenizer_three, tokenizer_four] + text_encoders = [text_encoder_one, text_encoder_two, text_encoder_three, text_encoder_four] def compute_text_embeddings(prompt, text_encoders, tokenizers): with torch.no_grad(): - prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( + prompt_embeds, pooled_prompt_embeds = encode_prompt( text_encoders, tokenizers, prompt, args.max_sequence_length ) prompt_embeds = prompt_embeds.to(accelerator.device) pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device) - text_ids = text_ids.to(accelerator.device) - return prompt_embeds, pooled_prompt_embeds, text_ids + + return prompt_embeds, pooled_prompt_embeds # If no type of tuning is done on the text_encoder and custom instance prompts are NOT # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid # the redundant encoding. if not args.train_text_encoder and not train_dataset.custom_instance_prompts: - instance_prompt_hidden_states, instance_pooled_prompt_embeds, instance_text_ids = compute_text_embeddings( + instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings( args.instance_prompt, text_encoders, tokenizers ) # Handle class prompt for prior-preservation. if args.with_prior_preservation: if not args.train_text_encoder: - class_prompt_hidden_states, class_pooled_prompt_embeds, class_text_ids = compute_text_embeddings( + class_prompt_hidden_states, class_pooled_prompt_embeds = compute_text_embeddings( args.class_prompt, text_encoders, tokenizers ) # Clear the memory here if not args.train_text_encoder and not train_dataset.custom_instance_prompts: - del text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two + del text_encoder_one, text_encoder_two, text_encoder_three, text_encoder_four, tokenizer_one, tokenizer_two, tokenizer_three, tokenizer_four free_memory() # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), @@ -1515,11 +1587,10 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): if not args.train_text_encoder: prompt_embeds = instance_prompt_hidden_states pooled_prompt_embeds = instance_pooled_prompt_embeds - text_ids = instance_text_ids if args.with_prior_preservation: prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0) pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, class_pooled_prompt_embeds], dim=0) - text_ids = torch.cat([text_ids, class_text_ids], dim=0) + # if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts) # we need to tokenize and encode the batch prompts on all training steps else: @@ -1535,9 +1606,9 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0) tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0) - vae_config_shift_factor = vae.config.shift_factor + vae_config_scaling_factor = vae.config.scaling_factor - vae_config_block_out_channels = vae.config.block_out_channels + if args.cache_latents: latents_cache = [] for batch in tqdm(train_dataloader, desc="Caching latents"): @@ -1678,7 +1749,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # encode batch prompts when custom prompts are provided for each image - if train_dataset.custom_instance_prompts: if not args.train_text_encoder: - prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings( + prompt_embeds, pooled_prompt_embeds = compute_text_embeddings( prompts, text_encoders, tokenizers ) else: @@ -1686,7 +1757,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): tokens_two = tokenize_prompt( tokenizer_two, prompts, max_sequence_length=args.max_sequence_length ) - prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( + prompt_embeds, pooled_prompt_embeds = encode_prompt( text_encoders=[text_encoder_one, text_encoder_two], tokenizers=[None, None], text_input_ids_list=[tokens_one, tokens_two], @@ -1697,7 +1768,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): else: elems_to_repeat = len(prompts) if args.train_text_encoder: - prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( + prompt_embeds, pooled_prompt_embeds = encode_prompt( text_encoders=[text_encoder_one, text_encoder_two], tokenizers=[None, None], text_input_ids_list=[ @@ -1715,18 +1786,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): else: pixel_values = batch["pixel_values"].to(dtype=vae.dtype) model_input = vae.encode(pixel_values).latent_dist.sample() + model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor model_input = model_input.to(dtype=weight_dtype) vae_scale_factor = 2 ** (len(vae_config_block_out_channels) - 1) - latent_image_ids = FluxPipeline._prepare_latent_image_ids( - model_input.shape[0], - model_input.shape[2] // 2, - model_input.shape[3] // 2, - accelerator.device, - weight_dtype, - ) # Sample noise that we'll add to the latents noise = torch.randn_like(model_input) bsz = model_input.shape[0] @@ -1743,44 +1808,21 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device) - # Add noise according to flow matching. - # zt = (1 - texp) * x + texp * z1 + sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise - packed_noisy_model_input = FluxPipeline._pack_latents( - noisy_model_input, - batch_size=model_input.shape[0], - num_channels_latents=model_input.shape[1], - height=model_input.shape[2], - width=model_input.shape[3], - ) - - # handle guidance - if unwrap_model(transformer).config.guidance_embeds: - guidance = torch.tensor([args.guidance_scale], device=accelerator.device) - guidance = guidance.expand(model_input.shape[0]) - else: - guidance = None - # Predict the noise residual model_pred = transformer( hidden_states=packed_noisy_model_input, - # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) - timestep=timesteps / 1000, - guidance=guidance, + timestep=timesteps, pooled_projections=pooled_prompt_embeds, encoder_hidden_states=prompt_embeds, - txt_ids=text_ids, - img_ids=latent_image_ids, + img_sizes=img_sizes, + img_ids=img_ids, return_dict=False, )[0] - model_pred = FluxPipeline._unpack_latents( - model_pred, - height=model_input.shape[2] * vae_scale_factor, - width=model_input.shape[3] * vae_scale_factor, - vae_scale_factor=vae_scale_factor, - ) + model_pred = -model_pred # these weighting schemes use a uniform timestep sampling # and instead post-weight the loss @@ -1869,14 +1911,19 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if args.validation_prompt is not None and epoch % args.validation_epochs == 0: # create pipeline if not args.train_text_encoder: - text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two) + text_encoder_one, text_encoder_two, text_encoder_three, text_encoder_four = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three, text_encoder_cls_four) text_encoder_one.to(weight_dtype) text_encoder_two.to(weight_dtype) - pipeline = FluxPipeline.from_pretrained( + text_encoder_three.to(weight_dtype) + text_encoder_four.to(weight_dtype) + + pipeline = HiDreamImagePipeline.from_pretrained( args.pretrained_model_name_or_path, vae=vae, text_encoder=unwrap_model(text_encoder_one), text_encoder_2=unwrap_model(text_encoder_two), + text_encoder_3=unwrap_model(text_encoder_three), + text_encoder_4=unwrap_model(text_encoder_four), transformer=unwrap_model(transformer), revision=args.revision, variant=args.variant, @@ -1914,7 +1961,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): else: text_encoder_lora_layers = None - FluxPipeline.save_lora_weights( + HiDreamImagePipeline.save_lora_weights( save_directory=args.output_dir, transformer_lora_layers=transformer_lora_layers, text_encoder_lora_layers=text_encoder_lora_layers, @@ -1922,7 +1969,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # Final inference # Load previous pipeline - pipeline = FluxPipeline.from_pretrained( + pipeline = HiDreamImagePipeline.from_pretrained( args.pretrained_model_name_or_path, revision=args.revision, variant=args.variant, From 0fa099327af4a9fdcd00d7332154af712ab8daf2 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Thu, 10 Apr 2025 12:00:49 +0300 Subject: [PATCH 03/76] initial commit --- .../train_dreambooth_lora_hidream.py | 547 ++++++------------ 1 file changed, 165 insertions(+), 382 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index 8dd2d0b8c42f..9744edfa3330 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -42,19 +42,17 @@ from torchvision import transforms from torchvision.transforms.functional import crop from tqdm.auto import tqdm -from transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast, PreTrainedTokenizerFast +from transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast import diffusers from diffusers import ( AutoencoderKL, FlowMatchEulerDiscreteScheduler, - UniPCMultistepScheduler, HiDreamImagePipeline, HiDreamImageTransformer2DModel, ) from diffusers.optimization import get_scheduler from diffusers.training_utils import ( - _set_state_dict_into_text_encoder, cast_training_params, compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, @@ -66,6 +64,7 @@ is_wandb_available, ) from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card +from diffusers.utils.import_utils import is_torch_npu_available from diffusers.utils.torch_utils import is_compiled_module @@ -77,13 +76,16 @@ logger = get_logger(__name__) +if is_torch_npu_available(): + torch.npu.config.allow_internal_format = False + def save_model_card( repo_id: str, images=None, base_model: str = None, - train_text_encoder=False, instance_prompt=None, + system_prompt=None, validation_prompt=None, repo_folder=None, ): @@ -96,7 +98,7 @@ def save_model_card( ) model_description = f""" -# Hi Dream Image DreamBooth LoRA - {repo_id} +# HiDream Image DreamBooth LoRA - {repo_id} @@ -104,14 +106,15 @@ def save_model_card( These are {repo_id} DreamBooth LoRA weights for {base_model}. -The weights were trained using [DreamBooth](https://dreambooth.github.io/) with the [HiDream diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_hidream.md). +The weights were trained using [DreamBooth](https://dreambooth.github.io/) with the [HiDream Image diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_hidream.md). -Was LoRA for the text encoder enabled? {train_text_encoder}. ## Trigger words You should use `{instance_prompt}` to trigger the image generation. +The following `system_prompt` was also used used during training (ignore if `None`): {system_prompt}. + ## Download model [Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab. @@ -119,23 +122,15 @@ def save_model_card( ## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers) ```py -from diffusers import AutoPipelineForText2Image -import torch -pipeline = AutoPipelineForText2Image.from_pretrained("", torch_dtype=torch.bfloat16).to('cuda') -pipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors') -image = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0] +TODO ``` For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) - -## License - -Please adhere to the licensing terms as described [here](). """ model_card = load_or_create_model_card( repo_id_or_path=repo_id, from_training=True, - license="other", + license="apache-2.0", base_model=base_model, prompt=instance_prompt, model_description=model_description, @@ -154,7 +149,6 @@ def save_model_card( model_card = populate_model_card(model_card, tags=tags) model_card.save(os.path.join(repo_folder, "README.md")) - def load_text_encoders(class_one, class_two, class_three, class_four): text_encoder_one = class_one.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant @@ -170,27 +164,25 @@ def load_text_encoders(class_one, class_two, class_three, class_four): ) return text_encoder_one, text_encoder_two, text_encoder_three, text_encoder_four - def log_validation( pipeline, args, accelerator, pipeline_args, epoch, - torch_dtype, is_final_validation=False, ): logger.info( f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f" {args.validation_prompt}." ) + pipeline = pipeline.to(accelerator.device) pipeline.set_progress_bar_config(disable=True) # run inference generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None - # autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() - autocast_ctx = nullcontext() + autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() with autocast_ctx: images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)] @@ -204,7 +196,7 @@ def log_validation( tracker.log( { phase_name: [ - wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) + wandb.Image(image, caption=f"{i}: {pipeline_args['prompt']}") for i, image in enumerate(images) ] } ) @@ -215,18 +207,17 @@ def log_validation( return images - def import_model_class_from_model_name_or_path( - pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" + pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" ): text_encoder_config = PretrainedConfig.from_pretrained( pretrained_model_name_or_path, subfolder=subfolder, revision=revision ) model_class = text_encoder_config.architectures[0] if model_class == "CLIPTextModelWithProjection": - from transformers import CLIPTextModelWithProjection + from transformers import CLIPTextModel - return CLIPTextModelWithProjection + return CLIPTextModel elif model_class == "T5EncoderModel": from transformers import T5EncoderModel @@ -238,7 +229,6 @@ def import_model_class_from_model_name_or_path( else: raise ValueError(f"{model_class} is not supported.") - def parse_args(input_args=None): parser = argparse.ArgumentParser(description="Simple example of a training script.") parser.add_argument( @@ -331,8 +321,14 @@ def parse_args(input_args=None): parser.add_argument( "--max_sequence_length", type=int, - default=512, - help="Maximum sequence length to use with with the T5 text encoder", + default=256, + help="Maximum sequence length to use with with the Gemma2 model", + ) + parser.add_argument( + "--system_prompt", + type=str, + default=None, + help="System prompt to use during inference to give the Gemma2 model certain characteristics.", ) parser.add_argument( "--validation_prompt", @@ -340,6 +336,12 @@ def parse_args(input_args=None): default=None, help="A prompt that is used during validation to verify that the model is learning.", ) + parser.add_argument( + "--final_validation_prompt", + type=str, + default=None, + help="A prompt that is used during a final validation to verify that the model is learning. Ignored if `--validation_prompt` is provided.", + ) parser.add_argument( "--num_validation_images", type=int, @@ -380,7 +382,7 @@ def parse_args(input_args=None): parser.add_argument( "--output_dir", type=str, - default="flux-dreambooth-lora", + default="lumina2-dreambooth-lora", help="The output directory where the model predictions and checkpoints will be written.", ) parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") @@ -407,11 +409,6 @@ def parse_args(input_args=None): action="store_true", help="whether to randomly flip images horizontally", ) - parser.add_argument( - "--train_text_encoder", - action="store_true", - help="Whether to train the text encoder. If set, the text encoder should be float32 precision.", - ) parser.add_argument( "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." ) @@ -467,20 +464,6 @@ def parse_args(input_args=None): default=1e-4, help="Initial learning rate (after the potential warmup period) to use.", ) - - parser.add_argument( - "--guidance_scale", - type=float, - default=3.5, - help="the FLUX.1 dev variant is a guidance distilled model", - ) - - parser.add_argument( - "--text_encoder_lr", - type=float, - default=5e-6, - help="Text encoder learning rate to use.", - ) parser.add_argument( "--scale_lr", action="store_true", @@ -561,16 +544,12 @@ def parse_args(input_args=None): ) parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay") parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params") - parser.add_argument( - "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder" - ) - parser.add_argument( "--lora_layers", type=str, default=None, help=( - 'The transformer modules to apply LoRA training on. Please specify the layers in a comma seperated. E.g. - "to_k,to_q,to_v,to_out.0" will result in lora training of attention layers only' + 'The transformer modules to apply LoRA training on. Please specify the layers in a comma seperated. E.g. - "to_k,to_q,to_v" will result in lora training of attention layers only' ), ) @@ -656,14 +635,9 @@ def parse_args(input_args=None): ), ) parser.add_argument( - "--prior_generation_precision", - type=str, - default=None, - choices=["no", "fp32", "fp16", "bf16"], - help=( - "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" - " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32." - ), + "--offload", + action="store_true", + help="Whether to offload the VAE and the text encoder to CPU when they are not used.", ) parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") @@ -899,31 +873,17 @@ def __getitem__(self, index): example["index"] = index return example - -def tokenize_prompt(tokenizer, prompt, max_sequence_length): - text_inputs = tokenizer( - prompt, - padding="max_length", - max_length=max_sequence_length, - truncation=True, - return_length=False, - return_overflowing_tokens=False, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - return text_input_ids - -def _encode_prompt_with_llama3( +def _encode_prompt_with_llama( text_encoder, tokenizer, - prompt: str, - max_sequence_length: int = 128, + max_sequence_length=128, + prompt=None, + num_images_per_prompt=1, device=None, - attention_mask=None, text_input_ids=None, - num_images_per_prompt: int = 1, + attention_mask=None, ): - + prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) if tokenizer is not None: @@ -935,30 +895,30 @@ def _encode_prompt_with_llama3( add_special_tokens=True, return_tensors="pt", ) + text_input_ids = text_inputs.input_ids attention_mask = text_inputs.attention_mask + else: if text_input_ids is None: raise ValueError("text_input_ids must be provided when the tokenizer is not specified") if attention_mask is None: - raise ValueError("attention_mask must be provided when the tokenizer is not specified") - - if hasattr(text_encoder, "module"): - dtype = text_encoder.module.dtype - else: - dtype = text_encoder.dtype - + raise ValueError("text_input_ids must be provided when the tokenizer is not specified") - outputs = text_encoder( + outputs = self.text_encoder_4( text_input_ids.to(device), attention_mask=attention_mask.to(device), output_hidden_states=True, output_attentions=True, ) - prompt_embeds = outputs.hidden_states[1:] + if hasattr(text_encoder, "module"): + dtype = text_encoder.module.dtype + else: + dtype = text_encoder.dtype + + prompt_embeds = outputs.hidden_states[1:].to(dtype=dtype, device=device) prompt_embeds = torch.stack(prompt_embeds, dim=0) - prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) _, _, seq_len, dim = prompt_embeds.shape # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method @@ -966,40 +926,44 @@ def _encode_prompt_with_llama3( prompt_embeds = prompt_embeds.view(-1, batch_size * num_images_per_prompt, seq_len, dim) return prompt_embeds + def _encode_prompt_with_t5( text_encoder, tokenizer, - max_sequence_length=512, + max_sequence_length=128, prompt=None, num_images_per_prompt=1, device=None, text_input_ids=None, + attention_mask=None, ): + prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) if tokenizer is not None: text_inputs = tokenizer( prompt, padding="max_length", - max_length=max_sequence_length, + max_length=min(max_sequence_length, tokenizer.model_max_length), truncation=True, - return_length=False, - return_overflowing_tokens=False, + add_special_tokens=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids + attention_mask = text_inputs.attention_mask else: if text_input_ids is None: raise ValueError("text_input_ids must be provided when the tokenizer is not specified") + if attention_mask is None: + raise ValueError("text_input_ids must be provided when the tokenizer is not specified") - prompt_embeds = text_encoder(text_input_ids.to(device))[0] + prompt_embeds = text_encoder(text_input_ids.to(device), attention_mask=attention_mask.to(device))[0] if hasattr(text_encoder, "module"): dtype = text_encoder.module.dtype else: dtype = text_encoder.dtype prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) - _, seq_len, _ = prompt_embeds.shape # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method @@ -1013,11 +977,12 @@ def _encode_prompt_with_clip( text_encoder, tokenizer, prompt: str, + max_sequence_length=128, device=None, text_input_ids=None, - max_sequence_length=218, num_images_per_prompt: int = 1, ): + prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) if tokenizer is not None: @@ -1026,8 +991,6 @@ def _encode_prompt_with_clip( padding="max_length", max_length=min(max_sequence_length, 218), truncation=True, - return_overflowing_tokens=False, - return_length=False, return_tensors="pt", ) @@ -1036,15 +999,14 @@ def _encode_prompt_with_clip( if text_input_ids is None: raise ValueError("text_input_ids must be provided when the tokenizer is not specified") - prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False) + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) if hasattr(text_encoder, "module"): dtype = text_encoder.module.dtype else: dtype = text_encoder.dtype - # Use pooled output of CLIPTextModel - prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) # duplicate text embeddings for each generation per prompt, using mps friendly method @@ -1062,15 +1024,19 @@ def encode_prompt( device=None, num_images_per_prompt: int = 1, text_input_ids_list=None, - llama_attention_mask=None, + attention_mask_list=None, ): prompt = [prompt] if isinstance(prompt, str) else prompt + if hasattr(text_encoders[0], "module"): + dtype = text_encoders[0].module.dtype + else: + dtype = text_encoders[0].dtype + pooled_prompt_embeds_1 = _encode_prompt_with_clip( text_encoder=text_encoders[0], tokenizer=tokenizers[0], prompt=prompt, - max_sequence_length=max_sequence_length, device=device if device is not None else text_encoders[0].device, num_images_per_prompt=num_images_per_prompt, text_input_ids=text_input_ids_list[0] if text_input_ids_list else None, @@ -1080,7 +1046,6 @@ def encode_prompt( text_encoder=text_encoders[1], tokenizer=tokenizers[1], prompt=prompt, - max_sequence_length=max_sequence_length, device=device if device is not None else text_encoders[1].device, num_images_per_prompt=num_images_per_prompt, text_input_ids=text_input_ids_list[1] if text_input_ids_list else None, @@ -1096,9 +1061,10 @@ def encode_prompt( num_images_per_prompt=num_images_per_prompt, device=device if device is not None else text_encoders[2].device, text_input_ids=text_input_ids_list[2] if text_input_ids_list else None, + attention_mask=attention_mask_list[0] if attention_mask_list else None, ) - llama3_prompt_embeds = _encode_prompt_with_llama3( + llama3_prompt_embeds = _encode_prompt_with_llama( text_encoder=text_encoders[3], tokenizer=tokenizers[3], max_sequence_length=max_sequence_length, @@ -1106,14 +1072,13 @@ def encode_prompt( num_images_per_prompt=num_images_per_prompt, device=device if device is not None else text_encoders[3].device, text_input_ids=text_input_ids_list[3] if text_input_ids_list else None, - attention_mask=llama_attention_mask + attention_mask=attention_mask_list[1] if attention_mask_list else None, ) prompt_embeds = [t5_prompt_embeds, llama3_prompt_embeds] return prompt_embeds, pooled_prompt_embeds - def main(args): if args.report_to == "wandb" and args.hub_token is not None: raise ValueError( @@ -1173,17 +1138,9 @@ def main(args): cur_class_images = len(list(class_images_dir.iterdir())) if cur_class_images < args.num_class_images: - has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available() - torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32 - if args.prior_generation_precision == "fp32": - torch_dtype = torch.float32 - elif args.prior_generation_precision == "fp16": - torch_dtype = torch.float16 - elif args.prior_generation_precision == "bf16": - torch_dtype = torch.bfloat16 pipeline = HiDreamImagePipeline.from_pretrained( args.pretrained_model_name_or_path, - torch_dtype=torch_dtype, + torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16, revision=args.revision, variant=args.variant, ) @@ -1209,8 +1166,7 @@ def main(args): image.save(image_filename) del pipeline - if torch.cuda.is_available(): - torch.cuda.empty_cache() + free_memory() # Handle the repository creation if accelerator.is_main_process: @@ -1223,70 +1179,35 @@ def main(args): exist_ok=True, ).repo_id - # Load the tokenizers - tokenizer_one = CLIPTokenizer.from_pretrained( + # Load the tokenizer + tokenizer = AutoTokenizer.from_pretrained( args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, ) - tokenizer_two = CLIPTokenizer.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="tokenizer_2", - revision=args.revision, - ) - tokenizer_three = T5TokenizerFast.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="tokenizer_3", - revision=args.revision, - ) - tokenizer_four = PreTrainedTokenizerFast.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="tokenizer_4", - revision=args.revision, - ) - - # import correct text encoder classes - text_encoder_cls_one = import_model_class_from_model_name_or_path( - args.pretrained_model_name_or_path, args.revision - ) - text_encoder_cls_two = import_model_class_from_model_name_or_path( - args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2" - ) - - text_encoder_cls_three = import_model_class_from_model_name_or_path( - args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_3" - ) - - text_encoder_cls_four = import_model_class_from_model_name_or_path( - args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_4" - ) # Load scheduler and models noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( - args.pretrained_model_name_or_path, subfolder="scheduler" + args.pretrained_model_name_or_path, subfolder="scheduler", revision=args.revision ) noise_scheduler_copy = copy.deepcopy(noise_scheduler) - text_encoder_one, text_encoder_two, text_encoder_three, text_encoder_four = load_text_encoders(text_encoder_cls_one, - text_encoder_cls_two, - text_encoder_cls_three, - text_encoder_cls_four) + text_encoder = Gemma2Model.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant + ) vae = AutoencoderKL.from_pretrained( args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant, ) - transformer = HiDreamImageTransformer2DModel.from_pretrained( + transformer = Lumina2Transformer2DModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant ) # We only train the additional adapter LoRA layers transformer.requires_grad_(False) vae.requires_grad_(False) - text_encoder_one.requires_grad_(False) - text_encoder_two.requires_grad_(False) - text_encoder_three.requires_grad_(False) - text_encoder_four.requires_grad_(False) + text_encoder.requires_grad_(False) # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision # as these weights are only used for inference, keeping weights in full precision is not required. @@ -1302,35 +1223,28 @@ def main(args): "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." ) - vae.to(accelerator.device, dtype=weight_dtype) + # keep VAE in FP32 to ensure numerical stability. + vae.to(dtype=torch.float32) transformer.to(accelerator.device, dtype=weight_dtype) - text_encoder_one.to(accelerator.device, dtype=weight_dtype) - text_encoder_two.to(accelerator.device, dtype=weight_dtype) - text_encoder_three.to(accelerator.device, dtype=weight_dtype) - text_encoder_four.to(accelerator.device, dtype=weight_dtype) + # because Gemma2 is particularly suited for bfloat16. + text_encoder.to(dtype=torch.bfloat16) + + # Initialize a text encoding pipeline and keep it to CPU for now. + text_encoding_pipeline = HiDreamImagePipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=None, + transformer=None, + text_encoder=text_encoder, + tokenizer=tokenizer, + ) if args.gradient_checkpointing: transformer.enable_gradient_checkpointing() - if args.train_text_encoder: - text_encoder_one.gradient_checkpointing_enable() if args.lora_layers is not None: target_modules = [layer.strip() for layer in args.lora_layers.split(",")] else: - target_modules = [ - "attn.to_k", - "attn.to_q", - "attn.to_v", - "attn.to_out.0", - "attn.add_k_proj", - "attn.add_q_proj", - "attn.add_v_proj", - "attn.to_add_out", - "ff.net.0.proj", - "ff.net.2", - "ff_context.net.0.proj", - "ff_context.net.2", - ] + target_modules = ["to_k", "to_q", "to_v", "to_out.0"] # now we will add new LoRA weights the transformer layers transformer_lora_config = LoraConfig( @@ -1340,14 +1254,6 @@ def main(args): target_modules=target_modules, ) transformer.add_adapter(transformer_lora_config) - if args.train_text_encoder: - text_lora_config = LoraConfig( - r=args.rank, - lora_alpha=args.rank, - init_lora_weights="gaussian", - target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], - ) - text_encoder_one.add_adapter(text_lora_config) def unwrap_model(model): model = accelerator.unwrap_model(model) @@ -1358,13 +1264,10 @@ def unwrap_model(model): def save_model_hook(models, weights, output_dir): if accelerator.is_main_process: transformer_lora_layers_to_save = None - text_encoder_one_lora_layers_to_save = None for model in models: if isinstance(model, type(unwrap_model(transformer))): transformer_lora_layers_to_save = get_peft_model_state_dict(model) - elif isinstance(model, type(unwrap_model(text_encoder_one))): - text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model) else: raise ValueError(f"unexpected save model: {model.__class__}") @@ -1374,20 +1277,16 @@ def save_model_hook(models, weights, output_dir): HiDreamImagePipeline.save_lora_weights( output_dir, transformer_lora_layers=transformer_lora_layers_to_save, - text_encoder_lora_layers=text_encoder_one_lora_layers_to_save, ) def load_model_hook(models, input_dir): transformer_ = None - text_encoder_one_ = None while len(models) > 0: model = models.pop() if isinstance(model, type(unwrap_model(transformer))): transformer_ = model - elif isinstance(model, type(unwrap_model(text_encoder_one))): - text_encoder_one_ = model else: raise ValueError(f"unexpected save model: {model.__class__}") @@ -1406,17 +1305,12 @@ def load_model_hook(models, input_dir): f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " f" {unexpected_keys}. " ) - if args.train_text_encoder: - # Do we need to call `scale_lora_layers()` here? - _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_) # Make sure the trainable params are in float32. This is again needed since the base models # are in `weight_dtype`. More details: # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 if args.mixed_precision == "fp16": models = [transformer_] - if args.train_text_encoder: - models.extend([text_encoder_one_]) # only upcast trainable parameters (LoRA) into fp32 cast_training_params(models) @@ -1436,27 +1330,14 @@ def load_model_hook(models, input_dir): # Make sure the trainable params are in float32. if args.mixed_precision == "fp16": models = [transformer] - if args.train_text_encoder: - models.extend([text_encoder_one]) # only upcast trainable parameters (LoRA) into fp32 cast_training_params(models, dtype=torch.float32) transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters())) - if args.train_text_encoder: - text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters())) # Optimization parameters transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate} - if args.train_text_encoder: - # different learning rate for text encoder and unet - text_parameters_one_with_lr = { - "params": text_lora_parameters_one, - "weight_decay": args.adam_weight_decay_text_encoder, - "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate, - } - params_to_optimize = [transformer_parameters_with_lr, text_parameters_one_with_lr] - else: - params_to_optimize = [transformer_parameters_with_lr] + params_to_optimize = [transformer_parameters_with_lr] # Optimizer creation if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"): @@ -1504,15 +1385,6 @@ def load_model_hook(models, input_dir): logger.warning( "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" ) - if args.train_text_encoder and args.text_encoder_lr: - logger.warning( - f"Learning rates were provided both for the transformer and the text encoder- e.g. text_encoder_lr:" - f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. " - f"When using prodigy only learning_rate is used as the initial learning rate." - ) - # changes the learning rate of text_encoder_parameters_one to be - # --learning_rate - params_to_optimize[1]["lr"] = args.learning_rate optimizer = optimizer_class( params_to_optimize, @@ -1545,76 +1417,57 @@ def load_model_hook(models, input_dir): num_workers=args.dataloader_num_workers, ) - if not args.train_text_encoder: - tokenizers = [tokenizer_one, tokenizer_two, tokenizer_three, tokenizer_four] - text_encoders = [text_encoder_one, text_encoder_two, text_encoder_three, text_encoder_four] - - def compute_text_embeddings(prompt, text_encoders, tokenizers): - with torch.no_grad(): - prompt_embeds, pooled_prompt_embeds = encode_prompt( - text_encoders, tokenizers, prompt, args.max_sequence_length - ) - prompt_embeds = prompt_embeds.to(accelerator.device) - pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device) - - return prompt_embeds, pooled_prompt_embeds + def compute_text_embeddings(prompt, text_encoding_pipeline): + text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device) + with torch.no_grad(): + prompt_embeds, prompt_attention_mask, _, _ = text_encoding_pipeline.encode_prompt( + prompt, + max_sequence_length=args.max_sequence_length, + system_prompt=args.system_prompt, + ) + if args.offload: + text_encoding_pipeline = text_encoding_pipeline.to("cpu") + prompt_embeds = prompt_embeds.to(transformer.dtype) + return prompt_embeds, prompt_attention_mask # If no type of tuning is done on the text_encoder and custom instance prompts are NOT # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid # the redundant encoding. - if not args.train_text_encoder and not train_dataset.custom_instance_prompts: - instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings( - args.instance_prompt, text_encoders, tokenizers + if not train_dataset.custom_instance_prompts: + instance_prompt_hidden_states, instance_prompt_attention_mask = compute_text_embeddings( + args.instance_prompt, text_encoding_pipeline ) # Handle class prompt for prior-preservation. if args.with_prior_preservation: - if not args.train_text_encoder: - class_prompt_hidden_states, class_pooled_prompt_embeds = compute_text_embeddings( - args.class_prompt, text_encoders, tokenizers - ) + class_prompt_hidden_states, class_prompt_attention_mask = compute_text_embeddings( + args.class_prompt, text_encoding_pipeline + ) # Clear the memory here - if not args.train_text_encoder and not train_dataset.custom_instance_prompts: - del text_encoder_one, text_encoder_two, text_encoder_three, text_encoder_four, tokenizer_one, tokenizer_two, tokenizer_three, tokenizer_four + if not train_dataset.custom_instance_prompts: + del text_encoder, tokenizer free_memory() # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), # pack the statically computed variables appropriately here. This is so that we don't # have to pass them to the dataloader. - if not train_dataset.custom_instance_prompts: - if not args.train_text_encoder: - prompt_embeds = instance_prompt_hidden_states - pooled_prompt_embeds = instance_pooled_prompt_embeds - if args.with_prior_preservation: - prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0) - pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, class_pooled_prompt_embeds], dim=0) - - # if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts) - # we need to tokenize and encode the batch prompts on all training steps - else: - tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt, max_sequence_length=77) - tokens_two = tokenize_prompt( - tokenizer_two, args.instance_prompt, max_sequence_length=args.max_sequence_length - ) - if args.with_prior_preservation: - class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt, max_sequence_length=77) - class_tokens_two = tokenize_prompt( - tokenizer_two, args.class_prompt, max_sequence_length=args.max_sequence_length - ) - tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0) - tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0) - + prompt_embeds = instance_prompt_hidden_states + prompt_attention_mask = instance_prompt_attention_mask + if args.with_prior_preservation: + prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0) + prompt_attention_mask = torch.cat([prompt_attention_mask, class_prompt_attention_mask], dim=0) vae_config_scaling_factor = vae.config.scaling_factor - + vae_config_shift_factor = vae.config.shift_factor if args.cache_latents: latents_cache = [] + vae = vae.to(accelerator.device) for batch in tqdm(train_dataloader, desc="Caching latents"): with torch.no_grad(): batch["pixel_values"] = batch["pixel_values"].to( - accelerator.device, non_blocking=True, dtype=weight_dtype + accelerator.device, non_blocking=True, dtype=vae.dtype ) latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist) @@ -1639,24 +1492,9 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): ) # Prepare everything with our `accelerator`. - if args.train_text_encoder: - ( - transformer, - text_encoder_one, - optimizer, - train_dataloader, - lr_scheduler, - ) = accelerator.prepare( - transformer, - text_encoder_one, - optimizer, - train_dataloader, - lr_scheduler, - ) - else: - transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - transformer, optimizer, train_dataloader, lr_scheduler - ) + transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + transformer, optimizer, train_dataloader, lr_scheduler + ) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) @@ -1668,7 +1506,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. if accelerator.is_main_process: - tracker_name = "dreambooth-flux-dev-lora" + tracker_name = "dreambooth-lumina2-lora" accelerator.init_trackers(tracker_name, config=vars(args)) # Train! @@ -1734,64 +1572,28 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): for epoch in range(first_epoch, args.num_train_epochs): transformer.train() - if args.train_text_encoder: - text_encoder_one.train() - # set top parameter requires_grad = True for gradient checkpointing works - unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True) for step, batch in enumerate(train_dataloader): models_to_accumulate = [transformer] - if args.train_text_encoder: - models_to_accumulate.extend([text_encoder_one]) - with accelerator.accumulate(models_to_accumulate): - prompts = batch["prompts"] + prompts = batch["prompts"] + with accelerator.accumulate(models_to_accumulate): # encode batch prompts when custom prompts are provided for each image - if train_dataset.custom_instance_prompts: - if not args.train_text_encoder: - prompt_embeds, pooled_prompt_embeds = compute_text_embeddings( - prompts, text_encoders, tokenizers - ) - else: - tokens_one = tokenize_prompt(tokenizer_one, prompts, max_sequence_length=77) - tokens_two = tokenize_prompt( - tokenizer_two, prompts, max_sequence_length=args.max_sequence_length - ) - prompt_embeds, pooled_prompt_embeds = encode_prompt( - text_encoders=[text_encoder_one, text_encoder_two], - tokenizers=[None, None], - text_input_ids_list=[tokens_one, tokens_two], - max_sequence_length=args.max_sequence_length, - device=accelerator.device, - prompt=prompts, - ) - else: - elems_to_repeat = len(prompts) - if args.train_text_encoder: - prompt_embeds, pooled_prompt_embeds = encode_prompt( - text_encoders=[text_encoder_one, text_encoder_two], - tokenizers=[None, None], - text_input_ids_list=[ - tokens_one.repeat(elems_to_repeat, 1), - tokens_two.repeat(elems_to_repeat, 1), - ], - max_sequence_length=args.max_sequence_length, - device=accelerator.device, - prompt=args.instance_prompt, - ) + prompt_embeds, prompt_attention_mask = compute_text_embeddings(prompts, text_encoding_pipeline) # Convert images to latent space if args.cache_latents: model_input = latents_cache[step].sample() else: + vae = vae.to(accelerator.device) pixel_values = batch["pixel_values"].to(dtype=vae.dtype) model_input = vae.encode(pixel_values).latent_dist.sample() - + if args.offload: + vae = vae.to("cpu") model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor model_input = model_input.to(dtype=weight_dtype) - vae_scale_factor = 2 ** (len(vae_config_block_out_channels) - 1) - # Sample noise that we'll add to the latents noise = torch.randn_like(model_input) bsz = model_input.shape[0] @@ -1808,28 +1610,33 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device) - + # Add noise according to flow matching. + # zt = (1 - texp) * x + texp * z1 + # Lumina2 reverses the lerp i.e., sigma of 1.0 should mean `model_input` sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) - noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise + noisy_model_input = (1.0 - sigmas) * noise + sigmas * model_input # Predict the noise residual + # scale the timesteps (reversal not needed as we used a reverse lerp above already) + timesteps = timesteps / noise_scheduler.config.num_train_timesteps model_pred = transformer( - hidden_states=packed_noisy_model_input, + hidden_states=noisy_model_input, + encoder_hidden_states=prompt_embeds.repeat(len(prompts), 1, 1) + if not train_dataset.custom_instance_prompts + else prompt_embeds, + encoder_attention_mask=prompt_attention_mask.repeat(len(prompts), 1) + if not train_dataset.custom_instance_prompts + else prompt_attention_mask, timestep=timesteps, - pooled_projections=pooled_prompt_embeds, - encoder_hidden_states=prompt_embeds, - img_sizes=img_sizes, - img_ids=img_ids, return_dict=False, )[0] - model_pred = -model_pred # these weighting schemes use a uniform timestep sampling # and instead post-weight the loss weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) - # flow matching loss - target = noise - model_input + # flow matching loss (reversed) + target = model_input - noise if args.with_prior_preservation: # Chunk the noise and model_pred into two parts and compute the loss on each part separately. @@ -1858,11 +1665,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): accelerator.backward(loss) if accelerator.sync_gradients: - params_to_clip = ( - itertools.chain(transformer.parameters(), text_encoder_one.parameters()) - if args.train_text_encoder - else transformer.parameters() - ) + params_to_clip = transformer.parameters() accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() @@ -1910,37 +1713,22 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if accelerator.is_main_process: if args.validation_prompt is not None and epoch % args.validation_epochs == 0: # create pipeline - if not args.train_text_encoder: - text_encoder_one, text_encoder_two, text_encoder_three, text_encoder_four = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three, text_encoder_cls_four) - text_encoder_one.to(weight_dtype) - text_encoder_two.to(weight_dtype) - text_encoder_three.to(weight_dtype) - text_encoder_four.to(weight_dtype) - pipeline = HiDreamImagePipeline.from_pretrained( args.pretrained_model_name_or_path, - vae=vae, - text_encoder=unwrap_model(text_encoder_one), - text_encoder_2=unwrap_model(text_encoder_two), - text_encoder_3=unwrap_model(text_encoder_three), - text_encoder_4=unwrap_model(text_encoder_four), - transformer=unwrap_model(transformer), + transformer=accelerator.unwrap_model(transformer), revision=args.revision, variant=args.variant, torch_dtype=weight_dtype, ) - pipeline_args = {"prompt": args.validation_prompt} + pipeline_args = {"prompt": args.validation_prompt, "system_prompt": args.system_prompt} images = log_validation( pipeline=pipeline, args=args, accelerator=accelerator, pipeline_args=pipeline_args, epoch=epoch, - torch_dtype=weight_dtype, ) - if not args.train_text_encoder: - del text_encoder_one, text_encoder_two - free_memory() + free_memory() images = None del pipeline @@ -1955,16 +1743,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): transformer = transformer.to(weight_dtype) transformer_lora_layers = get_peft_model_state_dict(transformer) - if args.train_text_encoder: - text_encoder_one = unwrap_model(text_encoder_one) - text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32)) - else: - text_encoder_lora_layers = None - HiDreamImagePipeline.save_lora_weights( save_directory=args.output_dir, transformer_lora_layers=transformer_lora_layers, - text_encoder_lora_layers=text_encoder_lora_layers, ) # Final inference @@ -1980,8 +1761,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # run inference images = [] - if args.validation_prompt and args.num_validation_images > 0: - pipeline_args = {"prompt": args.validation_prompt} + if (args.validation_prompt and args.num_validation_images > 0) or (args.final_validation_prompt): + prompt_to_use = args.validation_prompt if args.validation_prompt else args.final_validation_prompt + args.num_validation_images = args.num_validation_images if args.num_validation_images else 1 + pipeline_args = {"prompt": prompt_to_use, "system_prompt": args.system_prompt} images = log_validation( pipeline=pipeline, args=args, @@ -1989,17 +1772,17 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): pipeline_args=pipeline_args, epoch=epoch, is_final_validation=True, - torch_dtype=weight_dtype, ) if args.push_to_hub: + validation_prpmpt = args.validation_prompt if args.validation_prompt else args.final_validation_prompt save_model_card( repo_id, images=images, base_model=args.pretrained_model_name_or_path, - train_text_encoder=args.train_text_encoder, instance_prompt=args.instance_prompt, - validation_prompt=args.validation_prompt, + system_prompt=args.system_prompt, + validation_prompt=validation_prpmpt, repo_folder=args.output_dir, ) upload_folder( From 5ecf0ed88b0387b4ceda2ab94f0be36fd6f742b7 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Thu, 10 Apr 2025 12:39:30 +0300 Subject: [PATCH 04/76] initial commit --- .../train_dreambooth_lora_hidream.py | 129 +++++++++++------- 1 file changed, 82 insertions(+), 47 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index 9744edfa3330..0404411ecc0b 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -1028,11 +1028,6 @@ def encode_prompt( ): prompt = [prompt] if isinstance(prompt, str) else prompt - if hasattr(text_encoders[0], "module"): - dtype = text_encoders[0].module.dtype - else: - dtype = text_encoders[0].dtype - pooled_prompt_embeds_1 = _encode_prompt_with_clip( text_encoder=text_encoders[0], tokenizer=tokenizers[0], @@ -1179,21 +1174,50 @@ def main(args): exist_ok=True, ).repo_id - # Load the tokenizer - tokenizer = AutoTokenizer.from_pretrained( + # Load the tokenizers + tokenizer_one = CLIPTokenizer.from_pretrained( args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, ) + tokenizer_two = CLIPTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer_2", + revision=args.revision, + ) + tokenizer_three = T5TokenizerFast.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer_3", + revision=args.revision, + ) + + tokenizer_four = PreTrainedTokenizerFast.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer_4", + revision=args.revision, + ) + + # import correct text encoder classes + text_encoder_cls_one = import_model_class_from_model_name_or_path( + args.pretrained_model_name_or_path, args.revision + ) + text_encoder_cls_two = import_model_class_from_model_name_or_path( + args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2" + ) + text_encoder_cls_three = import_model_class_from_model_name_or_path( + args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_3" + ) + text_encoder_cls_four = import_model_class_from_model_name_or_path( + args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_4" + ) # Load scheduler and models noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( args.pretrained_model_name_or_path, subfolder="scheduler", revision=args.revision ) noise_scheduler_copy = copy.deepcopy(noise_scheduler) - text_encoder = Gemma2Model.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant - ) + text_encoder_one, text_encoder_two, text_encoder_three, text_encoder_four = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three, text_encoder_cls_four) + vae = AutoencoderKL.from_pretrained( args.pretrained_model_name_or_path, subfolder="vae", @@ -1207,7 +1231,10 @@ def main(args): # We only train the additional adapter LoRA layers transformer.requires_grad_(False) vae.requires_grad_(False) - text_encoder.requires_grad_(False) + text_encoder_one.requires_grad_(False) + text_encoder_two.requires_grad_(False) + text_encoder_three.requires_grad_(False) + text_encoder_four.requires_grad_(False) # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision # as these weights are only used for inference, keeping weights in full precision is not required. @@ -1226,17 +1253,10 @@ def main(args): # keep VAE in FP32 to ensure numerical stability. vae.to(dtype=torch.float32) transformer.to(accelerator.device, dtype=weight_dtype) - # because Gemma2 is particularly suited for bfloat16. - text_encoder.to(dtype=torch.bfloat16) - - # Initialize a text encoding pipeline and keep it to CPU for now. - text_encoding_pipeline = HiDreamImagePipeline.from_pretrained( - args.pretrained_model_name_or_path, - vae=None, - transformer=None, - text_encoder=text_encoder, - tokenizer=tokenizer, - ) + text_encoder_one.to(accelerator.device, dtype=weight_dtype) + text_encoder_two.to(accelerator.device, dtype=weight_dtype) + text_encoder_three.to(accelerator.device, dtype=weight_dtype) + text_encoder_four.to(accelerator.device, dtype=weight_dtype) if args.gradient_checkpointing: transformer.enable_gradient_checkpointing() @@ -1417,36 +1437,34 @@ def load_model_hook(models, input_dir): num_workers=args.dataloader_num_workers, ) - def compute_text_embeddings(prompt, text_encoding_pipeline): - text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device) + tokenizers = [tokenizer_one, tokenizer_two, tokenizer_three, tokenizer_four] + text_encoders = [text_encoder_one, text_encoder_two, text_encoder_three, text_encoder_four] + def compute_text_embeddings(prompt, text_encoders, tokenizers): with torch.no_grad(): - prompt_embeds, prompt_attention_mask, _, _ = text_encoding_pipeline.encode_prompt( - prompt, - max_sequence_length=args.max_sequence_length, - system_prompt=args.system_prompt, + prompt_embeds, pooled_prompt_embeds = encode_prompt( + text_encoders, tokenizers, prompt, args.max_sequence_length ) - if args.offload: - text_encoding_pipeline = text_encoding_pipeline.to("cpu") - prompt_embeds = prompt_embeds.to(transformer.dtype) - return prompt_embeds, prompt_attention_mask + prompt_embeds = prompt_embeds.to(accelerator.device) + pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device) + return prompt_embeds, pooled_prompt_embeds # If no type of tuning is done on the text_encoder and custom instance prompts are NOT # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid # the redundant encoding. if not train_dataset.custom_instance_prompts: - instance_prompt_hidden_states, instance_prompt_attention_mask = compute_text_embeddings( - args.instance_prompt, text_encoding_pipeline + instance_prompt_hidden_states, instance_pooled_prompt_embeds, = compute_text_embeddings( + args.instance_prompt, text_encoders, tokenizers ) # Handle class prompt for prior-preservation. if args.with_prior_preservation: - class_prompt_hidden_states, class_prompt_attention_mask = compute_text_embeddings( - args.class_prompt, text_encoding_pipeline + class_prompt_hidden_states, class_pooled_prompt_embeds, = compute_text_embeddings( + args.class_prompt, text_encoders, tokenizers ) # Clear the memory here if not train_dataset.custom_instance_prompts: - del text_encoder, tokenizer + del text_encoder_one, text_encoder_two, text_encoder_three, text_encoder_four, tokenizer_one, tokenizer_two,tokenizer_three, tokenizer_four free_memory() # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), @@ -1454,10 +1472,10 @@ def compute_text_embeddings(prompt, text_encoding_pipeline): # have to pass them to the dataloader. if not train_dataset.custom_instance_prompts: prompt_embeds = instance_prompt_hidden_states - prompt_attention_mask = instance_prompt_attention_mask + pooled_prompt_embeds = instance_pooled_prompt_embeds if args.with_prior_preservation: prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0) - prompt_attention_mask = torch.cat([prompt_attention_mask, class_prompt_attention_mask], dim=0) + pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, class_pooled_prompt_embeds], dim=0) vae_config_scaling_factor = vae.config.scaling_factor vae_config_shift_factor = vae.config.shift_factor @@ -1506,7 +1524,7 @@ def compute_text_embeddings(prompt, text_encoding_pipeline): # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. if accelerator.is_main_process: - tracker_name = "dreambooth-lumina2-lora" + tracker_name = "dreambooth-hidream-lora" accelerator.init_trackers(tracker_name, config=vars(args)) # Train! @@ -1580,7 +1598,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): with accelerator.accumulate(models_to_accumulate): # encode batch prompts when custom prompts are provided for each image - if train_dataset.custom_instance_prompts: - prompt_embeds, prompt_attention_mask = compute_text_embeddings(prompts, text_encoding_pipeline) + prompt_embeds, pooled_prompt_embeds = compute_text_embeddings(prompts, text_encoders, tokenizers) # Convert images to latent space if args.cache_latents: @@ -1594,6 +1612,24 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor model_input = model_input.to(dtype=weight_dtype) + if model_input.shape[-2] != model_input.shape[-1]: + B, C, H, W = model_input.shape + pH, pW = H // transformer.config.patch_size, W // transformer.config.patch_size + + img_sizes = torch.tensor([pH, pW], dtype=torch.int64).reshape(-1) + img_ids = torch.zeros(pH, pW, 3) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW)[None, :] + img_ids = img_ids.reshape(pH * pW, -1) + img_ids_pad = torch.zeros(self.transformer.max_seq, 3) + img_ids_pad[: pH * pW, :] = img_ids + + img_sizes = img_sizes.unsqueeze(0).to(model_input.device) + img_ids = img_ids_pad.unsqueeze(0).to(model_input.device) + + else: + img_sizes = img_ids = None + # Sample noise that we'll add to the latents noise = torch.randn_like(model_input) bsz = model_input.shape[0] @@ -1612,22 +1648,21 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # Add noise according to flow matching. # zt = (1 - texp) * x + texp * z1 - # Lumina2 reverses the lerp i.e., sigma of 1.0 should mean `model_input` sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) - noisy_model_input = (1.0 - sigmas) * noise + sigmas * model_input + noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise # Predict the noise residual - # scale the timesteps (reversal not needed as we used a reverse lerp above already) - timesteps = timesteps / noise_scheduler.config.num_train_timesteps model_pred = transformer( hidden_states=noisy_model_input, encoder_hidden_states=prompt_embeds.repeat(len(prompts), 1, 1) if not train_dataset.custom_instance_prompts else prompt_embeds, - encoder_attention_mask=prompt_attention_mask.repeat(len(prompts), 1) + pooled_embeds=pooled_prompt_embeds.repeat(len(prompts), 1) if not train_dataset.custom_instance_prompts - else prompt_attention_mask, + else pooled_prompt_embeds, timestep=timesteps, + img_sizes=img_sizes, + img_ids=img_ids, return_dict=False, )[0] From 911c30e9a1d10c73cd8404d23fd9dae879be95cf Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Thu, 10 Apr 2025 12:43:09 +0300 Subject: [PATCH 05/76] initial commit --- .../dreambooth/train_dreambooth_lora_hidream.py | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index 0404411ecc0b..69ac69379beb 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -85,7 +85,6 @@ def save_model_card( images=None, base_model: str = None, instance_prompt=None, - system_prompt=None, validation_prompt=None, repo_folder=None, ): @@ -113,8 +112,6 @@ def save_model_card( You should use `{instance_prompt}` to trigger the image generation. -The following `system_prompt` was also used used during training (ignore if `None`): {system_prompt}. - ## Download model [Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab. @@ -324,12 +321,7 @@ def parse_args(input_args=None): default=256, help="Maximum sequence length to use with with the Gemma2 model", ) - parser.add_argument( - "--system_prompt", - type=str, - default=None, - help="System prompt to use during inference to give the Gemma2 model certain characteristics.", - ) + parser.add_argument( "--validation_prompt", type=str, @@ -382,7 +374,7 @@ def parse_args(input_args=None): parser.add_argument( "--output_dir", type=str, - default="lumina2-dreambooth-lora", + default="hidream-dreambooth-lora", help="The output directory where the model predictions and checkpoints will be written.", ) parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") @@ -1755,7 +1747,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): variant=args.variant, torch_dtype=weight_dtype, ) - pipeline_args = {"prompt": args.validation_prompt, "system_prompt": args.system_prompt} + pipeline_args = {"prompt": args.validation_prompt} images = log_validation( pipeline=pipeline, args=args, @@ -1799,7 +1791,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if (args.validation_prompt and args.num_validation_images > 0) or (args.final_validation_prompt): prompt_to_use = args.validation_prompt if args.validation_prompt else args.final_validation_prompt args.num_validation_images = args.num_validation_images if args.num_validation_images else 1 - pipeline_args = {"prompt": prompt_to_use, "system_prompt": args.system_prompt} + pipeline_args = {"prompt": prompt_to_use, "num_images_per_prompt": args.num_validation_images} images = log_validation( pipeline=pipeline, args=args, @@ -1816,7 +1808,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): images=images, base_model=args.pretrained_model_name_or_path, instance_prompt=args.instance_prompt, - system_prompt=args.system_prompt, validation_prompt=validation_prpmpt, repo_folder=args.output_dir, ) From b7fffee077c92aea136f1e185c9700bcbc6a1318 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Thu, 10 Apr 2025 12:46:19 +0300 Subject: [PATCH 06/76] initial commit --- examples/dreambooth/README_hidream.md | 127 ++++++++++++++++++++++++++ 1 file changed, 127 insertions(+) diff --git a/examples/dreambooth/README_hidream.md b/examples/dreambooth/README_hidream.md index e69de29bb2d1..82359b4785d6 100644 --- a/examples/dreambooth/README_hidream.md +++ b/examples/dreambooth/README_hidream.md @@ -0,0 +1,127 @@ +# DreamBooth training example for HiDream Image + +[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few (3~5) images of a subject. + +The `train_dreambooth_lora_hidream.py` script shows how to implement the training procedure with [LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) and adapt it for [HiDream Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/). + + +This will also allow us to push the trained model parameters to the Hugging Face Hub platform. + +## Running locally with PyTorch + +### Installing the dependencies + +Before running the scripts, make sure to install the library's training dependencies: + +**Important** + +To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment: + +```bash +git clone https://github.com/huggingface/diffusers +cd diffusers +pip install -e . +``` + +Then cd in the `examples/dreambooth` folder and run +```bash +pip install -r requirements_sana.txt +``` + +And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with: + +```bash +accelerate config +``` + +Or for a default accelerate configuration without answering questions about your environment + +```bash +accelerate config default +``` + +Or if your environment doesn't support an interactive shell (e.g., a notebook) + +```python +from accelerate.utils import write_basic_config +write_basic_config() +``` + +When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups. +Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.14.0` installed in your environment. + + +### Dog toy example + +Now let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example. + +Let's first download it locally: + +```python +from huggingface_hub import snapshot_download + +local_dir = "./dog" +snapshot_download( + "diffusers/dog-example", + local_dir=local_dir, repo_type="dataset", + ignore_patterns=".gitattributes", +) +``` + +This will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform. + +Now, we can launch training using: + +```bash +export MODEL_NAME="HiDream-ai/HiDream-I1-Dev" +export INSTANCE_DIR="dog" +export OUTPUT_DIR="trained-hidream-lora" + +accelerate launch train_dreambooth_lora_hidream.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --output_dir=$OUTPUT_DIR \ + --mixed_precision="bf16" \ + --instance_prompt="a photo of sks dog" \ + --resolution=1024 \ + --train_batch_size=1 \ + --gradient_accumulation_steps=4 \ + --use_8bit_adam \ + --learning_rate=1e-4 \ + --report_to="wandb" \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --max_train_steps=500 \ + --validation_prompt="A photo of sks dog in a bucket" \ + --validation_epochs=25 \ + --seed="0" \ + --push_to_hub +``` + +For using `push_to_hub`, make you're logged into your Hugging Face account: + +```bash +huggingface-cli login +``` + +To better track our training experiments, we're using the following flags in the command above: + +* `report_to="wandb` will ensure the training runs are tracked on [Weights and Biases](https://wandb.ai/site). To use it, be sure to install `wandb` with `pip install wandb`. Don't forget to call `wandb login ` before training if you haven't done it before. +* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected. + +## Notes + +Additionally, we welcome you to explore the following CLI arguments: + +* `--lora_layers`: The transformer modules to apply LoRA training on. Please specify the layers in a comma seperated. E.g. - "to_k,to_q,to_v" will result in lora training of attention layers only. +* `--system_prompt`: A custom system prompt to provide additional personality to the model. +* `--max_sequence_length`: Maximum sequence length to use for text embeddings. + + +We provide several options for optimizing memory optimization: + +* `--offload`: When enabled, we will offload the text encoder and VAE to CPU, when they are not used. +* `cache_latents`: When enabled, we will pre-compute the latents from the input images with the VAE and remove the VAE from memory once done. +* `--use_8bit_adam`: When enabled, we will use the 8bit version of AdamW provided by the `bitsandbytes` library. + +Refer to the [official documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/) of the `HiDreamImagePipeline` to know more about the model. From 02de3ce023f479f2c2a53be20060e3705e7e6865 Mon Sep 17 00:00:00 2001 From: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com> Date: Mon, 14 Apr 2025 14:16:53 +0300 Subject: [PATCH 07/76] Update examples/dreambooth/train_dreambooth_lora_hidream.py Co-authored-by: Bagheera <59658056+bghira@users.noreply.github.com> --- examples/dreambooth/train_dreambooth_lora_hidream.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index 69ac69379beb..56eeb15d5c4f 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -1662,7 +1662,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # and instead post-weight the loss weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) - # flow matching loss (reversed) + # flow matching loss via reverse ODE (HiDream uses an inverted flow field as compared to the more typical `noise - model_input` target) target = model_input - noise if args.with_prior_preservation: From 5257b468fd7355e2098e5ab72dfe3bd598566163 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 14 Apr 2025 14:44:44 +0300 Subject: [PATCH 08/76] move prompt embeds, pooled embeds outside --- examples/dreambooth/train_dreambooth_lora_hidream.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index 56eeb15d5c4f..bb18ef28a949 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -1591,7 +1591,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # encode batch prompts when custom prompts are provided for each image - if train_dataset.custom_instance_prompts: prompt_embeds, pooled_prompt_embeds = compute_text_embeddings(prompts, text_encoders, tokenizers) - + else: + prompt_embeds = prompt_embeds.repeat(len(prompts), 1, 1) + pooled_prompt_embeds = pooled_prompt_embeds.repeat(len(prompts), 1) # Convert images to latent space if args.cache_latents: model_input = latents_cache[step].sample() @@ -1646,12 +1648,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # Predict the noise residual model_pred = transformer( hidden_states=noisy_model_input, - encoder_hidden_states=prompt_embeds.repeat(len(prompts), 1, 1) - if not train_dataset.custom_instance_prompts - else prompt_embeds, - pooled_embeds=pooled_prompt_embeds.repeat(len(prompts), 1) - if not train_dataset.custom_instance_prompts - else pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + pooled_embeds=pooled_prompt_embeds, timestep=timesteps, img_sizes=img_sizes, img_ids=img_ids, From e9b4ad20cefe48a68e75aceacf448d51b207e8c7 Mon Sep 17 00:00:00 2001 From: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com> Date: Mon, 14 Apr 2025 15:01:31 +0300 Subject: [PATCH 09/76] Update examples/dreambooth/train_dreambooth_lora_hidream.py Co-authored-by: hlky --- examples/dreambooth/train_dreambooth_lora_hidream.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index 56eeb15d5c4f..e0b3b8b26078 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -991,7 +991,7 @@ def _encode_prompt_with_clip( if text_input_ids is None: raise ValueError("text_input_ids must be provided when the tokenizer is not specified") - prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + prompt_embeds = text_encoder(text_input_ids.to(device)) if hasattr(text_encoder, "module"): dtype = text_encoder.module.dtype From 677bab170edfb5742d9d1233fb3cedbc0b899b66 Mon Sep 17 00:00:00 2001 From: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com> Date: Mon, 14 Apr 2025 15:01:44 +0300 Subject: [PATCH 10/76] Update examples/dreambooth/train_dreambooth_lora_hidream.py Co-authored-by: hlky --- examples/dreambooth/train_dreambooth_lora_hidream.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index e0b3b8b26078..ea2c275401da 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -1216,7 +1216,7 @@ def main(args): revision=args.revision, variant=args.variant, ) - transformer = Lumina2Transformer2DModel.from_pretrained( + transformer = HiDreamImageTransformer2DModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant ) From 31aa0a2d6d29a09426d36d1a59e9d591c631579b Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 14 Apr 2025 15:04:18 +0300 Subject: [PATCH 11/76] fix import --- examples/dreambooth/train_dreambooth_lora_hidream.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index 4d411d02a9ce..91bd70d99ad4 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -42,7 +42,7 @@ from torchvision import transforms from torchvision.transforms.functional import crop from tqdm.auto import tqdm -from transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast +from transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast, PreTrainedTokenizerFast import diffusers from diffusers import ( From de1654a9b5826d055e44aea31e3f8e18a879bb68 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 14 Apr 2025 16:00:39 +0300 Subject: [PATCH 12/76] fix import and tokenizer 4, text encoder 4 loading --- .../train_dreambooth_lora_hidream.py | 34 +++++++++---------- 1 file changed, 16 insertions(+), 18 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index 91bd70d99ad4..08783cdcb05d 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -42,7 +42,7 @@ from torchvision import transforms from torchvision.transforms.functional import crop from tqdm.auto import tqdm -from transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast, PreTrainedTokenizerFast +from transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast, PreTrainedTokenizerFast, LlamaForCausalLM import diffusers from diffusers import ( @@ -146,7 +146,7 @@ def save_model_card( model_card = populate_model_card(model_card, tags=tags) model_card.save(os.path.join(repo_folder, "README.md")) -def load_text_encoders(class_one, class_two, class_three, class_four): +def load_text_encoders(class_one, class_two, class_three): text_encoder_one = class_one.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant ) @@ -156,9 +156,11 @@ def load_text_encoders(class_one, class_two, class_three, class_four): text_encoder_three = class_three.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder_3", revision=args.revision, variant=args.variant ) - text_encoder_four = class_four.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder_4", revision=args.revision, variant=args.variant - ) + text_encoder_four = LlamaForCausalLM.from_pretrained( + "meta-llama/Meta-Llama-3.1-8B-Instruct", + output_hidden_states=True, + output_attentions=True, + torch_dtype=torch.bfloat16,) return text_encoder_one, text_encoder_two, text_encoder_three, text_encoder_four def log_validation( @@ -211,18 +213,14 @@ def import_model_class_from_model_name_or_path( pretrained_model_name_or_path, subfolder=subfolder, revision=revision ) model_class = text_encoder_config.architectures[0] - if model_class == "CLIPTextModelWithProjection": - from transformers import CLIPTextModel + if model_class == "CLIPTextModelWithProjection" or model_class == "CLIPTextModel": + from transformers import CLIPTextModelWithProjection - return CLIPTextModel + return CLIPTextModelWithProjection elif model_class == "T5EncoderModel": from transformers import T5EncoderModel return T5EncoderModel - elif model_class == "LlamaForCausalLM": - from transformers import LlamaForCausalLM - - return LlamaForCausalLM else: raise ValueError(f"{model_class} is not supported.") @@ -1184,8 +1182,7 @@ def main(args): ) tokenizer_four = PreTrainedTokenizerFast.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="tokenizer_4", + "meta-llama/Meta-Llama-3.1-8B-Instruct", revision=args.revision, ) @@ -1199,16 +1196,13 @@ def main(args): text_encoder_cls_three = import_model_class_from_model_name_or_path( args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_3" ) - text_encoder_cls_four = import_model_class_from_model_name_or_path( - args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_4" - ) # Load scheduler and models noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( args.pretrained_model_name_or_path, subfolder="scheduler", revision=args.revision ) noise_scheduler_copy = copy.deepcopy(noise_scheduler) - text_encoder_one, text_encoder_two, text_encoder_three, text_encoder_four = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three, text_encoder_cls_four) + text_encoder_one, text_encoder_two, text_encoder_three, text_encoder_four = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three) vae = AutoencoderKL.from_pretrained( args.pretrained_model_name_or_path, @@ -1740,6 +1734,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # create pipeline pipeline = HiDreamImagePipeline.from_pretrained( args.pretrained_model_name_or_path, + # tokenizer_4=tokenizer_4, + # text_encoder_4=text_encoder_4, transformer=accelerator.unwrap_model(transformer), revision=args.revision, variant=args.variant, @@ -1777,6 +1773,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # Load previous pipeline pipeline = HiDreamImagePipeline.from_pretrained( args.pretrained_model_name_or_path, + # tokenizer_4=tokenizer_4, + # text_encoder_4=text_encoder_4, revision=args.revision, variant=args.variant, torch_dtype=weight_dtype, From 33385c997279edcf4a0ab2762be9fb0b16238aa2 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 14 Apr 2025 16:17:37 +0300 Subject: [PATCH 13/76] te --- examples/dreambooth/train_dreambooth_lora_hidream.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index 08783cdcb05d..675b95b9cb27 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -895,7 +895,7 @@ def _encode_prompt_with_llama( if attention_mask is None: raise ValueError("text_input_ids must be provided when the tokenizer is not specified") - outputs = self.text_encoder_4( + outputs = text_encoder( text_input_ids.to(device), attention_mask=attention_mask.to(device), output_hidden_states=True, @@ -1185,6 +1185,7 @@ def main(args): "meta-llama/Meta-Llama-3.1-8B-Instruct", revision=args.revision, ) + tokenizer_four.pad_token = tokenizer_four.eos_token # import correct text encoder classes text_encoder_cls_one = import_model_class_from_model_name_or_path( From d993e161dc41a5ab05bf02fc5c88aedef21f7211 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 14 Apr 2025 17:09:25 +0300 Subject: [PATCH 14/76] prompt embeds --- examples/dreambooth/train_dreambooth_lora_hidream.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index 675b95b9cb27..15c7e20f798d 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -907,8 +907,8 @@ def _encode_prompt_with_llama( else: dtype = text_encoder.dtype - prompt_embeds = outputs.hidden_states[1:].to(dtype=dtype, device=device) - prompt_embeds = torch.stack(prompt_embeds, dim=0) + prompt_embeds = outputs.hidden_states[1:] + prompt_embeds = torch.stack(prompt_embeds, dim=0).to(dtype=dtype, device=device) _, _, seq_len, dim = prompt_embeds.shape # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method @@ -1060,6 +1060,8 @@ def encode_prompt( attention_mask=attention_mask_list[1] if attention_mask_list else None, ) + print("t5_prompt_embeds",t5_prompt_embeds.shape) + print("llama3_prompt_embeds",llama3_prompt_embeds.shape) prompt_embeds = [t5_prompt_embeds, llama3_prompt_embeds] return prompt_embeds, pooled_prompt_embeds @@ -1431,7 +1433,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): prompt_embeds, pooled_prompt_embeds = encode_prompt( text_encoders, tokenizers, prompt, args.max_sequence_length ) - prompt_embeds = prompt_embeds.to(accelerator.device) + prompt_embeds[0] = prompt_embeds[0].to(accelerator.device) + prompt_embeds[1] = prompt_embeds[1].to(accelerator.device) pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device) return prompt_embeds, pooled_prompt_embeds @@ -1587,7 +1590,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if train_dataset.custom_instance_prompts: prompt_embeds, pooled_prompt_embeds = compute_text_embeddings(prompts, text_encoders, tokenizers) else: - prompt_embeds = prompt_embeds.repeat(len(prompts), 1, 1) + prompt_embeds[0] = prompt_embeds[0].repeat(len(prompts), 1, 1) + prompt_embeds[1] = prompt_embeds[1].repeat(1, len(prompts), 1, 1) pooled_prompt_embeds = pooled_prompt_embeds.repeat(len(prompts), 1) # Convert images to latent space if args.cache_latents: From c296b6fe9f30ca9c653ad162dca4dce9b35c0a24 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 14 Apr 2025 17:11:10 +0300 Subject: [PATCH 15/76] fix naming --- examples/dreambooth/train_dreambooth_lora_hidream.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index 15c7e20f798d..0c0e48c50196 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -1649,7 +1649,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): hidden_states=noisy_model_input, encoder_hidden_states=prompt_embeds, pooled_embeds=pooled_prompt_embeds, - timestep=timesteps, + timesteps=timesteps, img_sizes=img_sizes, img_ids=img_ids, return_dict=False, From aa6b6e28b864890f40501de9637a08f11c8704e9 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 14 Apr 2025 18:09:09 +0300 Subject: [PATCH 16/76] shapes --- examples/dreambooth/train_dreambooth_lora_hidream.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index 0c0e48c50196..41e2e4938961 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -1454,7 +1454,11 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # Clear the memory here if not train_dataset.custom_instance_prompts: - del text_encoder_one, text_encoder_two, text_encoder_three, text_encoder_four, tokenizer_one, tokenizer_two,tokenizer_three, tokenizer_four + # delete tokenizers and text encoders except for llama (tokenizer & te four) + # as it's needed for inference with pipeline + del text_encoder_one, text_encoder_two, text_encoder_three, tokenizer_one, tokenizer_two,tokenizer_three + if not args.validation_prompt: + del tokenizer_four, text_encoder_four free_memory() # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), @@ -1739,8 +1743,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # create pipeline pipeline = HiDreamImagePipeline.from_pretrained( args.pretrained_model_name_or_path, - # tokenizer_4=tokenizer_4, - # text_encoder_4=text_encoder_4, + tokenizer_4=tokenizer_four, + text_encoder_4=text_encoder_four, transformer=accelerator.unwrap_model(transformer), revision=args.revision, variant=args.variant, From ecc1c18ee2af43752d5212d3bc50e636056511c7 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 14 Apr 2025 18:51:38 +0300 Subject: [PATCH 17/76] initial commit to add HiDreamImageLoraLoaderMixin --- src/diffusers/loaders/__init__.py | 1 + src/diffusers/loaders/lora_pipeline.py | 335 ++++++++++++++++++ .../hidream_image/pipeline_hidream_image.py | 3 +- 3 files changed, 338 insertions(+), 1 deletion(-) diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index 3ba1bfacf3dd..2c3eefd4f42c 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -76,6 +76,7 @@ def text_encoder_attn_modules(text_encoder): "SanaLoraLoaderMixin", "Lumina2LoraLoaderMixin", "WanLoraLoaderMixin", + "HiDreamImageLoraLoaderMixin" ] _import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"] _import_structure["ip_adapter"] = [ diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 2e241bc9ffad..6ed159ca6fe1 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -5395,6 +5395,341 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): """ super().unfuse_lora(components=components, **kwargs) +class HiDreamImageLoraLoaderMixin(LoraBaseMixin): + r""" + Load LoRA layers into [`HiDreamImageTransformer2DModel`]. Specific to [`HiDreamImagePipeline`]. + """ + + _lora_loadable_modules = ["transformer"] + transformer_name = TRANSFORMER_NAME + + @classmethod + @validate_hf_hub_args + def lora_state_dict( + cls, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + **kwargs, + ): + r""" + Return state dict for lora weights and the network alphas. + + + + We support loading A1111 formatted LoRA checkpoints in a limited capacity. + + This function is experimental and might change in the future. + + + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + + """ + # Load the main state dict first which has the LoRA layers for either of + # transformer and text encoder or both. + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + state_dict = _fetch_state_dict( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + use_safetensors=use_safetensors, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + allow_pickle=allow_pickle, + ) + + is_dora_scale_present = any("dora_scale" in k for k in state_dict) + if is_dora_scale_present: + warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." + logger.warning(warn_msg) + state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} + + # conversion. + non_diffusers = any(k.startswith("diffusion_model.") for k in state_dict) + if non_diffusers: + state_dict = _convert_non_diffusers_lumina2_lora_to_diffusers(state_dict) + + return state_dict + + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights + def load_lora_weights( + self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs + ): + """ + Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and + `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See + [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded. + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state + dict is loaded into `self.transformer`. + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + kwargs (`dict`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # if a dict is passed, copy it instead of modifying it inplace + if isinstance(pretrained_model_name_or_path_or_dict, dict): + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() + + # First, ensure that the checkpoint is a compatible one and can be successfully loaded. + state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + + is_correct_format = all("lora" in key for key in state_dict.keys()) + if not is_correct_format: + raise ValueError("Invalid LoRA checkpoint.") + + self.load_lora_into_transformer( + state_dict, + transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, + adapter_name=adapter_name, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + ) + + @classmethod + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->Lumina2Transformer2DModel + def load_lora_into_transformer( + cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False + ): + """ + This will load the LoRA layers specified in `state_dict` into `transformer`. + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The keys can either be indexed directly + into the unet or prefixed with an additional `unet` which can be used to distinguish between text + encoder lora layers. + transformer (`HiDreamImageTransformer2DModel`): + The Transformer model to load the LoRA layers into. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + hotswap : (`bool`, *optional*) + Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter + in-place. This means that, instead of loading an additional adapter, this will take the existing + adapter weights and replace them with the weights of the new adapter. This can be faster and more + memory efficient. However, the main advantage of hotswapping is that when the model is compiled with + torch.compile, loading the new adapter does not require recompilation of the model. When using + hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. + + If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need + to call an additional method before loading the adapter: + + ```py + pipeline = ... # load diffusers pipeline + max_rank = ... # the highest rank among all LoRAs that you want to load + # call *before* compiling and loading the LoRA adapter + pipeline.enable_lora_hotswap(target_rank=max_rank) + pipeline.load_lora_weights(file_name) + # optionally compile the model now + ``` + + Note that hotswapping adapters of the text encoder is not yet supported. There are some further + limitations to this technique, which are documented here: + https://huggingface.co/docs/peft/main/en/package_reference/hotswap + """ + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # Load the layers corresponding to transformer. + logger.info(f"Loading {cls.transformer_name}.") + transformer.load_lora_adapter( + state_dict, + network_alphas=None, + adapter_name=adapter_name, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights + def save_lora_weights( + cls, + save_directory: Union[str, os.PathLike], + transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + ): + r""" + Save the LoRA parameters corresponding to the UNet and text encoder. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to save LoRA parameters to. Will be created if it doesn't exist. + transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `transformer`. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful during distributed training and you + need to call this function on all processes. In this case, set `is_main_process=True` only on the main + process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful during distributed training when you need to + replace `torch.save` with another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + """ + state_dict = {} + + if not transformer_lora_layers: + raise ValueError("You must pass `transformer_lora_layers`.") + + if transformer_lora_layers: + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + + # Save the model + cls.write_lora_layers( + state_dict=state_dict, + save_directory=save_directory, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + ) + + # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora + def fuse_lora( + self, + components: List[str] = ["transformer"], + lora_scale: float = 1.0, + safe_fusing: bool = False, + adapter_names: Optional[List[str]] = None, + **kwargs, + ): + r""" + Fuses the LoRA parameters into the original parameters of the corresponding blocks. + + + + This is an experimental API. + + + + Args: + components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. + lora_scale (`float`, defaults to 1.0): + Controls how much to influence the outputs with the LoRA parameters. + safe_fusing (`bool`, defaults to `False`): + Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. + adapter_names (`List[str]`, *optional*): + Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. + + Example: + + ```py + from diffusers import DiffusionPipeline + import torch + + pipeline = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ).to("cuda") + pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") + pipeline.fuse_lora(lora_scale=0.7) + ``` + """ + super().fuse_lora( + components=components, + lora_scale=lora_scale, + safe_fusing=safe_fusing, + adapter_names=adapter_names, + **kwargs, + ) + + # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora + def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): + r""" + Reverses the effect of + [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). + + + + This is an experimental API. + + + + Args: + components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. + unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. + """ + super().unfuse_lora(components=components, **kwargs) class LoraLoaderMixin(StableDiffusionLoraLoaderMixin): def __init__(self, *args, **kwargs): diff --git a/src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py b/src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py index e16dedb53674..fea9e64cdcef 100644 --- a/src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py +++ b/src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py @@ -13,6 +13,7 @@ ) from ...image_processor import VaeImageProcessor +from ...loaders import HiDreamImageLoraLoaderMixin from ...models import AutoencoderKL, HiDreamImageTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler, UniPCMultistepScheduler from ...utils import is_torch_xla_available, logging @@ -151,7 +152,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class HiDreamImagePipeline(DiffusionPipeline): +class HiDreamImagePipeline(DiffusionPipeline, HiDreamImageLoraLoaderMixin): model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->text_encoder_4->transformer->vae" _callback_tensor_inputs = ["latents", "prompt_embeds"] From c439c89fdd7ea414034cb87996735f070315fc85 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 14 Apr 2025 18:52:20 +0300 Subject: [PATCH 18/76] fix init --- src/diffusers/loaders/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index 2c3eefd4f42c..b579262282f1 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -117,6 +117,7 @@ def text_encoder_attn_modules(text_encoder): StableDiffusionLoraLoaderMixin, StableDiffusionXLLoraLoaderMixin, WanLoraLoaderMixin, + HiDreamImageLoraLoaderMixin, ) from .single_file import FromSingleFileMixin from .textual_inversion import TextualInversionLoaderMixin From 22e9ae807e3e3326d6e052920e364c1a660d5e7a Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 14 Apr 2025 19:56:04 +0300 Subject: [PATCH 19/76] add tests --- .../test_dreambooth_lora_hidream.py | 206 ++++++++++++++++++ 1 file changed, 206 insertions(+) create mode 100644 examples/dreambooth/test_dreambooth_lora_hidream.py diff --git a/examples/dreambooth/test_dreambooth_lora_hidream.py b/examples/dreambooth/test_dreambooth_lora_hidream.py new file mode 100644 index 000000000000..add81767d646 --- /dev/null +++ b/examples/dreambooth/test_dreambooth_lora_hidream.py @@ -0,0 +1,206 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import sys +import tempfile + +import safetensors + + +sys.path.append("..") +from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402 + + +logging.basicConfig(level=logging.DEBUG) + +logger = logging.getLogger() +stream_handler = logging.StreamHandler(sys.stdout) +logger.addHandler(stream_handler) + + +class DreamBoothLoRAHiDreamImage(ExamplesTestsAccelerate): + instance_data_dir = "docs/source/en/imgs" + pretrained_model_name_or_path = "hf-internal-testing/tiny-hidream-i1-pipe" + script_path = "examples/dreambooth/train_dreambooth_lora_hidream.py" + transformer_layer_type = "double_stream_blocks.0.block.attn1.to_k" + + def test_dreambooth_lora_hidream(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --instance_data_dir {self.instance_data_dir} + --resolution 32 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + --max_sequence_length 16 + """.split() + + test_args.extend(["--instance_prompt", ""]) + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) + + # make sure the state_dict has the correct naming in the parameters. + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) + is_lora = all("lora" in k for k in lora_state_dict.keys()) + self.assertTrue(is_lora) + + # when not training the text encoder, all the parameters in the state dict should start + # with `"transformer"` in their names. + starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys()) + self.assertTrue(starts_with_transformer) + + def test_dreambooth_lora_latent_caching(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --instance_data_dir {self.instance_data_dir} + --resolution 32 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --cache_latents + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + --max_sequence_length 16 + """.split() + + test_args.extend(["--instance_prompt", ""]) + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) + + # make sure the state_dict has the correct naming in the parameters. + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) + is_lora = all("lora" in k for k in lora_state_dict.keys()) + self.assertTrue(is_lora) + + # when not training the text encoder, all the parameters in the state dict should start + # with `"transformer"` in their names. + starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys()) + self.assertTrue(starts_with_transformer) + + def test_dreambooth_lora_layers(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --instance_data_dir {self.instance_data_dir} + --resolution 32 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --cache_latents + --learning_rate 5.0e-04 + --scale_lr + --lora_layers {self.transformer_layer_type} + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + --max_sequence_length 16 + """.split() + + test_args.extend(["--instance_prompt", ""]) + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) + + # make sure the state_dict has the correct naming in the parameters. + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) + is_lora = all("lora" in k for k in lora_state_dict.keys()) + self.assertTrue(is_lora) + + # when not training the text encoder, all the parameters in the state dict should start + # with `"transformer"` in their names. In this test, we only params of + # `self.transformer_layer_type` should be in the state dict. + starts_with_transformer = all(self.transformer_layer_type in key for key in lora_state_dict) + self.assertTrue(starts_with_transformer) + + def test_dreambooth_lora_hidream_checkpointing_checkpoints_total_limit(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path={self.pretrained_model_name_or_path} + --instance_data_dir={self.instance_data_dir} + --output_dir={tmpdir} + --resolution=32 + --train_batch_size=1 + --gradient_accumulation_steps=1 + --max_train_steps=6 + --checkpoints_total_limit=2 + --checkpointing_steps=2 + --max_sequence_length 16 + """.split() + + test_args.extend(["--instance_prompt", ""]) + run_command(self._launch_args + test_args) + + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-4", "checkpoint-6"}, + ) + + def test_dreambooth_lora_hidream_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path={self.pretrained_model_name_or_path} + --instance_data_dir={self.instance_data_dir} + --output_dir={tmpdir} + --resolution=32 + --train_batch_size=1 + --gradient_accumulation_steps=1 + --max_train_steps=4 + --checkpointing_steps=2 + --max_sequence_length 166 + """.split() + + test_args.extend(["--instance_prompt", ""]) + run_command(self._launch_args + test_args) + + self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"}) + + resume_run_args = f""" + {self.script_path} + --pretrained_model_name_or_path={self.pretrained_model_name_or_path} + --instance_data_dir={self.instance_data_dir} + --output_dir={tmpdir} + --resolution=32 + --train_batch_size=1 + --gradient_accumulation_steps=1 + --max_train_steps=8 + --checkpointing_steps=2 + --resume_from_checkpoint=checkpoint-4 + --checkpoints_total_limit=2 + --max_sequence_length 16 + """.split() + + resume_run_args.extend(["--instance_prompt", ""]) + run_command(self._launch_args + resume_run_args) + + self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"}) From 3653dcc65188a779216e2fc2e4bc0109d09ef8ef Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 14 Apr 2025 20:48:43 +0300 Subject: [PATCH 20/76] loader --- docs/source/en/api/loaders/lora.md | 5 +++++ src/diffusers/loaders/peft.py | 1 + 2 files changed, 6 insertions(+) diff --git a/docs/source/en/api/loaders/lora.md b/docs/source/en/api/loaders/lora.md index 58611a61c25d..5e4bc1969ff4 100644 --- a/docs/source/en/api/loaders/lora.md +++ b/docs/source/en/api/loaders/lora.md @@ -24,6 +24,7 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi - [`SanaLoraLoaderMixin`] provides similar functions for [Sana](https://huggingface.co/docs/diffusers/main/en/api/pipelines/sana). - [`HunyuanVideoLoraLoaderMixin`] provides similar functions for [HunyuanVideo](https://huggingface.co/docs/diffusers/main/en/api/pipelines/hunyuan_video). - [`Lumina2LoraLoaderMixin`] provides similar functions for [Lumina2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/lumina2). +- [`HiDreamImageLoraLoaderMixin`] provides similar functions for [HiDream Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/hidream) - [`AmusedLoraLoaderMixin`] is for the [`AmusedPipeline`]. - [`LoraBaseMixin`] provides a base class with several utility methods to fuse, unfuse, unload, LoRAs and more. @@ -73,6 +74,10 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse [[autodoc]] loaders.lora_pipeline.Lumina2LoraLoaderMixin +## HiDreamImageLoraLoaderMixin + +[[autodoc]] loaders.lora_pipeline.HiDreamImageLoraLoaderMixin + ## AmusedLoraLoaderMixin [[autodoc]] loaders.lora_pipeline.AmusedLoraLoaderMixin diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 9165c46f3c78..5b4e2ec63ef8 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -55,6 +55,7 @@ "Lumina2Transformer2DModel": lambda model_cls, weights: weights, "WanTransformer3DModel": lambda model_cls, weights: weights, "CogView4Transformer2DModel": lambda model_cls, weights: weights, + "HiDreamImageTransformer2DModel": lambda model_cls, weights: weights, } From fcf6eaa9b9912e45cb643b393f720836129d749e Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Tue, 15 Apr 2025 09:51:30 +0300 Subject: [PATCH 21/76] fix model input --- .../dreambooth/train_dreambooth_lora_hidream.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index 41e2e4938961..745b1afc0fcb 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -48,6 +48,7 @@ from diffusers import ( AutoencoderKL, FlowMatchEulerDiscreteScheduler, + UniPCMultistepScheduler, HiDreamImagePipeline, HiDreamImageTransformer2DModel, ) @@ -1454,7 +1455,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # Clear the memory here if not train_dataset.custom_instance_prompts: - # delete tokenizers and text encoders except for llama (tokenizer & te four) + # delete tokenizers and text ecnoders except for llama (tokenizer & te four) # as it's needed for inference with pipeline del text_encoder_one, text_encoder_two, text_encoder_three, tokenizer_one, tokenizer_two,tokenizer_three if not args.validation_prompt: @@ -1646,8 +1647,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # Add noise according to flow matching. # zt = (1 - texp) * x + texp * z1 sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) - noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise - + noisy_model_input = (1.0 - sigmas) * noise + sigmas * model_input # Predict the noise residual model_pred = transformer( hidden_states=noisy_model_input, @@ -1780,10 +1780,17 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # Final inference # Load previous pipeline + tokenizer_4 = PreTrainedTokenizerFast.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct") + text_encoder_4 = LlamaForCausalLM.from_pretrained( + "meta-llama/Meta-Llama-3.1-8B-Instruct", + output_hidden_states=True, + output_attentions=True, + torch_dtype=torch.bfloat16, + ) pipeline = HiDreamImagePipeline.from_pretrained( args.pretrained_model_name_or_path, - # tokenizer_4=tokenizer_4, - # text_encoder_4=text_encoder_4, + tokenizer_4=tokenizer_4, + text_encoder_4=text_encoder_4, revision=args.revision, variant=args.variant, torch_dtype=weight_dtype, From 0fdc7dd49c17b88a8d62eb2ac97dad224c9ab501 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Tue, 15 Apr 2025 09:57:29 +0300 Subject: [PATCH 22/76] add code example to readme --- .../train_dreambooth_lora_hidream.py | 30 +++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index 745b1afc0fcb..f164751037e0 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -108,7 +108,6 @@ def save_model_card( The weights were trained using [DreamBooth](https://dreambooth.github.io/) with the [HiDream Image diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_hidream.md). - ## Trigger words You should use `{instance_prompt}` to trigger the image generation. @@ -120,7 +119,34 @@ def save_model_card( ## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers) ```py -TODO + >>> import torch + >>> from transformers import PreTrainedTokenizerFast, LlamaForCausalLM + >>> from diffusers import UniPCMultistepScheduler, HiDreamImagePipeline + + >>> scheduler = UniPCMultistepScheduler( + ... flow_shift=3.0, prediction_type="flow_prediction", use_flow_sigmas=True + ... ) + + >>> tokenizer_4 = PreTrainedTokenizerFast.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct") + >>> text_encoder_4 = LlamaForCausalLM.from_pretrained( + ... "meta-llama/Meta-Llama-3.1-8B-Instruct", + ... output_hidden_states=True, + ... output_attentions=True, + ... torch_dtype=torch.bfloat16, + ... ) + + >>> pipe = HiDreamImagePipeline.from_pretrained( + ... "HiDream-ai/HiDream-I1-Full", + ... scheduler=scheduler, + ... tokenizer_4=tokenizer_4, + ... text_encoder_4=text_encoder_4, + ... torch_dtype=torch.bfloat16, + ... ) + >>> pipe.enable_model_cpu_offload() + >>> pipe.load_lora_weights(f"{repo_id}") + >>> image = pipe(f"{instance_prompt}").images[0] + + ``` For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) From 62f2f15612750753c73e6f9a669459a109c5b791 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Tue, 15 Apr 2025 11:44:27 +0300 Subject: [PATCH 23/76] fix default max length of text encoders --- examples/dreambooth/train_dreambooth_lora_hidream.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index f164751037e0..16f8f5493343 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -343,8 +343,8 @@ def parse_args(input_args=None): parser.add_argument( "--max_sequence_length", type=int, - default=256, - help="Maximum sequence length to use with with the Gemma2 model", + default=128, + help="Maximum sequence length to use with t5 and llama encoders", ) parser.add_argument( @@ -1087,8 +1087,6 @@ def encode_prompt( attention_mask=attention_mask_list[1] if attention_mask_list else None, ) - print("t5_prompt_embeds",t5_prompt_embeds.shape) - print("llama3_prompt_embeds",llama3_prompt_embeds.shape) prompt_embeds = [t5_prompt_embeds, llama3_prompt_embeds] return prompt_embeds, pooled_prompt_embeds From 3cb1a4ce50afb607de7fce05bb1d7da3783caaea Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Tue, 15 Apr 2025 12:20:19 +0300 Subject: [PATCH 24/76] prints --- examples/dreambooth/train_dreambooth_lora_hidream.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index 16f8f5493343..21e56c3e3702 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -1673,6 +1673,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) noisy_model_input = (1.0 - sigmas) * noise + sigmas * model_input # Predict the noise residual + + print("noisy_model_input", noisy_model_input.shape) + print("prompt_embeds", prompt_embeds[0].shape, prompt_embeds[1].shape) + print("pooled_prompt_embeds", pooled_prompt_embeds.shape) + model_pred = transformer( hidden_states=noisy_model_input, encoder_hidden_states=prompt_embeds, @@ -1682,7 +1687,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): img_ids=img_ids, return_dict=False, )[0] - + print("model_pred", model_pred.shape) # these weighting schemes use a uniform timestep sampling # and instead post-weight the loss weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) From 82bcd4433ee73a81d99527ac1a7cdd0384154f2a Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Tue, 15 Apr 2025 19:37:39 +0300 Subject: [PATCH 25/76] nullify training cond in unpatchify for temp fix to incompatible shaping of transformer output during training --- src/diffusers/models/transformers/transformer_hidream_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_hidream_image.py b/src/diffusers/models/transformers/transformer_hidream_image.py index 04622a7e04b2..b6022a5410f0 100644 --- a/src/diffusers/models/transformers/transformer_hidream_image.py +++ b/src/diffusers/models/transformers/transformer_hidream_image.py @@ -662,7 +662,7 @@ def __init__( self.gradient_checkpointing = False def unpatchify(self, x: torch.Tensor, img_sizes: List[Tuple[int, int]], is_training: bool) -> List[torch.Tensor]: - if is_training: + if is_training and False: # temporary!!! B, S, F = x.shape C = F // (self.config.patch_size * self.config.patch_size) x = ( From 75aa8bdcec25882966fdedfe4b26747f6af5c6de Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Wed, 16 Apr 2025 12:02:37 +0300 Subject: [PATCH 26/76] smol fix --- examples/dreambooth/train_dreambooth_lora_hidream.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index 21e56c3e3702..924d8eaa91e0 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -1643,7 +1643,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH)[:, None] img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW)[None, :] img_ids = img_ids.reshape(pH * pW, -1) - img_ids_pad = torch.zeros(self.transformer.max_seq, 3) + img_ids_pad = torch.zeros(transformer.max_seq, 3) img_ids_pad[: pH * pW, :] = img_ids img_sizes = img_sizes.unsqueeze(0).to(model_input.device) From 47e861fc23ba0d092426e8743cd4339522f9e6d9 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Wed, 16 Apr 2025 12:26:29 +0300 Subject: [PATCH 27/76] unpatchify --- .../transformers/transformer_hidream_image.py | 33 +++++++++++++++---- 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_hidream_image.py b/src/diffusers/models/transformers/transformer_hidream_image.py index b6022a5410f0..36f1b39df95f 100644 --- a/src/diffusers/models/transformers/transformer_hidream_image.py +++ b/src/diffusers/models/transformers/transformer_hidream_image.py @@ -662,13 +662,32 @@ def __init__( self.gradient_checkpointing = False def unpatchify(self, x: torch.Tensor, img_sizes: List[Tuple[int, int]], is_training: bool) -> List[torch.Tensor]: - if is_training and False: # temporary!!! - B, S, F = x.shape - C = F // (self.config.patch_size * self.config.patch_size) - x = ( - x.reshape(B, S, self.config.patch_size, self.config.patch_size, C) - .permute(0, 4, 1, 2, 3) - .reshape(B, C, S, self.config.patch_size * self.config.patch_size) + if is_training: + # Assuming img_sizes contains [[pH, pW]] for each item in the batch. + # For simplicity in training, often all batches have the same size. + # We'll assume img_sizes[0] gives the target patch dimensions. + # If training with variable sizes, this needs more careful handling per item. + pH, pW = img_sizes[0] # Get target patch height/width + expected_S = pH * pW + # Ensure sequence length S matches expected H*W before rearranging + # This might require padding/truncating x if training uses fixed max_seq + current_S = x.shape[1] + if current_S > expected_S: + x = x[:, :expected_S, :] # Use only the relevant part of the sequence + elif current_S < expected_S: + # This case is less likely if padding happens earlier, but handle defensively + raise ValueError( + f"Sequence length {current_S} is less than expected {expected_S} ({pH}x{pW}) during unpatchify.") + + # Original incorrect line: + # x = einops.rearrange(x, 'B S (p1 p2 C) -> B C S (p1 p2)', p1=self.config.patch_size, p2=self.config.patch_size) + + # Corrected line using einops and H, W: + x = einops.rearrange( + x, + 'B (H W) (p1 p2 C) -> B C (H p1) (W p2)', + H=pH, W=pW, + p1=self.config.patch_size, p2=self.config.patch_size ) else: x_arr = [] From a461d3858d04111f11fadd49b5164ce7c81746bf Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Wed, 16 Apr 2025 15:29:37 +0300 Subject: [PATCH 28/76] unpatchify --- .../models/transformers/transformer_hidream_image.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_hidream_image.py b/src/diffusers/models/transformers/transformer_hidream_image.py index 36f1b39df95f..e61d2f02b640 100644 --- a/src/diffusers/models/transformers/transformer_hidream_image.py +++ b/src/diffusers/models/transformers/transformer_hidream_image.py @@ -663,10 +663,6 @@ def __init__( def unpatchify(self, x: torch.Tensor, img_sizes: List[Tuple[int, int]], is_training: bool) -> List[torch.Tensor]: if is_training: - # Assuming img_sizes contains [[pH, pW]] for each item in the batch. - # For simplicity in training, often all batches have the same size. - # We'll assume img_sizes[0] gives the target patch dimensions. - # If training with variable sizes, this needs more careful handling per item. pH, pW = img_sizes[0] # Get target patch height/width expected_S = pH * pW # Ensure sequence length S matches expected H*W before rearranging @@ -679,9 +675,6 @@ def unpatchify(self, x: torch.Tensor, img_sizes: List[Tuple[int, int]], is_train raise ValueError( f"Sequence length {current_S} is less than expected {expected_S} ({pH}x{pW}) during unpatchify.") - # Original incorrect line: - # x = einops.rearrange(x, 'B S (p1 p2 C) -> B C S (p1 p2)', p1=self.config.patch_size, p2=self.config.patch_size) - # Corrected line using einops and H, W: x = einops.rearrange( x, From b31b59563a2dc8a785697464ed4a5e195a31c9ef Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Wed, 16 Apr 2025 18:08:42 +0300 Subject: [PATCH 29/76] fix validation --- .../train_dreambooth_lora_hidream.py | 29 ++++++++++++++----- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index 924d8eaa91e0..07287b19f829 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -196,22 +196,35 @@ def log_validation( accelerator, pipeline_args, epoch, + torch_dtype, is_final_validation=False, ): logger.info( f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f" {args.validation_prompt}." ) - - pipeline = pipeline.to(accelerator.device) + pipeline = pipeline.to(accelerator.device, dtype=torch_dtype) pipeline.set_progress_bar_config(disable=True) # run inference generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() - with autocast_ctx: - images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)] + # pre-calculate prompt embeds, pooled prompt embeds, text ids because t5 does not support autocast + with torch.no_grad(): + prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds= pipeline.encode_prompt( + pipeline_args["prompt"], prompt_2=pipeline_args["prompt"], prompt_3=pipeline_args["prompt"], prompt_4=pipeline_args["prompt"] + ) + images = [] + for _ in range(args.num_validation_images): + with autocast_ctx: + image = pipeline( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, generator=generator + ).images[0] + images.append(image) for tracker in accelerator.trackers: phase_name = "test" if is_final_validation else "validation" @@ -222,7 +235,7 @@ def log_validation( tracker.log( { phase_name: [ - wandb.Image(image, caption=f"{i}: {pipeline_args['prompt']}") for i, image in enumerate(images) + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) ] } ) @@ -1481,7 +1494,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): if not train_dataset.custom_instance_prompts: # delete tokenizers and text ecnoders except for llama (tokenizer & te four) # as it's needed for inference with pipeline - del text_encoder_one, text_encoder_two, text_encoder_three, tokenizer_one, tokenizer_two,tokenizer_three + del text_encoder_one, text_encoder_two, text_encoder_three, tokenizer_one, tokenizer_two, tokenizer_three if not args.validation_prompt: del tokenizer_four, text_encoder_four free_memory() @@ -1643,7 +1656,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH)[:, None] img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW)[None, :] img_ids = img_ids.reshape(pH * pW, -1) - img_ids_pad = torch.zeros(transformer.max_seq, 3) + img_ids_pad = torch.zeros(self.transformer.max_seq, 3) img_ids_pad[: pH * pW, :] = img_ids img_sizes = img_sizes.unsqueeze(0).to(model_input.device) @@ -1785,6 +1798,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): args=args, accelerator=accelerator, pipeline_args=pipeline_args, + torch_dtype=weight_dtype, epoch=epoch, ) free_memory() @@ -1840,6 +1854,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): pipeline_args=pipeline_args, epoch=epoch, is_final_validation=True, + torch_dtype=weight_dtype, ) if args.push_to_hub: From 466c9c03d4af69981f69b596419ecbcc699b4d6e Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Thu, 17 Apr 2025 10:14:52 +0300 Subject: [PATCH 30/76] flip pred and loss --- .../train_dreambooth_lora_hidream.py | 174 +++++++++--------- 1 file changed, 91 insertions(+), 83 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index 07287b19f829..615eac0f7eea 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -68,7 +68,6 @@ from diffusers.utils.import_utils import is_torch_npu_available from diffusers.utils.torch_utils import is_compiled_module - if is_wandb_available(): import wandb @@ -82,12 +81,12 @@ def save_model_card( - repo_id: str, - images=None, - base_model: str = None, - instance_prompt=None, - validation_prompt=None, - repo_folder=None, + repo_id: str, + images=None, + base_model: str = None, + instance_prompt=None, + validation_prompt=None, + repo_folder=None, ): widget_dict = [] if images is not None: @@ -122,11 +121,11 @@ def save_model_card( >>> import torch >>> from transformers import PreTrainedTokenizerFast, LlamaForCausalLM >>> from diffusers import UniPCMultistepScheduler, HiDreamImagePipeline - + >>> scheduler = UniPCMultistepScheduler( ... flow_shift=3.0, prediction_type="flow_prediction", use_flow_sigmas=True ... ) - + >>> tokenizer_4 = PreTrainedTokenizerFast.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct") >>> text_encoder_4 = LlamaForCausalLM.from_pretrained( ... "meta-llama/Meta-Llama-3.1-8B-Instruct", @@ -134,7 +133,7 @@ def save_model_card( ... output_attentions=True, ... torch_dtype=torch.bfloat16, ... ) - + >>> pipe = HiDreamImagePipeline.from_pretrained( ... "HiDream-ai/HiDream-I1-Full", ... scheduler=scheduler, @@ -145,8 +144,8 @@ def save_model_card( >>> pipe.enable_model_cpu_offload() >>> pipe.load_lora_weights(f"{repo_id}") >>> image = pipe(f"{instance_prompt}").images[0] - - + + ``` For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) @@ -173,6 +172,7 @@ def save_model_card( model_card = populate_model_card(model_card, tags=tags) model_card.save(os.path.join(repo_folder, "README.md")) + def load_text_encoders(class_one, class_two, class_three): text_encoder_one = class_one.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant @@ -184,20 +184,21 @@ def load_text_encoders(class_one, class_two, class_three): args.pretrained_model_name_or_path, subfolder="text_encoder_3", revision=args.revision, variant=args.variant ) text_encoder_four = LlamaForCausalLM.from_pretrained( - "meta-llama/Meta-Llama-3.1-8B-Instruct", + "meta-llama/Meta-Llama-3.1-8B-Instruct", output_hidden_states=True, output_attentions=True, - torch_dtype=torch.bfloat16,) + torch_dtype=torch.bfloat16, ) return text_encoder_one, text_encoder_two, text_encoder_three, text_encoder_four + def log_validation( - pipeline, - args, - accelerator, - pipeline_args, - epoch, - torch_dtype, - is_final_validation=False, + pipeline, + args, + accelerator, + pipeline_args, + epoch, + torch_dtype, + is_final_validation=False, ): logger.info( f"Running validation... \n Generating {args.num_validation_images} images with prompt:" @@ -212,8 +213,9 @@ def log_validation( # pre-calculate prompt embeds, pooled prompt embeds, text ids because t5 does not support autocast with torch.no_grad(): - prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds= pipeline.encode_prompt( - pipeline_args["prompt"], prompt_2=pipeline_args["prompt"], prompt_3=pipeline_args["prompt"], prompt_4=pipeline_args["prompt"] + prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = pipeline.encode_prompt( + pipeline_args["prompt"], prompt_2=pipeline_args["prompt"], prompt_3=pipeline_args["prompt"], + prompt_4=pipeline_args["prompt"] ) images = [] for _ in range(args.num_validation_images): @@ -246,8 +248,9 @@ def log_validation( return images + def import_model_class_from_model_name_or_path( - pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" + pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" ): text_encoder_config = PretrainedConfig.from_pretrained( pretrained_model_name_or_path, subfolder=subfolder, revision=revision @@ -264,6 +267,7 @@ def import_model_class_from_model_name_or_path( else: raise ValueError(f"{model_class} is not supported.") + def parse_args(input_args=None): parser = argparse.ArgumentParser(description="Simple example of a training script.") parser.add_argument( @@ -321,8 +325,8 @@ def parse_args(input_args=None): type=str, default="image", help="The column of the dataset containing the target image. By " - "default, the standard Image Dataset maps out 'file_name' " - "to 'image'.", + "default, the standard Image Dataset maps out 'file_name' " + "to 'image'.", ) parser.add_argument( "--caption_column", @@ -570,7 +574,7 @@ def parse_args(input_args=None): type=float, default=None, help="coefficients for computing the Prodigy stepsize using running averages. If set to None, " - "uses the value of square root of beta2. Ignored if optimizer is adamW", + "uses the value of square root of beta2. Ignored if optimizer is adamW", ) parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay") parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params") @@ -601,7 +605,7 @@ def parse_args(input_args=None): type=bool, default=True, help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. " - "Ignored if optimizer is adamW", + "Ignored if optimizer is adamW", ) parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") @@ -708,15 +712,15 @@ class DreamBoothDataset(Dataset): """ def __init__( - self, - instance_data_root, - instance_prompt, - class_prompt, - class_data_root=None, - class_num=None, - size=1024, - repeats=1, - center_crop=False, + self, + instance_data_root, + instance_prompt, + class_prompt, + class_data_root=None, + class_num=None, + size=1024, + repeats=1, + center_crop=False, ): self.size = size self.center_crop = center_crop @@ -903,15 +907,16 @@ def __getitem__(self, index): example["index"] = index return example + def _encode_prompt_with_llama( - text_encoder, - tokenizer, - max_sequence_length=128, - prompt=None, - num_images_per_prompt=1, - device=None, - text_input_ids=None, - attention_mask=None, + text_encoder, + tokenizer, + max_sequence_length=128, + prompt=None, + num_images_per_prompt=1, + device=None, + text_input_ids=None, + attention_mask=None, ): prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) @@ -958,14 +963,14 @@ def _encode_prompt_with_llama( def _encode_prompt_with_t5( - text_encoder, - tokenizer, - max_sequence_length=128, - prompt=None, - num_images_per_prompt=1, - device=None, - text_input_ids=None, - attention_mask=None, + text_encoder, + tokenizer, + max_sequence_length=128, + prompt=None, + num_images_per_prompt=1, + device=None, + text_input_ids=None, + attention_mask=None, ): prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) @@ -1004,13 +1009,13 @@ def _encode_prompt_with_t5( def _encode_prompt_with_clip( - text_encoder, - tokenizer, - prompt: str, - max_sequence_length=128, - device=None, - text_input_ids=None, - num_images_per_prompt: int = 1, + text_encoder, + tokenizer, + prompt: str, + max_sequence_length=128, + device=None, + text_input_ids=None, + num_images_per_prompt: int = 1, ): prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) @@ -1047,14 +1052,14 @@ def _encode_prompt_with_clip( def encode_prompt( - text_encoders, - tokenizers, - prompt: str, - max_sequence_length, - device=None, - num_images_per_prompt: int = 1, - text_input_ids_list=None, - attention_mask_list=None, + text_encoders, + tokenizers, + prompt: str, + max_sequence_length, + device=None, + num_images_per_prompt: int = 1, + text_input_ids_list=None, + attention_mask_list=None, ): prompt = [prompt] if isinstance(prompt, str) else prompt @@ -1104,6 +1109,7 @@ def encode_prompt( return prompt_embeds, pooled_prompt_embeds + def main(args): if args.report_to == "wandb" and args.hub_token is not None: raise ValueError( @@ -1181,7 +1187,7 @@ def main(args): pipeline.to(accelerator.device) for example in tqdm( - sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process + sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process ): images = pipeline(example["prompt"]).images @@ -1243,7 +1249,9 @@ def main(args): args.pretrained_model_name_or_path, subfolder="scheduler", revision=args.revision ) noise_scheduler_copy = copy.deepcopy(noise_scheduler) - text_encoder_one, text_encoder_two, text_encoder_three, text_encoder_four = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three) + text_encoder_one, text_encoder_two, text_encoder_three, text_encoder_four = load_text_encoders(text_encoder_cls_one, + text_encoder_cls_two, + text_encoder_cls_three) vae = AutoencoderKL.from_pretrained( args.pretrained_model_name_or_path, @@ -1371,7 +1379,7 @@ def load_model_hook(models, input_dir): if args.scale_lr: args.learning_rate = ( - args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) # Make sure the trainable params are in float32. @@ -1466,6 +1474,7 @@ def load_model_hook(models, input_dir): tokenizers = [tokenizer_one, tokenizer_two, tokenizer_three, tokenizer_four] text_encoders = [text_encoder_one, text_encoder_two, text_encoder_three, text_encoder_four] + def compute_text_embeddings(prompt, text_encoders, tokenizers): with torch.no_grad(): prompt_embeds, pooled_prompt_embeds = encode_prompt( @@ -1684,12 +1693,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # Add noise according to flow matching. # zt = (1 - texp) * x + texp * z1 sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) - noisy_model_input = (1.0 - sigmas) * noise + sigmas * model_input + noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise # Predict the noise residual - print("noisy_model_input", noisy_model_input.shape) - print("prompt_embeds", prompt_embeds[0].shape, prompt_embeds[1].shape) - print("pooled_prompt_embeds", pooled_prompt_embeds.shape) + # print("noisy_model_input", noisy_model_input.shape) + # print("prompt_embeds", prompt_embeds[0].shape, prompt_embeds[1].shape) + # print("pooled_prompt_embeds", pooled_prompt_embeds.shape) model_pred = transformer( hidden_states=noisy_model_input, @@ -1699,15 +1708,14 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): img_sizes=img_sizes, img_ids=img_ids, return_dict=False, - )[0] - print("model_pred", model_pred.shape) + )[0] * -1 + # print("model_pred", model_pred.shape) # these weighting schemes use a uniform timestep sampling # and instead post-weight the loss weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) # flow matching loss via reverse ODE (HiDream uses an inverted flow field as compared to the more typical `noise - model_input` target) - target = model_input - noise - + target = noise - model_input if args.with_prior_preservation: # Chunk the noise and model_pred into two parts and compute the loss on each part separately. model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) @@ -1825,10 +1833,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # Load previous pipeline tokenizer_4 = PreTrainedTokenizerFast.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct") text_encoder_4 = LlamaForCausalLM.from_pretrained( - "meta-llama/Meta-Llama-3.1-8B-Instruct", - output_hidden_states=True, - output_attentions=True, - torch_dtype=torch.bfloat16, + "meta-llama/Meta-Llama-3.1-8B-Instruct", + output_hidden_states=True, + output_attentions=True, + torch_dtype=torch.bfloat16, ) pipeline = HiDreamImagePipeline.from_pretrained( args.pretrained_model_name_or_path, From 6043d9d19f852778757d9b4eafa706fc906f783e Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Thu, 17 Apr 2025 17:44:41 +0300 Subject: [PATCH 31/76] fix shift!!! --- examples/dreambooth/train_dreambooth_lora_hidream.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index 615eac0f7eea..a3c6d6d4a0b0 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -1246,7 +1246,7 @@ def main(args): # Load scheduler and models noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( - args.pretrained_model_name_or_path, subfolder="scheduler", revision=args.revision + args.pretrained_model_name_or_path, subfolder="scheduler", revision=args.revision, shift=3.0 ) noise_scheduler_copy = copy.deepcopy(noise_scheduler) text_encoder_one, text_encoder_two, text_encoder_three, text_encoder_four = load_text_encoders(text_encoder_cls_one, From 13e6f0da50d1d549ab2aeb24148832e53f21daf3 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Thu, 17 Apr 2025 18:10:10 +0300 Subject: [PATCH 32/76] revert unpatchify changes (for now) --- .../transformers/transformer_hidream_image.py | 24 +++++-------------- 1 file changed, 6 insertions(+), 18 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_hidream_image.py b/src/diffusers/models/transformers/transformer_hidream_image.py index fc1f73de2583..30f6c3a34c9d 100644 --- a/src/diffusers/models/transformers/transformer_hidream_image.py +++ b/src/diffusers/models/transformers/transformer_hidream_image.py @@ -663,24 +663,12 @@ def __init__( def unpatchify(self, x: torch.Tensor, img_sizes: List[Tuple[int, int]], is_training: bool) -> List[torch.Tensor]: if is_training: - pH, pW = img_sizes[0] # Get target patch height/width - expected_S = pH * pW - # Ensure sequence length S matches expected H*W before rearranging - # This might require padding/truncating x if training uses fixed max_seq - current_S = x.shape[1] - if current_S > expected_S: - x = x[:, :expected_S, :] # Use only the relevant part of the sequence - elif current_S < expected_S: - # This case is less likely if padding happens earlier, but handle defensively - raise ValueError( - f"Sequence length {current_S} is less than expected {expected_S} ({pH}x{pW}) during unpatchify.") - - # Corrected line using einops and H, W: - x = einops.rearrange( - x, - 'B (H W) (p1 p2 C) -> B C (H p1) (W p2)', - H=pH, W=pW, - p1=self.config.patch_size, p2=self.config.patch_size + B, S, F = x.shape + C = F // (self.config.patch_size * self.config.patch_size) + x = ( + x.reshape(B, S, self.config.patch_size, self.config.patch_size, C) + .permute(0, 4, 1, 2, 3) + .reshape(B, C, S, self.config.patch_size * self.config.patch_size) ) else: x_arr = [] From ae39434dbe293c0c1457f3b5ae3867253c079767 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Fri, 18 Apr 2025 13:48:52 +0300 Subject: [PATCH 33/76] smol fix --- examples/dreambooth/train_dreambooth_lora_hidream.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index a3c6d6d4a0b0..a9adbce82c02 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -1665,7 +1665,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH)[:, None] img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW)[None, :] img_ids = img_ids.reshape(pH * pW, -1) - img_ids_pad = torch.zeros(self.transformer.max_seq, 3) + img_ids_pad = torch.zeros(transformer.max_seq, 3) img_ids_pad[: pH * pW, :] = img_ids img_sizes = img_sizes.unsqueeze(0).to(model_input.device) From c8932ede14586ff2b7953d3352440f854b62c1cd Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Fri, 18 Apr 2025 10:50:44 +0000 Subject: [PATCH 34/76] Apply style fixes --- .../train_dreambooth_lora_hidream.py | 183 ++++++++++-------- src/diffusers/loaders/__init__.py | 4 +- src/diffusers/loaders/lora_pipeline.py | 2 + 3 files changed, 101 insertions(+), 88 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index a9adbce82c02..959f3f2176a4 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -42,13 +42,12 @@ from torchvision import transforms from torchvision.transforms.functional import crop from tqdm.auto import tqdm -from transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast, PreTrainedTokenizerFast, LlamaForCausalLM +from transformers import CLIPTokenizer, LlamaForCausalLM, PretrainedConfig, PreTrainedTokenizerFast, T5TokenizerFast import diffusers from diffusers import ( AutoencoderKL, FlowMatchEulerDiscreteScheduler, - UniPCMultistepScheduler, HiDreamImagePipeline, HiDreamImageTransformer2DModel, ) @@ -68,6 +67,7 @@ from diffusers.utils.import_utils import is_torch_npu_available from diffusers.utils.torch_utils import is_compiled_module + if is_wandb_available(): import wandb @@ -81,12 +81,12 @@ def save_model_card( - repo_id: str, - images=None, - base_model: str = None, - instance_prompt=None, - validation_prompt=None, - repo_folder=None, + repo_id: str, + images=None, + base_model: str = None, + instance_prompt=None, + validation_prompt=None, + repo_folder=None, ): widget_dict = [] if images is not None: @@ -187,18 +187,19 @@ def load_text_encoders(class_one, class_two, class_three): "meta-llama/Meta-Llama-3.1-8B-Instruct", output_hidden_states=True, output_attentions=True, - torch_dtype=torch.bfloat16, ) + torch_dtype=torch.bfloat16, + ) return text_encoder_one, text_encoder_two, text_encoder_three, text_encoder_four def log_validation( - pipeline, - args, - accelerator, - pipeline_args, - epoch, - torch_dtype, - is_final_validation=False, + pipeline, + args, + accelerator, + pipeline_args, + epoch, + torch_dtype, + is_final_validation=False, ): logger.info( f"Running validation... \n Generating {args.num_validation_images} images with prompt:" @@ -213,9 +214,13 @@ def log_validation( # pre-calculate prompt embeds, pooled prompt embeds, text ids because t5 does not support autocast with torch.no_grad(): - prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = pipeline.encode_prompt( - pipeline_args["prompt"], prompt_2=pipeline_args["prompt"], prompt_3=pipeline_args["prompt"], - prompt_4=pipeline_args["prompt"] + prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = ( + pipeline.encode_prompt( + pipeline_args["prompt"], + prompt_2=pipeline_args["prompt"], + prompt_3=pipeline_args["prompt"], + prompt_4=pipeline_args["prompt"], + ) ) images = [] for _ in range(args.num_validation_images): @@ -224,7 +229,8 @@ def log_validation( prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, - negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, generator=generator + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + generator=generator, ).images[0] images.append(image) @@ -250,7 +256,7 @@ def log_validation( def import_model_class_from_model_name_or_path( - pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" + pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" ): text_encoder_config = PretrainedConfig.from_pretrained( pretrained_model_name_or_path, subfolder=subfolder, revision=revision @@ -325,8 +331,8 @@ def parse_args(input_args=None): type=str, default="image", help="The column of the dataset containing the target image. By " - "default, the standard Image Dataset maps out 'file_name' " - "to 'image'.", + "default, the standard Image Dataset maps out 'file_name' " + "to 'image'.", ) parser.add_argument( "--caption_column", @@ -574,7 +580,7 @@ def parse_args(input_args=None): type=float, default=None, help="coefficients for computing the Prodigy stepsize using running averages. If set to None, " - "uses the value of square root of beta2. Ignored if optimizer is adamW", + "uses the value of square root of beta2. Ignored if optimizer is adamW", ) parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay") parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params") @@ -605,7 +611,7 @@ def parse_args(input_args=None): type=bool, default=True, help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. " - "Ignored if optimizer is adamW", + "Ignored if optimizer is adamW", ) parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") @@ -712,15 +718,15 @@ class DreamBoothDataset(Dataset): """ def __init__( - self, - instance_data_root, - instance_prompt, - class_prompt, - class_data_root=None, - class_num=None, - size=1024, - repeats=1, - center_crop=False, + self, + instance_data_root, + instance_prompt, + class_prompt, + class_data_root=None, + class_num=None, + size=1024, + repeats=1, + center_crop=False, ): self.size = size self.center_crop = center_crop @@ -909,14 +915,14 @@ def __getitem__(self, index): def _encode_prompt_with_llama( - text_encoder, - tokenizer, - max_sequence_length=128, - prompt=None, - num_images_per_prompt=1, - device=None, - text_input_ids=None, - attention_mask=None, + text_encoder, + tokenizer, + max_sequence_length=128, + prompt=None, + num_images_per_prompt=1, + device=None, + text_input_ids=None, + attention_mask=None, ): prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) @@ -963,14 +969,14 @@ def _encode_prompt_with_llama( def _encode_prompt_with_t5( - text_encoder, - tokenizer, - max_sequence_length=128, - prompt=None, - num_images_per_prompt=1, - device=None, - text_input_ids=None, - attention_mask=None, + text_encoder, + tokenizer, + max_sequence_length=128, + prompt=None, + num_images_per_prompt=1, + device=None, + text_input_ids=None, + attention_mask=None, ): prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) @@ -1009,13 +1015,13 @@ def _encode_prompt_with_t5( def _encode_prompt_with_clip( - text_encoder, - tokenizer, - prompt: str, - max_sequence_length=128, - device=None, - text_input_ids=None, - num_images_per_prompt: int = 1, + text_encoder, + tokenizer, + prompt: str, + max_sequence_length=128, + device=None, + text_input_ids=None, + num_images_per_prompt: int = 1, ): prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) @@ -1052,14 +1058,14 @@ def _encode_prompt_with_clip( def encode_prompt( - text_encoders, - tokenizers, - prompt: str, - max_sequence_length, - device=None, - num_images_per_prompt: int = 1, - text_input_ids_list=None, - attention_mask_list=None, + text_encoders, + tokenizers, + prompt: str, + max_sequence_length, + device=None, + num_images_per_prompt: int = 1, + text_input_ids_list=None, + attention_mask_list=None, ): prompt = [prompt] if isinstance(prompt, str) else prompt @@ -1187,7 +1193,7 @@ def main(args): pipeline.to(accelerator.device) for example in tqdm( - sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process + sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process ): images = pipeline(example["prompt"]).images @@ -1249,9 +1255,9 @@ def main(args): args.pretrained_model_name_or_path, subfolder="scheduler", revision=args.revision, shift=3.0 ) noise_scheduler_copy = copy.deepcopy(noise_scheduler) - text_encoder_one, text_encoder_two, text_encoder_three, text_encoder_four = load_text_encoders(text_encoder_cls_one, - text_encoder_cls_two, - text_encoder_cls_three) + text_encoder_one, text_encoder_two, text_encoder_three, text_encoder_four = load_text_encoders( + text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three + ) vae = AutoencoderKL.from_pretrained( args.pretrained_model_name_or_path, @@ -1348,7 +1354,7 @@ def load_model_hook(models, input_dir): lora_state_dict = HiDreamImagePipeline.lora_state_dict(input_dir) transformer_state_dict = { - f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") + f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.") } transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") @@ -1379,7 +1385,7 @@ def load_model_hook(models, input_dir): if args.scale_lr: args.learning_rate = ( - args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) # Make sure the trainable params are in float32. @@ -1489,15 +1495,17 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid # the redundant encoding. if not train_dataset.custom_instance_prompts: - instance_prompt_hidden_states, instance_pooled_prompt_embeds, = compute_text_embeddings( - args.instance_prompt, text_encoders, tokenizers - ) + ( + instance_prompt_hidden_states, + instance_pooled_prompt_embeds, + ) = compute_text_embeddings(args.instance_prompt, text_encoders, tokenizers) # Handle class prompt for prior-preservation. if args.with_prior_preservation: - class_prompt_hidden_states, class_pooled_prompt_embeds, = compute_text_embeddings( - args.class_prompt, text_encoders, tokenizers - ) + ( + class_prompt_hidden_states, + class_pooled_prompt_embeds, + ) = compute_text_embeddings(args.class_prompt, text_encoders, tokenizers) # Clear the memory here if not train_dataset.custom_instance_prompts: @@ -1700,15 +1708,18 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # print("prompt_embeds", prompt_embeds[0].shape, prompt_embeds[1].shape) # print("pooled_prompt_embeds", pooled_prompt_embeds.shape) - model_pred = transformer( - hidden_states=noisy_model_input, - encoder_hidden_states=prompt_embeds, - pooled_embeds=pooled_prompt_embeds, - timesteps=timesteps, - img_sizes=img_sizes, - img_ids=img_ids, - return_dict=False, - )[0] * -1 + model_pred = ( + transformer( + hidden_states=noisy_model_input, + encoder_hidden_states=prompt_embeds, + pooled_embeds=pooled_prompt_embeds, + timesteps=timesteps, + img_sizes=img_sizes, + img_ids=img_ids, + return_dict=False, + )[0] + * -1 + ) # print("model_pred", model_pred.shape) # these weighting schemes use a uniform timestep sampling # and instead post-weight the loss diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index 7bb61dc8e477..84c6d9f32c66 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -77,7 +77,7 @@ def text_encoder_attn_modules(text_encoder): "SanaLoraLoaderMixin", "Lumina2LoraLoaderMixin", "WanLoraLoaderMixin", - "HiDreamImageLoraLoaderMixin" + "HiDreamImageLoraLoaderMixin", ] _import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"] _import_structure["ip_adapter"] = [ @@ -109,6 +109,7 @@ def text_encoder_attn_modules(text_encoder): CogVideoXLoraLoaderMixin, CogView4LoraLoaderMixin, FluxLoraLoaderMixin, + HiDreamImageLoraLoaderMixin, HunyuanVideoLoraLoaderMixin, LoraLoaderMixin, LTXVideoLoraLoaderMixin, @@ -119,7 +120,6 @@ def text_encoder_attn_modules(text_encoder): StableDiffusionLoraLoaderMixin, StableDiffusionXLLoraLoaderMixin, WanLoraLoaderMixin, - HiDreamImageLoraLoaderMixin, ) from .single_file import FromSingleFileMixin from .textual_inversion import TextualInversionLoaderMixin diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 551e7f84422b..43a6a6d1ce46 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -5359,6 +5359,7 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): """ super().unfuse_lora(components=components, **kwargs) + class HiDreamImageLoraLoaderMixin(LoraBaseMixin): r""" Load LoRA layers into [`HiDreamImageTransformer2DModel`]. Specific to [`HiDreamImagePipeline`]. @@ -5695,6 +5696,7 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): """ super().unfuse_lora(components=components, **kwargs) + class LoraLoaderMixin(StableDiffusionLoraLoaderMixin): def __init__(self, *args, **kwargs): deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead." From c8ac7d5faaa91c96136f31f83ad7e5a27483ec99 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Fri, 18 Apr 2025 15:35:26 +0300 Subject: [PATCH 35/76] workaround moe training --- .../transformers/transformer_hidream_image.py | 26 +++++++++++++++---- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_hidream_image.py b/src/diffusers/models/transformers/transformer_hidream_image.py index 30f6c3a34c9d..eb46b1dcc3b0 100644 --- a/src/diffusers/models/transformers/transformer_hidream_image.py +++ b/src/diffusers/models/transformers/transformer_hidream_image.py @@ -275,7 +275,7 @@ def __call__( # Modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py class MoEGate(nn.Module): - def __init__(self, embed_dim, num_routed_experts=4, num_activated_experts=2, aux_loss_alpha=0.01): + def __init__(self, embed_dim, num_routed_experts=4, num_activated_experts=2, aux_loss_alpha=0.01, _force_inference_output=False): super().__init__() self.top_k = num_activated_experts self.n_routed_experts = num_routed_experts @@ -289,6 +289,9 @@ def __init__(self, embed_dim, num_routed_experts=4, num_activated_experts=2, aux self.gating_dim = embed_dim self.weight = nn.Parameter(torch.randn(self.n_routed_experts, self.gating_dim) / embed_dim**0.5) + self._force_inference_output = _force_inference_output + + def forward(self, hidden_states): bsz, seq_len, h = hidden_states.shape # print(bsz, seq_len, h) @@ -309,7 +312,7 @@ def forward(self, hidden_states): topk_weight = topk_weight / denominator ### expert-level computation auxiliary loss - if self.training and self.alpha > 0.0: + if self.training and self.alpha > 0.0 and not self._force_inference_output: scores_for_aux = scores aux_topk = self.top_k # always compute aux loss based on the naive greedy topk method @@ -341,14 +344,19 @@ def __init__( hidden_dim: int, num_routed_experts: int, num_activated_experts: int, + _force_inference_output: bool = False, ): super().__init__() self.shared_experts = HiDreamImageFeedForwardSwiGLU(dim, hidden_dim // 2) self.experts = nn.ModuleList( [HiDreamImageFeedForwardSwiGLU(dim, hidden_dim) for i in range(num_routed_experts)] ) + self._force_inference_output = _force_inference_output self.gate = MoEGate( - embed_dim=dim, num_routed_experts=num_routed_experts, num_activated_experts=num_activated_experts + embed_dim=dim, + num_routed_experts=num_routed_experts, + num_activated_experts=num_activated_experts, + _force_inference_output=_force_inference_output ) self.num_activated_experts = num_activated_experts @@ -359,7 +367,8 @@ def forward(self, x): topk_idx, topk_weight, aux_loss = self.gate(x) x = x.view(-1, x.shape[-1]) flat_topk_idx = topk_idx.view(-1) - if self.training: + print("forward 2: self.config._force_inference_output", self._force_inference_output) + if self.training and not self._force_inference_output: x = x.repeat_interleave(self.num_activated_experts, dim=0) y = torch.empty_like(x, dtype=wtype) for i, expert in enumerate(self.experts): @@ -413,6 +422,7 @@ def __init__( attention_head_dim: int, num_routed_experts: int = 4, num_activated_experts: int = 2, + _force_inference_output: bool = False, ): super().__init__() self.num_attention_heads = num_attention_heads @@ -436,6 +446,7 @@ def __init__( hidden_dim=4 * dim, num_routed_experts=num_routed_experts, num_activated_experts=num_activated_experts, + _force_inference_output=_force_inference_output ) else: self.ff_i = HiDreamImageFeedForwardSwiGLU(dim=dim, hidden_dim=4 * dim) @@ -480,6 +491,7 @@ def __init__( attention_head_dim: int, num_routed_experts: int = 4, num_activated_experts: int = 2, + _force_inference_output: bool = False, ): super().__init__() self.num_attention_heads = num_attention_heads @@ -504,6 +516,7 @@ def __init__( hidden_dim=4 * dim, num_routed_experts=num_routed_experts, num_activated_experts=num_activated_experts, + _force_inference_output=_force_inference_output ) else: self.ff_i = HiDreamImageFeedForwardSwiGLU(dim=dim, hidden_dim=4 * dim) @@ -606,6 +619,7 @@ def __init__( axes_dims_rope: Tuple[int, int] = (32, 32), max_resolution: Tuple[int, int] = (128, 128), llama_layers: List[int] = None, + force_inference_output: bool = False ): super().__init__() self.out_channels = out_channels or in_channels @@ -629,6 +643,7 @@ def __init__( attention_head_dim=attention_head_dim, num_routed_experts=num_routed_experts, num_activated_experts=num_activated_experts, + _force_inference_output=force_inference_output ) ) for _ in range(num_layers) @@ -644,6 +659,7 @@ def __init__( attention_head_dim=attention_head_dim, num_routed_experts=num_routed_experts, num_activated_experts=num_activated_experts, + _force_inference_output=force_inference_output ) ) for _ in range(num_single_layers) @@ -662,7 +678,7 @@ def __init__( self.gradient_checkpointing = False def unpatchify(self, x: torch.Tensor, img_sizes: List[Tuple[int, int]], is_training: bool) -> List[torch.Tensor]: - if is_training: + if is_training and not self.config.force_inference_output: B, S, F = x.shape C = F // (self.config.patch_size * self.config.patch_size) x = ( From bf7ace6d864f25931d6311cbd9446560bdfabb31 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Fri, 18 Apr 2025 15:41:20 +0300 Subject: [PATCH 36/76] workaround moe training --- examples/dreambooth/train_dreambooth_lora_hidream.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index 959f3f2176a4..1535ddf75a47 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -1266,7 +1266,7 @@ def main(args): variant=args.variant, ) transformer = HiDreamImageTransformer2DModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant + args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant, force_inference_output=True ) # We only train the additional adapter LoRA layers From 2fcb17ddf78cc48e4c9656a85f4d34dffc652e3a Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Fri, 18 Apr 2025 15:41:44 +0300 Subject: [PATCH 37/76] remove prints --- src/diffusers/models/transformers/transformer_hidream_image.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_hidream_image.py b/src/diffusers/models/transformers/transformer_hidream_image.py index eb46b1dcc3b0..4ab6c733aaf5 100644 --- a/src/diffusers/models/transformers/transformer_hidream_image.py +++ b/src/diffusers/models/transformers/transformer_hidream_image.py @@ -294,7 +294,6 @@ def __init__(self, embed_dim, num_routed_experts=4, num_activated_experts=2, aux def forward(self, hidden_states): bsz, seq_len, h = hidden_states.shape - # print(bsz, seq_len, h) ### compute gating score hidden_states = hidden_states.view(-1, h) logits = F.linear(hidden_states, self.weight, None) @@ -367,7 +366,6 @@ def forward(self, x): topk_idx, topk_weight, aux_loss = self.gate(x) x = x.view(-1, x.shape[-1]) flat_topk_idx = topk_idx.view(-1) - print("forward 2: self.config._force_inference_output", self._force_inference_output) if self.training and not self._force_inference_output: x = x.repeat_interleave(self.num_activated_experts, dim=0) y = torch.empty_like(x, dtype=wtype) From 80f13beb48b1270a0580edcd01a8e8e4714917fd Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Fri, 18 Apr 2025 16:03:34 +0300 Subject: [PATCH 38/76] to reduce some memory, keep vae in `weight_dtype` same as we have for flux (as it's the same vae) https://github.com/huggingface/diffusers/blob/bbd0c161b55ba2234304f1e6325832dd69c60565/examples/dreambooth/train_dreambooth_lora_flux.py#L1207 --- examples/dreambooth/train_dreambooth_lora_hidream.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index 1535ddf75a47..32071f2f6ffa 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -1291,8 +1291,7 @@ def main(args): "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." ) - # keep VAE in FP32 to ensure numerical stability. - vae.to(dtype=torch.float32) + vae.to(dtype=weight_dtype) transformer.to(accelerator.device, dtype=weight_dtype) text_encoder_one.to(accelerator.device, dtype=weight_dtype) text_encoder_two.to(accelerator.device, dtype=weight_dtype) From b8039c94ae1f7f109e7f00b9d7efe16dd1f6c5a2 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Fri, 18 Apr 2025 16:41:31 +0300 Subject: [PATCH 39/76] refactor to align with HiDream refactor --- .../train_dreambooth_lora_hidream.py | 54 +++++++++++-------- 1 file changed, 31 insertions(+), 23 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index 32071f2f6ffa..0515ebf646f9 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -214,7 +214,13 @@ def log_validation( # pre-calculate prompt embeds, pooled prompt embeds, text ids because t5 does not support autocast with torch.no_grad(): - prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = ( + + ( prompt_embeds_t5, + negative_prompt_embeds_t5, + prompt_embeds_llama3, + negative_prompt_embeds_llama3, + pooled_prompt_embeds, + negative_pooled_prompt_embeds ) = ( pipeline.encode_prompt( pipeline_args["prompt"], prompt_2=pipeline_args["prompt"], @@ -226,8 +232,10 @@ def log_validation( for _ in range(args.num_validation_images): with autocast_ctx: image = pipeline( - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, + prompt_embeds_t5=prompt_embeds_t5, + prompt_embeds_llama3=prompt_embeds_llama3, + negative_prompt_embeds_t5=negative_prompt_embeds_t5, + negative_prompt_embeds_llama3=negative_prompt_embeds_llama3, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, generator=generator, @@ -1111,9 +1119,8 @@ def encode_prompt( attention_mask=attention_mask_list[1] if attention_mask_list else None, ) - prompt_embeds = [t5_prompt_embeds, llama3_prompt_embeds] - return prompt_embeds, pooled_prompt_embeds + return t5_prompt_embeds, llama3_prompt_embeds, pooled_prompt_embeds def main(args): @@ -1482,27 +1489,29 @@ def load_model_hook(models, input_dir): def compute_text_embeddings(prompt, text_encoders, tokenizers): with torch.no_grad(): - prompt_embeds, pooled_prompt_embeds = encode_prompt( + t5_prompt_embeds, llama3_prompt_embeds, pooled_prompt_embeds = encode_prompt( text_encoders, tokenizers, prompt, args.max_sequence_length ) - prompt_embeds[0] = prompt_embeds[0].to(accelerator.device) - prompt_embeds[1] = prompt_embeds[1].to(accelerator.device) + t5_prompt_embeds = t5_prompt_embeds.to(accelerator.device) + llama3_prompt_embeds = llama3_prompt_embeds.to(accelerator.device) pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device) - return prompt_embeds, pooled_prompt_embeds + return t5_prompt_embeds, llama3_prompt_embeds, pooled_prompt_embeds # If no type of tuning is done on the text_encoder and custom instance prompts are NOT # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid # the redundant encoding. if not train_dataset.custom_instance_prompts: ( - instance_prompt_hidden_states, + instance_prompt_hidden_states_t5, + instance_prompt_hidden_states_llama3, instance_pooled_prompt_embeds, ) = compute_text_embeddings(args.instance_prompt, text_encoders, tokenizers) # Handle class prompt for prior-preservation. if args.with_prior_preservation: ( - class_prompt_hidden_states, + class_prompt_hidden_states_t5, + class_prompt_hidden_states_llama3, class_pooled_prompt_embeds, ) = compute_text_embeddings(args.class_prompt, text_encoders, tokenizers) @@ -1519,10 +1528,12 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # pack the statically computed variables appropriately here. This is so that we don't # have to pass them to the dataloader. if not train_dataset.custom_instance_prompts: - prompt_embeds = instance_prompt_hidden_states + t5_prompt_embeds = instance_prompt_hidden_states_t5 + llama3_prompt_embeds = instance_prompt_hidden_states_llama3 pooled_prompt_embeds = instance_pooled_prompt_embeds if args.with_prior_preservation: - prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0) + t5_prompt_embeds = torch.cat([instance_prompt_hidden_states_t5, class_prompt_hidden_states_t5], dim=0) + llama3_prompt_embeds = torch.cat([instance_prompt_hidden_states_llama3, class_prompt_hidden_states_llama3], dim=0) pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, class_pooled_prompt_embeds], dim=0) vae_config_scaling_factor = vae.config.scaling_factor @@ -1646,10 +1657,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): with accelerator.accumulate(models_to_accumulate): # encode batch prompts when custom prompts are provided for each image - if train_dataset.custom_instance_prompts: - prompt_embeds, pooled_prompt_embeds = compute_text_embeddings(prompts, text_encoders, tokenizers) + t5_prompt_embeds, llama3_prompt_embeds, pooled_prompt_embeds = compute_text_embeddings(prompts, text_encoders, tokenizers) else: - prompt_embeds[0] = prompt_embeds[0].repeat(len(prompts), 1, 1) - prompt_embeds[1] = prompt_embeds[1].repeat(1, len(prompts), 1, 1) + t5_prompt_embeds = t5_prompt_embeds.repeat(len(prompts), 1, 1) + llama3_prompt_embeds = llama3_prompt_embeds.repeat(1, len(prompts), 1, 1) pooled_prompt_embeds = pooled_prompt_embeds.repeat(len(prompts), 1) # Convert images to latent space if args.cache_latents: @@ -1703,18 +1714,15 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise # Predict the noise residual - # print("noisy_model_input", noisy_model_input.shape) - # print("prompt_embeds", prompt_embeds[0].shape, prompt_embeds[1].shape) - # print("pooled_prompt_embeds", pooled_prompt_embeds.shape) - model_pred = ( transformer( hidden_states=noisy_model_input, - encoder_hidden_states=prompt_embeds, + encoder_hidden_states_t5=prompt_embeds_t5, + encoder_hidden_states_llama3=prompt_embeds_llama3, pooled_embeds=pooled_prompt_embeds, timesteps=timesteps, - img_sizes=img_sizes, - img_ids=img_ids, + # img_sizes=img_sizes, + # img_ids=img_ids, return_dict=False, )[0] * -1 From c331597b4b35b0f38fb462677acdbb9684b523d7 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Fri, 18 Apr 2025 16:45:06 +0300 Subject: [PATCH 40/76] refactor to align with HiDream refactor --- .../train_dreambooth_lora_hidream.py | 20 ------------------- 1 file changed, 20 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index 0515ebf646f9..61d6533b8b7c 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -1674,24 +1674,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor model_input = model_input.to(dtype=weight_dtype) - if model_input.shape[-2] != model_input.shape[-1]: - B, C, H, W = model_input.shape - pH, pW = H // transformer.config.patch_size, W // transformer.config.patch_size - - img_sizes = torch.tensor([pH, pW], dtype=torch.int64).reshape(-1) - img_ids = torch.zeros(pH, pW, 3) - img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH)[:, None] - img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW)[None, :] - img_ids = img_ids.reshape(pH * pW, -1) - img_ids_pad = torch.zeros(transformer.max_seq, 3) - img_ids_pad[: pH * pW, :] = img_ids - - img_sizes = img_sizes.unsqueeze(0).to(model_input.device) - img_ids = img_ids_pad.unsqueeze(0).to(model_input.device) - - else: - img_sizes = img_ids = None - # Sample noise that we'll add to the latents noise = torch.randn_like(model_input) bsz = model_input.shape[0] @@ -1721,8 +1703,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): encoder_hidden_states_llama3=prompt_embeds_llama3, pooled_embeds=pooled_prompt_embeds, timesteps=timesteps, - # img_sizes=img_sizes, - # img_ids=img_ids, return_dict=False, )[0] * -1 From c32ccccc1a060061f67c6585f631957384361542 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Fri, 18 Apr 2025 16:49:00 +0300 Subject: [PATCH 41/76] refactor to align with HiDream refactor --- examples/dreambooth/train_dreambooth_lora_hidream.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index 61d6533b8b7c..55a08221cf75 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -1699,15 +1699,14 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): model_pred = ( transformer( hidden_states=noisy_model_input, - encoder_hidden_states_t5=prompt_embeds_t5, - encoder_hidden_states_llama3=prompt_embeds_llama3, + encoder_hidden_states_t5=t5_prompt_embeds, + encoder_hidden_states_llama3=llama3_prompt_embeds, pooled_embeds=pooled_prompt_embeds, timesteps=timesteps, return_dict=False, )[0] * -1 ) - # print("model_pred", model_pred.shape) # these weighting schemes use a uniform timestep sampling # and instead post-weight the loss weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) From 6e070b8bc7c4db476f43b7e61cfa5193c0c87703 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Fri, 18 Apr 2025 17:33:45 +0300 Subject: [PATCH 42/76] add support for cpu offloading of text encoders --- .../train_dreambooth_lora_hidream.py | 241 ++---------------- 1 file changed, 26 insertions(+), 215 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index 55a08221cf75..c2ad06730a0d 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -42,7 +42,7 @@ from torchvision import transforms from torchvision.transforms.functional import crop from tqdm.auto import tqdm -from transformers import CLIPTokenizer, LlamaForCausalLM, PretrainedConfig, PreTrainedTokenizerFast, T5TokenizerFast +from transformers import CLIPTokenizer, LlamaForCausalLM, PretrainedConfig, PreTrainedTokenizerFast, T5Tokenizer import diffusers from diffusers import ( @@ -922,207 +922,6 @@ def __getitem__(self, index): return example -def _encode_prompt_with_llama( - text_encoder, - tokenizer, - max_sequence_length=128, - prompt=None, - num_images_per_prompt=1, - device=None, - text_input_ids=None, - attention_mask=None, -): - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) - - if tokenizer is not None: - text_inputs = tokenizer( - prompt, - padding="max_length", - max_length=min(max_sequence_length, tokenizer.model_max_length), - truncation=True, - add_special_tokens=True, - return_tensors="pt", - ) - - text_input_ids = text_inputs.input_ids - attention_mask = text_inputs.attention_mask - - else: - if text_input_ids is None: - raise ValueError("text_input_ids must be provided when the tokenizer is not specified") - if attention_mask is None: - raise ValueError("text_input_ids must be provided when the tokenizer is not specified") - - outputs = text_encoder( - text_input_ids.to(device), - attention_mask=attention_mask.to(device), - output_hidden_states=True, - output_attentions=True, - ) - - if hasattr(text_encoder, "module"): - dtype = text_encoder.module.dtype - else: - dtype = text_encoder.dtype - - prompt_embeds = outputs.hidden_states[1:] - prompt_embeds = torch.stack(prompt_embeds, dim=0).to(dtype=dtype, device=device) - _, _, seq_len, dim = prompt_embeds.shape - - # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, 1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(-1, batch_size * num_images_per_prompt, seq_len, dim) - return prompt_embeds - - -def _encode_prompt_with_t5( - text_encoder, - tokenizer, - max_sequence_length=128, - prompt=None, - num_images_per_prompt=1, - device=None, - text_input_ids=None, - attention_mask=None, -): - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) - - if tokenizer is not None: - text_inputs = tokenizer( - prompt, - padding="max_length", - max_length=min(max_sequence_length, tokenizer.model_max_length), - truncation=True, - add_special_tokens=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - attention_mask = text_inputs.attention_mask - else: - if text_input_ids is None: - raise ValueError("text_input_ids must be provided when the tokenizer is not specified") - if attention_mask is None: - raise ValueError("text_input_ids must be provided when the tokenizer is not specified") - - prompt_embeds = text_encoder(text_input_ids.to(device), attention_mask=attention_mask.to(device))[0] - - if hasattr(text_encoder, "module"): - dtype = text_encoder.module.dtype - else: - dtype = text_encoder.dtype - prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) - _, seq_len, _ = prompt_embeds.shape - - # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - return prompt_embeds - - -def _encode_prompt_with_clip( - text_encoder, - tokenizer, - prompt: str, - max_sequence_length=128, - device=None, - text_input_ids=None, - num_images_per_prompt: int = 1, -): - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) - - if tokenizer is not None: - text_inputs = tokenizer( - prompt, - padding="max_length", - max_length=min(max_sequence_length, 218), - truncation=True, - return_tensors="pt", - ) - - text_input_ids = text_inputs.input_ids - else: - if text_input_ids is None: - raise ValueError("text_input_ids must be provided when the tokenizer is not specified") - - prompt_embeds = text_encoder(text_input_ids.to(device)) - - if hasattr(text_encoder, "module"): - dtype = text_encoder.module.dtype - else: - dtype = text_encoder.dtype - # Use pooled output of CLIPTextModel - prompt_embeds = prompt_embeds[0] - prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) - - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) - - return prompt_embeds - - -def encode_prompt( - text_encoders, - tokenizers, - prompt: str, - max_sequence_length, - device=None, - num_images_per_prompt: int = 1, - text_input_ids_list=None, - attention_mask_list=None, -): - prompt = [prompt] if isinstance(prompt, str) else prompt - - pooled_prompt_embeds_1 = _encode_prompt_with_clip( - text_encoder=text_encoders[0], - tokenizer=tokenizers[0], - prompt=prompt, - device=device if device is not None else text_encoders[0].device, - num_images_per_prompt=num_images_per_prompt, - text_input_ids=text_input_ids_list[0] if text_input_ids_list else None, - ) - - pooled_prompt_embeds_2 = _encode_prompt_with_clip( - text_encoder=text_encoders[1], - tokenizer=tokenizers[1], - prompt=prompt, - device=device if device is not None else text_encoders[1].device, - num_images_per_prompt=num_images_per_prompt, - text_input_ids=text_input_ids_list[1] if text_input_ids_list else None, - ) - - pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_1, pooled_prompt_embeds_2], dim=-1) - - t5_prompt_embeds = _encode_prompt_with_t5( - text_encoder=text_encoders[2], - tokenizer=tokenizers[2], - max_sequence_length=max_sequence_length, - prompt=prompt, - num_images_per_prompt=num_images_per_prompt, - device=device if device is not None else text_encoders[2].device, - text_input_ids=text_input_ids_list[2] if text_input_ids_list else None, - attention_mask=attention_mask_list[0] if attention_mask_list else None, - ) - - llama3_prompt_embeds = _encode_prompt_with_llama( - text_encoder=text_encoders[3], - tokenizer=tokenizers[3], - max_sequence_length=max_sequence_length, - prompt=prompt, - num_images_per_prompt=num_images_per_prompt, - device=device if device is not None else text_encoders[3].device, - text_input_ids=text_input_ids_list[3] if text_input_ids_list else None, - attention_mask=attention_mask_list[1] if attention_mask_list else None, - ) - - - return t5_prompt_embeds, llama3_prompt_embeds, pooled_prompt_embeds - - def main(args): if args.report_to == "wandb" and args.hub_token is not None: raise ValueError( @@ -1234,7 +1033,7 @@ def main(args): subfolder="tokenizer_2", revision=args.revision, ) - tokenizer_three = T5TokenizerFast.from_pretrained( + tokenizer_three = T5Tokenizer.from_pretrained( args.pretrained_model_name_or_path, subfolder="tokenizer_3", revision=args.revision, @@ -1305,6 +1104,21 @@ def main(args): text_encoder_three.to(accelerator.device, dtype=weight_dtype) text_encoder_four.to(accelerator.device, dtype=weight_dtype) + # Initialize a text encoding pipeline and keep it to CPU for now. + text_encoding_pipeline = HiDreamImagePipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=None, + transformer=None, + text_encoder=text_encoder_one, + tokenizer=tokenizer_one, + text_encoder_2=text_encoder_two, + tokenizer_2= tokenizer_two, + text_encoder_3= text_encoder_three, + tokenizer_3= tokenizer_three, + text_encoder_4= text_encoder_four, + tokenizer_4= tokenizer_four, + ) + if args.gradient_checkpointing: transformer.enable_gradient_checkpointing() @@ -1484,17 +1298,14 @@ def load_model_hook(models, input_dir): num_workers=args.dataloader_num_workers, ) - tokenizers = [tokenizer_one, tokenizer_two, tokenizer_three, tokenizer_four] - text_encoders = [text_encoder_one, text_encoder_two, text_encoder_three, text_encoder_four] - - def compute_text_embeddings(prompt, text_encoders, tokenizers): + def compute_text_embeddings(prompt, text_encoding_pipeline): + text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device) with torch.no_grad(): - t5_prompt_embeds, llama3_prompt_embeds, pooled_prompt_embeds = encode_prompt( - text_encoders, tokenizers, prompt, args.max_sequence_length + t5_prompt_embeds,_, llama3_prompt_embeds,_, pooled_prompt_embeds,_ = text_encoding_pipeline.encode_prompt( + prompt=prompt, max_sequence_length=args.max_sequence_length ) - t5_prompt_embeds = t5_prompt_embeds.to(accelerator.device) - llama3_prompt_embeds = llama3_prompt_embeds.to(accelerator.device) - pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device) + if args.offload: + text_encoding_pipeline = text_encoding_pipeline.to("cpu") return t5_prompt_embeds, llama3_prompt_embeds, pooled_prompt_embeds # If no type of tuning is done on the text_encoder and custom instance prompts are NOT @@ -1505,7 +1316,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): instance_prompt_hidden_states_t5, instance_prompt_hidden_states_llama3, instance_pooled_prompt_embeds, - ) = compute_text_embeddings(args.instance_prompt, text_encoders, tokenizers) + ) = compute_text_embeddings(args.instance_prompt, text_encoding_pipeline) # Handle class prompt for prior-preservation. if args.with_prior_preservation: @@ -1513,7 +1324,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): class_prompt_hidden_states_t5, class_prompt_hidden_states_llama3, class_pooled_prompt_embeds, - ) = compute_text_embeddings(args.class_prompt, text_encoders, tokenizers) + ) = compute_text_embeddings(args.class_prompt, text_encoding_pipeline) # Clear the memory here if not train_dataset.custom_instance_prompts: @@ -1657,7 +1468,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): with accelerator.accumulate(models_to_accumulate): # encode batch prompts when custom prompts are provided for each image - if train_dataset.custom_instance_prompts: - t5_prompt_embeds, llama3_prompt_embeds, pooled_prompt_embeds = compute_text_embeddings(prompts, text_encoders, tokenizers) + t5_prompt_embeds, llama3_prompt_embeds, pooled_prompt_embeds = compute_text_embeddings(prompts, text_encoding_pipeline) else: t5_prompt_embeds = t5_prompt_embeds.repeat(len(prompts), 1, 1) llama3_prompt_embeds = llama3_prompt_embeds.repeat(1, len(prompts), 1, 1) From d77e42a1678bf082d9846dfa7986beb7e2ea689c Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Fri, 18 Apr 2025 14:36:07 +0000 Subject: [PATCH 43/76] Apply style fixes --- .../train_dreambooth_lora_hidream.py | 53 +++++++++++-------- .../transformers/transformer_hidream_image.py | 22 +++++--- 2 files changed, 44 insertions(+), 31 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index c2ad06730a0d..838389d8ea32 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -214,19 +214,18 @@ def log_validation( # pre-calculate prompt embeds, pooled prompt embeds, text ids because t5 does not support autocast with torch.no_grad(): - - ( prompt_embeds_t5, - negative_prompt_embeds_t5, - prompt_embeds_llama3, - negative_prompt_embeds_llama3, - pooled_prompt_embeds, - negative_pooled_prompt_embeds ) = ( - pipeline.encode_prompt( - pipeline_args["prompt"], - prompt_2=pipeline_args["prompt"], - prompt_3=pipeline_args["prompt"], - prompt_4=pipeline_args["prompt"], - ) + ( + prompt_embeds_t5, + negative_prompt_embeds_t5, + prompt_embeds_llama3, + negative_prompt_embeds_llama3, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = pipeline.encode_prompt( + pipeline_args["prompt"], + prompt_2=pipeline_args["prompt"], + prompt_3=pipeline_args["prompt"], + prompt_4=pipeline_args["prompt"], ) images = [] for _ in range(args.num_validation_images): @@ -1072,7 +1071,11 @@ def main(args): variant=args.variant, ) transformer = HiDreamImageTransformer2DModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant, force_inference_output=True + args.pretrained_model_name_or_path, + subfolder="transformer", + revision=args.revision, + variant=args.variant, + force_inference_output=True, ) # We only train the additional adapter LoRA layers @@ -1112,11 +1115,11 @@ def main(args): text_encoder=text_encoder_one, tokenizer=tokenizer_one, text_encoder_2=text_encoder_two, - tokenizer_2= tokenizer_two, - text_encoder_3= text_encoder_three, - tokenizer_3= tokenizer_three, - text_encoder_4= text_encoder_four, - tokenizer_4= tokenizer_four, + tokenizer_2=tokenizer_two, + text_encoder_3=text_encoder_three, + tokenizer_3=tokenizer_three, + text_encoder_4=text_encoder_four, + tokenizer_4=tokenizer_four, ) if args.gradient_checkpointing: @@ -1301,8 +1304,8 @@ def load_model_hook(models, input_dir): def compute_text_embeddings(prompt, text_encoding_pipeline): text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device) with torch.no_grad(): - t5_prompt_embeds,_, llama3_prompt_embeds,_, pooled_prompt_embeds,_ = text_encoding_pipeline.encode_prompt( - prompt=prompt, max_sequence_length=args.max_sequence_length + t5_prompt_embeds, _, llama3_prompt_embeds, _, pooled_prompt_embeds, _ = ( + text_encoding_pipeline.encode_prompt(prompt=prompt, max_sequence_length=args.max_sequence_length) ) if args.offload: text_encoding_pipeline = text_encoding_pipeline.to("cpu") @@ -1344,7 +1347,9 @@ def compute_text_embeddings(prompt, text_encoding_pipeline): pooled_prompt_embeds = instance_pooled_prompt_embeds if args.with_prior_preservation: t5_prompt_embeds = torch.cat([instance_prompt_hidden_states_t5, class_prompt_hidden_states_t5], dim=0) - llama3_prompt_embeds = torch.cat([instance_prompt_hidden_states_llama3, class_prompt_hidden_states_llama3], dim=0) + llama3_prompt_embeds = torch.cat( + [instance_prompt_hidden_states_llama3, class_prompt_hidden_states_llama3], dim=0 + ) pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, class_pooled_prompt_embeds], dim=0) vae_config_scaling_factor = vae.config.scaling_factor @@ -1468,7 +1473,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): with accelerator.accumulate(models_to_accumulate): # encode batch prompts when custom prompts are provided for each image - if train_dataset.custom_instance_prompts: - t5_prompt_embeds, llama3_prompt_embeds, pooled_prompt_embeds = compute_text_embeddings(prompts, text_encoding_pipeline) + t5_prompt_embeds, llama3_prompt_embeds, pooled_prompt_embeds = compute_text_embeddings( + prompts, text_encoding_pipeline + ) else: t5_prompt_embeds = t5_prompt_embeds.repeat(len(prompts), 1, 1) llama3_prompt_embeds = llama3_prompt_embeds.repeat(1, len(prompts), 1, 1) diff --git a/src/diffusers/models/transformers/transformer_hidream_image.py b/src/diffusers/models/transformers/transformer_hidream_image.py index 4ab6c733aaf5..bb3f93b7b95c 100644 --- a/src/diffusers/models/transformers/transformer_hidream_image.py +++ b/src/diffusers/models/transformers/transformer_hidream_image.py @@ -275,7 +275,14 @@ def __call__( # Modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py class MoEGate(nn.Module): - def __init__(self, embed_dim, num_routed_experts=4, num_activated_experts=2, aux_loss_alpha=0.01, _force_inference_output=False): + def __init__( + self, + embed_dim, + num_routed_experts=4, + num_activated_experts=2, + aux_loss_alpha=0.01, + _force_inference_output=False, + ): super().__init__() self.top_k = num_activated_experts self.n_routed_experts = num_routed_experts @@ -291,7 +298,6 @@ def __init__(self, embed_dim, num_routed_experts=4, num_activated_experts=2, aux self._force_inference_output = _force_inference_output - def forward(self, hidden_states): bsz, seq_len, h = hidden_states.shape ### compute gating score @@ -355,7 +361,7 @@ def __init__( embed_dim=dim, num_routed_experts=num_routed_experts, num_activated_experts=num_activated_experts, - _force_inference_output=_force_inference_output + _force_inference_output=_force_inference_output, ) self.num_activated_experts = num_activated_experts @@ -444,7 +450,7 @@ def __init__( hidden_dim=4 * dim, num_routed_experts=num_routed_experts, num_activated_experts=num_activated_experts, - _force_inference_output=_force_inference_output + _force_inference_output=_force_inference_output, ) else: self.ff_i = HiDreamImageFeedForwardSwiGLU(dim=dim, hidden_dim=4 * dim) @@ -514,7 +520,7 @@ def __init__( hidden_dim=4 * dim, num_routed_experts=num_routed_experts, num_activated_experts=num_activated_experts, - _force_inference_output=_force_inference_output + _force_inference_output=_force_inference_output, ) else: self.ff_i = HiDreamImageFeedForwardSwiGLU(dim=dim, hidden_dim=4 * dim) @@ -617,7 +623,7 @@ def __init__( axes_dims_rope: Tuple[int, int] = (32, 32), max_resolution: Tuple[int, int] = (128, 128), llama_layers: List[int] = None, - force_inference_output: bool = False + force_inference_output: bool = False, ): super().__init__() self.out_channels = out_channels or in_channels @@ -641,7 +647,7 @@ def __init__( attention_head_dim=attention_head_dim, num_routed_experts=num_routed_experts, num_activated_experts=num_activated_experts, - _force_inference_output=force_inference_output + _force_inference_output=force_inference_output, ) ) for _ in range(num_layers) @@ -657,7 +663,7 @@ def __init__( attention_head_dim=attention_head_dim, num_routed_experts=num_routed_experts, num_activated_experts=num_activated_experts, - _force_inference_output=force_inference_output + _force_inference_output=force_inference_output, ) ) for _ in range(num_single_layers) From 5c8c33903fb259c2ef99bd82abf97d0c63beed54 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Fri, 18 Apr 2025 17:39:38 +0300 Subject: [PATCH 44/76] adjust lr and rank for train example --- examples/dreambooth/README_hidream.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/dreambooth/README_hidream.md b/examples/dreambooth/README_hidream.md index 82359b4785d6..e1ea302e521e 100644 --- a/examples/dreambooth/README_hidream.md +++ b/examples/dreambooth/README_hidream.md @@ -87,7 +87,8 @@ accelerate launch train_dreambooth_lora_hidream.py \ --train_batch_size=1 \ --gradient_accumulation_steps=4 \ --use_8bit_adam \ - --learning_rate=1e-4 \ + --rank=16 \ + --learning_rate=2e-4 \ --report_to="wandb" \ --lr_scheduler="constant" \ --lr_warmup_steps=0 \ From abfb389f2c31a04ceac8798bfd59507b9b868d66 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Fri, 18 Apr 2025 17:50:19 +0300 Subject: [PATCH 45/76] fix copies --- src/diffusers/loaders/lora_pipeline.py | 44 ++++++++++---------------- 1 file changed, 16 insertions(+), 28 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 43a6a6d1ce46..025b9b5987f2 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -5472,8 +5472,12 @@ def lora_state_dict( # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights def load_lora_weights( - self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs - ): + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + adapter_name: Optional[str] = None, + hotswap: bool = False, + **kwargs, + ): """ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See @@ -5490,6 +5494,8 @@ def load_lora_weights( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. kwargs (`dict`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. """ @@ -5515,17 +5521,20 @@ def load_lora_weights( self.load_lora_into_transformer( state_dict, - transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, + transformer=getattr(self, self.transformer_name) if not hasattr(self, + "transformer") else self.transformer, adapter_name=adapter_name, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, ) @classmethod # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->Lumina2Transformer2DModel def load_lora_into_transformer( - cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False - ): + cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, + hotswap: bool = False + ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -5542,29 +5551,8 @@ def load_lora_into_transformer( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap : (`bool`, *optional*) - Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter - in-place. This means that, instead of loading an additional adapter, this will take the existing - adapter weights and replace them with the weights of the new adapter. This can be faster and more - memory efficient. However, the main advantage of hotswapping is that when the model is compiled with - torch.compile, loading the new adapter does not require recompilation of the model. When using - hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. - - If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need - to call an additional method before loading the adapter: - - ```py - pipeline = ... # load diffusers pipeline - max_rank = ... # the highest rank among all LoRAs that you want to load - # call *before* compiling and loading the LoRA adapter - pipeline.enable_lora_hotswap(target_rank=max_rank) - pipeline.load_lora_weights(file_name) - # optionally compile the model now - ``` - - Note that hotswapping adapters of the text encoder is not yet supported. There are some further - limitations to this technique, which are documented here: - https://huggingface.co/docs/peft/main/en/package_reference/hotswap + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( From a5fe6be763524f826052fc65bd26326891e501fb Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Fri, 18 Apr 2025 14:54:41 +0000 Subject: [PATCH 46/76] Apply style fixes --- src/diffusers/loaders/lora_pipeline.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 025b9b5987f2..7153e459810a 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -5473,11 +5473,11 @@ def lora_state_dict( # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights def load_lora_weights( self, - pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], - adapter_name: Optional[str] = None, - hotswap: bool = False, - **kwargs, - ): + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + adapter_name: Optional[str] = None, + hotswap: bool = False, + **kwargs, + ): """ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See @@ -5521,8 +5521,7 @@ def load_lora_weights( self.load_lora_into_transformer( state_dict, - transformer=getattr(self, self.transformer_name) if not hasattr(self, - "transformer") else self.transformer, + transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, @@ -5532,9 +5531,8 @@ def load_lora_weights( @classmethod # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->Lumina2Transformer2DModel def load_lora_into_transformer( - cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, - hotswap: bool = False - ): + cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False + ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. From 2798d402a19f40f01743c69be222d5dff1505d52 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Sat, 19 Apr 2025 09:04:58 +0300 Subject: [PATCH 47/76] update README --- examples/dreambooth/README_hidream.md | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/examples/dreambooth/README_hidream.md b/examples/dreambooth/README_hidream.md index e1ea302e521e..1e147bcad120 100644 --- a/examples/dreambooth/README_hidream.md +++ b/examples/dreambooth/README_hidream.md @@ -71,6 +71,10 @@ snapshot_download( This will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform. Now, we can launch training using: +> [!NOTE] +> The following training configuration prioritizes lower memory consumption by using gradient checkpointing, 8-bit Adam, latent caching and no validation. +> Additionally, when provided with 'instance_prompt' only and no 'caption_column' (used for custom prompts for each image) +> text embeddings are pre-computed to save memory. ```bash export MODEL_NAME="HiDream-ai/HiDream-I1-Dev" @@ -92,8 +96,9 @@ accelerate launch train_dreambooth_lora_hidream.py \ --report_to="wandb" \ --lr_scheduler="constant" \ --lr_warmup_steps=0 \ - --max_train_steps=500 \ - --validation_prompt="A photo of sks dog in a bucket" \ + --max_train_steps=1000 \ + --cache_latents \ + --gradient_checkpointing \ --validation_epochs=25 \ --seed="0" \ --push_to_hub @@ -115,9 +120,7 @@ To better track our training experiments, we're using the following flags in the Additionally, we welcome you to explore the following CLI arguments: * `--lora_layers`: The transformer modules to apply LoRA training on. Please specify the layers in a comma seperated. E.g. - "to_k,to_q,to_v" will result in lora training of attention layers only. -* `--system_prompt`: A custom system prompt to provide additional personality to the model. -* `--max_sequence_length`: Maximum sequence length to use for text embeddings. - +* `--rank`: The rank of the LoRA layers. The higher the rank, the more parameters are trained. The default is 16. We provide several options for optimizing memory optimization: From ab960c2aadbfc0c174f4d6c9736eceded7216ee0 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Sat, 19 Apr 2025 09:08:28 +0300 Subject: [PATCH 48/76] update README --- examples/dreambooth/README_hidream.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/README_hidream.md b/examples/dreambooth/README_hidream.md index 1e147bcad120..f9315d52bc09 100644 --- a/examples/dreambooth/README_hidream.md +++ b/examples/dreambooth/README_hidream.md @@ -72,7 +72,8 @@ This will also allow us to push the trained LoRA parameters to the Hugging Face Now, we can launch training using: > [!NOTE] -> The following training configuration prioritizes lower memory consumption by using gradient checkpointing, 8-bit Adam, latent caching and no validation. +> The following training configuration prioritizes lower memory consumption by using gradient checkpointing, +> 8-bit Adam optimizer, latent caching, offloading, no validation. > Additionally, when provided with 'instance_prompt' only and no 'caption_column' (used for custom prompts for each image) > text embeddings are pre-computed to save memory. @@ -127,5 +128,5 @@ We provide several options for optimizing memory optimization: * `--offload`: When enabled, we will offload the text encoder and VAE to CPU, when they are not used. * `cache_latents`: When enabled, we will pre-compute the latents from the input images with the VAE and remove the VAE from memory once done. * `--use_8bit_adam`: When enabled, we will use the 8bit version of AdamW provided by the `bitsandbytes` library. - +* `--instance_prompt` and no `--caption_column`: when only an instance prompt is provided, we will pre-compute the text embeddings and remove the text encoders from memory once done. Refer to the [official documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/) of the `HiDreamImagePipeline` to know more about the model. From 52d94213534d8429b1bd5ed217c1db13a0cb6ecc Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Sat, 19 Apr 2025 09:08:37 +0300 Subject: [PATCH 49/76] update README --- examples/dreambooth/README_hidream.md | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/dreambooth/README_hidream.md b/examples/dreambooth/README_hidream.md index f9315d52bc09..a60cf7aa238b 100644 --- a/examples/dreambooth/README_hidream.md +++ b/examples/dreambooth/README_hidream.md @@ -129,4 +129,5 @@ We provide several options for optimizing memory optimization: * `cache_latents`: When enabled, we will pre-compute the latents from the input images with the VAE and remove the VAE from memory once done. * `--use_8bit_adam`: When enabled, we will use the 8bit version of AdamW provided by the `bitsandbytes` library. * `--instance_prompt` and no `--caption_column`: when only an instance prompt is provided, we will pre-compute the text embeddings and remove the text encoders from memory once done. + Refer to the [official documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/) of the `HiDreamImagePipeline` to know more about the model. From fc5eb48fa8af784a2cb3d12ea1f1fdf052226008 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Sat, 19 Apr 2025 09:09:49 +0300 Subject: [PATCH 50/76] fix license --- examples/dreambooth/train_dreambooth_lora_hidream.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index 838389d8ea32..4a1729dda054 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -153,7 +153,7 @@ def save_model_card( model_card = load_or_create_model_card( repo_id_or_path=repo_id, from_training=True, - license="apache-2.0", + license="mit", base_model=base_model, prompt=instance_prompt, model_description=model_description, From a012914833cd21d2bfad58728cbd2ef9aa2ae05d Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Sat, 19 Apr 2025 09:11:30 +0300 Subject: [PATCH 51/76] keep prompt2,3,4 as None in validation --- examples/dreambooth/train_dreambooth_lora_hidream.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index 4a1729dda054..e0275b8370a5 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -223,9 +223,6 @@ def log_validation( negative_pooled_prompt_embeds, ) = pipeline.encode_prompt( pipeline_args["prompt"], - prompt_2=pipeline_args["prompt"], - prompt_3=pipeline_args["prompt"], - prompt_4=pipeline_args["prompt"], ) images = [] for _ in range(args.num_validation_images): From d5b9eccf7e6d17bd077f2b00824086e87287012e Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Sat, 19 Apr 2025 09:12:06 +0300 Subject: [PATCH 52/76] remove reverse ode comment --- examples/dreambooth/train_dreambooth_lora_hidream.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index e0275b8370a5..76349beff214 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -1526,7 +1526,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # and instead post-weight the loss weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) - # flow matching loss via reverse ODE (HiDream uses an inverted flow field as compared to the more typical `noise - model_input` target) + target = noise - model_input if args.with_prior_preservation: # Chunk the noise and model_pred into two parts and compute the loss on each part separately. From f04a13ac65723c0455108e4c1257236d38bd575b Mon Sep 17 00:00:00 2001 From: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com> Date: Sat, 19 Apr 2025 09:13:15 +0300 Subject: [PATCH 53/76] Update examples/dreambooth/train_dreambooth_lora_hidream.py Co-authored-by: Sayak Paul --- .../train_dreambooth_lora_hidream.py | 21 ++++++++----------- 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index 76349beff214..a667cb895142 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -1510,18 +1510,15 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise # Predict the noise residual - - model_pred = ( - transformer( - hidden_states=noisy_model_input, - encoder_hidden_states_t5=t5_prompt_embeds, - encoder_hidden_states_llama3=llama3_prompt_embeds, - pooled_embeds=pooled_prompt_embeds, - timesteps=timesteps, - return_dict=False, - )[0] - * -1 - ) + model_pred = transformer( + hidden_states=noisy_model_input, + encoder_hidden_states_t5=t5_prompt_embeds, + encoder_hidden_states_llama3=llama3_prompt_embeds, + pooled_embeds=pooled_prompt_embeds, + timesteps=timesteps, + return_dict=False, + )[0] + model_pred = model_pred * -1 # these weighting schemes use a uniform timestep sampling # and instead post-weight the loss weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) From 0b75081cac66886ddad2e33b4ecf7bbde14d12f7 Mon Sep 17 00:00:00 2001 From: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com> Date: Sat, 19 Apr 2025 09:20:00 +0300 Subject: [PATCH 54/76] Update examples/dreambooth/train_dreambooth_lora_hidream.py Co-authored-by: Sayak Paul --- examples/dreambooth/train_dreambooth_lora_hidream.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index a667cb895142..5bd244f824b0 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -1097,7 +1097,10 @@ def main(args): "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." ) - vae.to(dtype=weight_dtype) + if not args.offload: + vae.to(dtype=weight_dtype, device=accelerator.device) + else: + vae.to(dtype=weight_dtype) transformer.to(accelerator.device, dtype=weight_dtype) text_encoder_one.to(accelerator.device, dtype=weight_dtype) text_encoder_two.to(accelerator.device, dtype=weight_dtype) From 4db988f9e39ce596a6a1d5b7fbf80c8f459a8894 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Sat, 19 Apr 2025 09:24:54 +0300 Subject: [PATCH 55/76] vae offload change --- examples/dreambooth/train_dreambooth_lora_hidream.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index 5bd244f824b0..98a871f24a42 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -1098,7 +1098,7 @@ def main(args): ) if not args.offload: - vae.to(dtype=weight_dtype, device=accelerator.device) + vae.to(dtype=weight_dtype, device=accelerator.device) else: vae.to(dtype=weight_dtype) transformer.to(accelerator.device, dtype=weight_dtype) @@ -1356,7 +1356,8 @@ def compute_text_embeddings(prompt, text_encoding_pipeline): vae_config_shift_factor = vae.config.shift_factor if args.cache_latents: latents_cache = [] - vae = vae.to(accelerator.device) + if not args.offload: + vae = vae.to(accelerator.device) for batch in tqdm(train_dataloader, desc="Caching latents"): with torch.no_grad(): batch["pixel_values"] = batch["pixel_values"].to( @@ -1484,7 +1485,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if args.cache_latents: model_input = latents_cache[step].sample() else: - vae = vae.to(accelerator.device) + if args.offload: + vae = vae.to(accelerator.device) pixel_values = batch["pixel_values"].to(dtype=vae.dtype) model_input = vae.encode(pixel_values).latent_dist.sample() if args.offload: From 13192a330937c6bf459ec72d1858b0cdc9988a96 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Sat, 19 Apr 2025 09:30:24 +0300 Subject: [PATCH 56/76] fix text encoder offloading --- .../dreambooth/train_dreambooth_lora_hidream.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index 98a871f24a42..83c4bddb9022 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -1099,13 +1099,17 @@ def main(args): if not args.offload: vae.to(dtype=weight_dtype, device=accelerator.device) + text_encoder_one.to(dtype=weight_dtype, device=accelerator.device) + text_encoder_two.to(dtype=weight_dtype, device=accelerator.device) + text_encoder_three.to(dtype=weight_dtype, device=accelerator.device) + text_encoder_four.to(dtype=weight_dtype, device=accelerator.device) else: vae.to(dtype=weight_dtype) + text_encoder_one.to(dtype=weight_dtype) + text_encoder_two.to(dtype=weight_dtype) + text_encoder_three.to(dtype=weight_dtype) + text_encoder_four.to(dtype=weight_dtype) transformer.to(accelerator.device, dtype=weight_dtype) - text_encoder_one.to(accelerator.device, dtype=weight_dtype) - text_encoder_two.to(accelerator.device, dtype=weight_dtype) - text_encoder_three.to(accelerator.device, dtype=weight_dtype) - text_encoder_four.to(accelerator.device, dtype=weight_dtype) # Initialize a text encoding pipeline and keep it to CPU for now. text_encoding_pipeline = HiDreamImagePipeline.from_pretrained( @@ -1302,12 +1306,13 @@ def load_model_hook(models, input_dir): ) def compute_text_embeddings(prompt, text_encoding_pipeline): - text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device) + if args.offload: + text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device) with torch.no_grad(): t5_prompt_embeds, _, llama3_prompt_embeds, _, pooled_prompt_embeds, _ = ( text_encoding_pipeline.encode_prompt(prompt=prompt, max_sequence_length=args.max_sequence_length) ) - if args.offload: + if args.offload: # back to cpu text_encoding_pipeline = text_encoding_pipeline.to("cpu") return t5_prompt_embeds, llama3_prompt_embeds, pooled_prompt_embeds From 408dfdbe773f832a53e2be2082142735e3f980b0 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Sat, 19 Apr 2025 06:33:44 +0000 Subject: [PATCH 57/76] Apply style fixes --- examples/dreambooth/train_dreambooth_lora_hidream.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index 83c4bddb9022..47e9a5bf61b6 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -1312,7 +1312,7 @@ def compute_text_embeddings(prompt, text_encoding_pipeline): t5_prompt_embeds, _, llama3_prompt_embeds, _, pooled_prompt_embeds, _ = ( text_encoding_pipeline.encode_prompt(prompt=prompt, max_sequence_length=args.max_sequence_length) ) - if args.offload: # back to cpu + if args.offload: # back to cpu text_encoding_pipeline = text_encoding_pipeline.to("cpu") return t5_prompt_embeds, llama3_prompt_embeds, pooled_prompt_embeds @@ -1533,7 +1533,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # and instead post-weight the loss weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) - target = noise - model_input if args.with_prior_preservation: # Chunk the noise and model_pred into two parts and compute the loss on each part separately. From 3383446536a140181e7f59fd237f0ad50365488f Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Sun, 20 Apr 2025 15:58:57 +0300 Subject: [PATCH 58/76] cleaner to_kwargs --- .../train_dreambooth_lora_hidream.py | 20 ++++++++----------- 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index 47e9a5bf61b6..378284ae1f27 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -1097,18 +1097,14 @@ def main(args): "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." ) - if not args.offload: - vae.to(dtype=weight_dtype, device=accelerator.device) - text_encoder_one.to(dtype=weight_dtype, device=accelerator.device) - text_encoder_two.to(dtype=weight_dtype, device=accelerator.device) - text_encoder_three.to(dtype=weight_dtype, device=accelerator.device) - text_encoder_four.to(dtype=weight_dtype, device=accelerator.device) - else: - vae.to(dtype=weight_dtype) - text_encoder_one.to(dtype=weight_dtype) - text_encoder_two.to(dtype=weight_dtype) - text_encoder_three.to(dtype=weight_dtype) - text_encoder_four.to(dtype=weight_dtype) + to_kwargs = {"dtype": weight_dtype, "device": accelerator.device} if not args.offload else {"dtype": weight_dtype} + # flux vae is stable in bf16 so load it in weight_dtype to reduce memory + vae.to(**to_kwargs) + text_encoder_one.to(**to_kwargs) + text_encoder_two.to(**to_kwargs) + text_encoder_three.to(**to_kwargs) + text_encoder_four.to(**to_kwargs) + # we never offload the transformer to CPU, so we can just use the accelerator device transformer.to(accelerator.device, dtype=weight_dtype) # Initialize a text encoding pipeline and keep it to CPU for now. From 73ab2017f13b74c3a3a92212d015792adfad2b95 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Sun, 20 Apr 2025 16:12:34 +0300 Subject: [PATCH 59/76] fix module name in copied from --- src/diffusers/loaders/lora_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 7153e459810a..894da587c1c9 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -5529,7 +5529,7 @@ def load_lora_weights( ) @classmethod - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->Lumina2Transformer2DModel + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->HiDreamImageTransformer2DModel def load_lora_into_transformer( cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False ): From a07ee59f1819e22744c8c611e855646d6e3866be Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 21 Apr 2025 11:42:34 +0300 Subject: [PATCH 60/76] add requirements --- examples/dreambooth/requirements_hidream.txt | 8 ++++++++ 1 file changed, 8 insertions(+) create mode 100644 examples/dreambooth/requirements_hidream.txt diff --git a/examples/dreambooth/requirements_hidream.txt b/examples/dreambooth/requirements_hidream.txt new file mode 100644 index 000000000000..4277a844621e --- /dev/null +++ b/examples/dreambooth/requirements_hidream.txt @@ -0,0 +1,8 @@ +accelerate>=1.4.0 +torchvision +transformers==4.50.0 +ftfy +tensorboard +Jinja2 +peft>=0.14.0 +sentencepiece \ No newline at end of file From a751cc2fbba89ace8d305b5b43b298266ee55b7c Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 21 Apr 2025 12:49:28 +0300 Subject: [PATCH 61/76] fix offloading --- examples/dreambooth/train_dreambooth_lora_hidream.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index 378284ae1f27..bab142fc733b 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -1357,7 +1357,7 @@ def compute_text_embeddings(prompt, text_encoding_pipeline): vae_config_shift_factor = vae.config.shift_factor if args.cache_latents: latents_cache = [] - if not args.offload: + if args.offload: vae = vae.to(accelerator.device) for batch in tqdm(train_dataloader, desc="Caching latents"): with torch.no_grad(): From 120b82146b9c71354d78c4f1a47ba4eb218a28d2 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 21 Apr 2025 13:10:27 +0300 Subject: [PATCH 62/76] fix offloading --- examples/dreambooth/train_dreambooth_lora_hidream.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index bab142fc733b..2b38829986db 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -1370,6 +1370,7 @@ def compute_text_embeddings(prompt, text_encoding_pipeline): del vae free_memory() + # Scheduler and math around the number of training steps. overrode_max_train_steps = False num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) From 9d71d3b38a17b569ecdad55b641d1d0fa4571dbe Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 21 Apr 2025 13:48:41 +0300 Subject: [PATCH 63/76] fix offloading --- examples/dreambooth/train_dreambooth_lora_hidream.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index 2b38829986db..bab142fc733b 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -1370,7 +1370,6 @@ def compute_text_embeddings(prompt, text_encoding_pipeline): del vae free_memory() - # Scheduler and math around the number of training steps. overrode_max_train_steps = False num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) From 2798d77272f888e3baff5a0b08572c08a8968fff Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 21 Apr 2025 14:10:22 +0300 Subject: [PATCH 64/76] update transformers version in reqs --- examples/dreambooth/requirements_hidream.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dreambooth/requirements_hidream.txt b/examples/dreambooth/requirements_hidream.txt index 4277a844621e..060ffd987a0e 100644 --- a/examples/dreambooth/requirements_hidream.txt +++ b/examples/dreambooth/requirements_hidream.txt @@ -1,6 +1,6 @@ accelerate>=1.4.0 torchvision -transformers==4.50.0 +transformers>=4.50.0 ftfy tensorboard Jinja2 From 363d29b601e6f871725766fbdd483d8154d85a26 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 21 Apr 2025 14:24:27 +0300 Subject: [PATCH 65/76] try AutoTokenizer --- examples/dreambooth/train_dreambooth_lora_hidream.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index bab142fc733b..1763cd59ac15 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -42,7 +42,7 @@ from torchvision import transforms from torchvision.transforms.functional import crop from tqdm.auto import tqdm -from transformers import CLIPTokenizer, LlamaForCausalLM, PretrainedConfig, PreTrainedTokenizerFast, T5Tokenizer +from transformers import CLIPTokenizer, LlamaForCausalLM, PretrainedConfig, PreTrainedTokenizerFast, T5Tokenizer, AutoTokenizer import diffusers from diffusers import ( @@ -1029,7 +1029,7 @@ def main(args): subfolder="tokenizer_2", revision=args.revision, ) - tokenizer_three = T5Tokenizer.from_pretrained( + tokenizer_three = AutoTokenizer.from_pretrained( args.pretrained_model_name_or_path, subfolder="tokenizer_3", revision=args.revision, From 8f24e8ce8e5d3bd2db81a4280db309b45de9030a Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 21 Apr 2025 14:28:58 +0300 Subject: [PATCH 66/76] try AutoTokenizer --- examples/dreambooth/train_dreambooth_lora_hidream.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index 1763cd59ac15..e28fa03f7ece 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -42,7 +42,7 @@ from torchvision import transforms from torchvision.transforms.functional import crop from tqdm.auto import tqdm -from transformers import CLIPTokenizer, LlamaForCausalLM, PretrainedConfig, PreTrainedTokenizerFast, T5Tokenizer, AutoTokenizer +from transformers import CLIPTokenizer, LlamaForCausalLM, PretrainedConfig, PreTrainedTokenizerFast, AutoTokenizer import diffusers from diffusers import ( From b31bdf09ce18785fdc95e32e9cfa36a2bfd70d77 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Mon, 21 Apr 2025 11:34:19 +0000 Subject: [PATCH 67/76] Apply style fixes --- examples/dreambooth/train_dreambooth_lora_hidream.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index e28fa03f7ece..783cedb6d741 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -42,7 +42,7 @@ from torchvision import transforms from torchvision.transforms.functional import crop from tqdm.auto import tqdm -from transformers import CLIPTokenizer, LlamaForCausalLM, PretrainedConfig, PreTrainedTokenizerFast, AutoTokenizer +from transformers import AutoTokenizer, CLIPTokenizer, LlamaForCausalLM, PretrainedConfig, PreTrainedTokenizerFast import diffusers from diffusers import ( From 6c77651fb53a4857477cb6aea2bc817211994d5e Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 21 Apr 2025 14:38:01 +0300 Subject: [PATCH 68/76] empty commit --- tests/lora/test_lora_layers_hidream.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 tests/lora/test_lora_layers_hidream.py diff --git a/tests/lora/test_lora_layers_hidream.py b/tests/lora/test_lora_layers_hidream.py new file mode 100644 index 000000000000..e69de29bb2d1 From 9b6ef439c357049751feaaad43869553124d70d4 Mon Sep 17 00:00:00 2001 From: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com> Date: Mon, 21 Apr 2025 14:39:49 +0300 Subject: [PATCH 69/76] Delete tests/lora/test_lora_layers_hidream.py --- tests/lora/test_lora_layers_hidream.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 tests/lora/test_lora_layers_hidream.py diff --git a/tests/lora/test_lora_layers_hidream.py b/tests/lora/test_lora_layers_hidream.py deleted file mode 100644 index e69de29bb2d1..000000000000 From fb3ac745197abd42310bca434848773aca98bb66 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 21 Apr 2025 14:52:52 +0300 Subject: [PATCH 70/76] change tokenizer_4 to load with AutoTokenizer as well --- examples/dreambooth/train_dreambooth_lora_hidream.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index 783cedb6d741..36d9801e3494 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -42,7 +42,7 @@ from torchvision import transforms from torchvision.transforms.functional import crop from tqdm.auto import tqdm -from transformers import AutoTokenizer, CLIPTokenizer, LlamaForCausalLM, PretrainedConfig, PreTrainedTokenizerFast +from transformers import AutoTokenizer, CLIPTokenizer, LlamaForCausalLM, PretrainedConfig import diffusers from diffusers import ( @@ -1035,7 +1035,7 @@ def main(args): revision=args.revision, ) - tokenizer_four = PreTrainedTokenizerFast.from_pretrained( + tokenizer_four = AutoTokenizer.from_pretrained( "meta-llama/Meta-Llama-3.1-8B-Instruct", revision=args.revision, ) @@ -1645,7 +1645,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # Final inference # Load previous pipeline - tokenizer_4 = PreTrainedTokenizerFast.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct") + tokenizer_4 = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct") text_encoder_4 = LlamaForCausalLM.from_pretrained( "meta-llama/Meta-Llama-3.1-8B-Instruct", output_hidden_states=True, From 82a40372138a6863cd85af343f818c6d3cabef6d Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 21 Apr 2025 15:34:46 +0300 Subject: [PATCH 71/76] make text_encoder_four and tokenizer_four configurable --- .../dreambooth/test_dreambooth_lora_hidream.py | 14 ++++++++++++++ .../dreambooth/train_dreambooth_lora_hidream.py | 17 +++++++++++++++-- 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/test_dreambooth_lora_hidream.py b/examples/dreambooth/test_dreambooth_lora_hidream.py index add81767d646..95f48ccd7d6d 100644 --- a/examples/dreambooth/test_dreambooth_lora_hidream.py +++ b/examples/dreambooth/test_dreambooth_lora_hidream.py @@ -35,6 +35,8 @@ class DreamBoothLoRAHiDreamImage(ExamplesTestsAccelerate): instance_data_dir = "docs/source/en/imgs" pretrained_model_name_or_path = "hf-internal-testing/tiny-hidream-i1-pipe" + text_encoder_4_path = "hf-internal-testing/tiny-random-LlamaForCausalLM" + tokenizer_4_path = "hf-internal-testing/tiny-random-LlamaForCausalLM" script_path = "examples/dreambooth/train_dreambooth_lora_hidream.py" transformer_layer_type = "double_stream_blocks.0.block.attn1.to_k" @@ -43,6 +45,8 @@ def test_dreambooth_lora_hidream(self): test_args = f""" {self.script_path} --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --pretrained_text_encoder_4_name_or_path {self.text_encoder_4_path} + --pretrained_tokenizer_4_name_or_path {self.tokenizer_4_path} --instance_data_dir {self.instance_data_dir} --resolution 32 --train_batch_size 1 @@ -76,6 +80,8 @@ def test_dreambooth_lora_latent_caching(self): test_args = f""" {self.script_path} --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --pretrained_text_encoder_4_name_or_path {self.text_encoder_4_path} + --pretrained_tokenizer_4_name_or_path {self.tokenizer_4_path} --instance_data_dir {self.instance_data_dir} --resolution 32 --train_batch_size 1 @@ -110,6 +116,8 @@ def test_dreambooth_lora_layers(self): test_args = f""" {self.script_path} --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --pretrained_text_encoder_4_name_or_path {self.text_encoder_4_path} + --pretrained_tokenizer_4_name_or_path {self.tokenizer_4_path} --instance_data_dir {self.instance_data_dir} --resolution 32 --train_batch_size 1 @@ -146,6 +154,8 @@ def test_dreambooth_lora_hidream_checkpointing_checkpoints_total_limit(self): test_args = f""" {self.script_path} --pretrained_model_name_or_path={self.pretrained_model_name_or_path} + --pretrained_text_encoder_4_name_or_path {self.text_encoder_4_path} + --pretrained_tokenizer_4_name_or_path {self.tokenizer_4_path} --instance_data_dir={self.instance_data_dir} --output_dir={tmpdir} --resolution=32 @@ -170,6 +180,8 @@ def test_dreambooth_lora_hidream_checkpointing_checkpoints_total_limit_removes_m test_args = f""" {self.script_path} --pretrained_model_name_or_path={self.pretrained_model_name_or_path} + --pretrained_text_encoder_4_name_or_path {self.text_encoder_4_path} + --pretrained_tokenizer_4_name_or_path {self.tokenizer_4_path} --instance_data_dir={self.instance_data_dir} --output_dir={tmpdir} --resolution=32 @@ -188,6 +200,8 @@ def test_dreambooth_lora_hidream_checkpointing_checkpoints_total_limit_removes_m resume_run_args = f""" {self.script_path} --pretrained_model_name_or_path={self.pretrained_model_name_or_path} + --pretrained_text_encoder_4_name_or_path {self.text_encoder_4_path} + --pretrained_tokenizer_4_name_or_path {self.tokenizer_4_path} --instance_data_dir={self.instance_data_dir} --output_dir={tmpdir} --resolution=32 diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index 36d9801e3494..ec2b1c3b6aeb 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -184,7 +184,7 @@ def load_text_encoders(class_one, class_two, class_three): args.pretrained_model_name_or_path, subfolder="text_encoder_3", revision=args.revision, variant=args.variant ) text_encoder_four = LlamaForCausalLM.from_pretrained( - "meta-llama/Meta-Llama-3.1-8B-Instruct", + args.pretrained_text_encoder_4_name_or_path, output_hidden_states=True, output_attentions=True, torch_dtype=torch.bfloat16, @@ -287,6 +287,18 @@ def parse_args(input_args=None): required=True, help="Path to pretrained model or model identifier from huggingface.co/models.", ) + parser.add_argument( + "--pretrained_tokenizer_4_name_or_path", + type=str, + default="meta-llama/Meta-Llama-3.1-8B-Instruct", + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--pretrained_text_encoder_4_name_or_path", + type=str, + default="meta-llama/Meta-Llama-3.1-8B-Instruct", + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) parser.add_argument( "--revision", type=str, @@ -1036,7 +1048,7 @@ def main(args): ) tokenizer_four = AutoTokenizer.from_pretrained( - "meta-llama/Meta-Llama-3.1-8B-Instruct", + args.pretrained_tokenizer_4_name_or_path, revision=args.revision, ) tokenizer_four.pad_token = tokenizer_four.eos_token @@ -1646,6 +1658,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # Final inference # Load previous pipeline tokenizer_4 = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct") + tokenizer_4.pad_token = tokenizer_4.eos_token text_encoder_4 = LlamaForCausalLM.from_pretrained( "meta-llama/Meta-Llama-3.1-8B-Instruct", output_hidden_states=True, From 9de10cbfb2504424eeeb47909d0a1c97b4d24da4 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 21 Apr 2025 15:48:00 +0300 Subject: [PATCH 72/76] save model card --- .../train_dreambooth_lora_hidream.py | 23 ++++++++++--------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index ec2b1c3b6aeb..4409932a4bd9 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -1657,10 +1657,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # Final inference # Load previous pipeline - tokenizer_4 = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct") + tokenizer_4 = AutoTokenizer.from_pretrained(args.pretrained_tokenizer_4_name_or_path) tokenizer_4.pad_token = tokenizer_4.eos_token text_encoder_4 = LlamaForCausalLM.from_pretrained( - "meta-llama/Meta-Llama-3.1-8B-Instruct", + args.pretrained_text_encoder_4_name_or_path, output_hidden_states=True, output_attentions=True, torch_dtype=torch.bfloat16, @@ -1692,16 +1692,17 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): torch_dtype=weight_dtype, ) + validation_prpmpt = args.validation_prompt if args.validation_prompt else args.final_validation_prompt + save_model_card( + repo_id, + images=images, + base_model=args.pretrained_model_name_or_path, + instance_prompt=args.instance_prompt, + validation_prompt=validation_prpmpt, + repo_folder=args.output_dir, + ) + if args.push_to_hub: - validation_prpmpt = args.validation_prompt if args.validation_prompt else args.final_validation_prompt - save_model_card( - repo_id, - images=images, - base_model=args.pretrained_model_name_or_path, - instance_prompt=args.instance_prompt, - validation_prompt=validation_prpmpt, - repo_folder=args.output_dir, - ) upload_folder( repo_id=repo_id, folder_path=args.output_dir, From 418f6a3c5ca0607c66e132f08f53fab027f93652 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 21 Apr 2025 16:00:54 +0300 Subject: [PATCH 73/76] save model card --- examples/dreambooth/train_dreambooth_lora_hidream.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index 4409932a4bd9..be9e0a04f6e1 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -1694,7 +1694,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): validation_prpmpt = args.validation_prompt if args.validation_prompt else args.final_validation_prompt save_model_card( - repo_id, + (args.hub_model_id or Path(args.output_dir).name) if not args.push_to_hub else repo_id, images=images, base_model=args.pretrained_model_name_or_path, instance_prompt=args.instance_prompt, From 36c1adae10459ce4aeb4f2b6d27589af1d072d0c Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 21 Apr 2025 17:09:34 +0300 Subject: [PATCH 74/76] revert T5 --- examples/dreambooth/train_dreambooth_lora_hidream.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index be9e0a04f6e1..72e458a72abf 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -42,7 +42,7 @@ from torchvision import transforms from torchvision.transforms.functional import crop from tqdm.auto import tqdm -from transformers import AutoTokenizer, CLIPTokenizer, LlamaForCausalLM, PretrainedConfig +from transformers import AutoTokenizer, CLIPTokenizer, LlamaForCausalLM, PretrainedConfig, T5Tokenizer import diffusers from diffusers import ( @@ -1041,7 +1041,7 @@ def main(args): subfolder="tokenizer_2", revision=args.revision, ) - tokenizer_three = AutoTokenizer.from_pretrained( + tokenizer_three = T5Tokenizer.from_pretrained( args.pretrained_model_name_or_path, subfolder="tokenizer_3", revision=args.revision, From bd399b19373aa3699937bd454416f15fe3d247a1 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 21 Apr 2025 17:58:08 +0300 Subject: [PATCH 75/76] fix test --- examples/dreambooth/test_dreambooth_lora_hidream.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dreambooth/test_dreambooth_lora_hidream.py b/examples/dreambooth/test_dreambooth_lora_hidream.py index 95f48ccd7d6d..3f48c3095f3f 100644 --- a/examples/dreambooth/test_dreambooth_lora_hidream.py +++ b/examples/dreambooth/test_dreambooth_lora_hidream.py @@ -189,7 +189,7 @@ def test_dreambooth_lora_hidream_checkpointing_checkpoints_total_limit_removes_m --gradient_accumulation_steps=1 --max_train_steps=4 --checkpointing_steps=2 - --max_sequence_length 166 + --max_sequence_length 16 """.split() test_args.extend(["--instance_prompt", ""]) From ed97dba00a7d673ede2c833ab159d026ce3a7d00 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Tue, 22 Apr 2025 10:59:20 +0300 Subject: [PATCH 76/76] remove non diffusers lumina2 conversion --- src/diffusers/loaders/lora_pipeline.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 894da587c1c9..fb2cdf6ce304 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -5370,6 +5370,7 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin): @classmethod @validate_hf_hub_args + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict def lora_state_dict( cls, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], @@ -5463,11 +5464,6 @@ def lora_state_dict( logger.warning(warn_msg) state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} - # conversion. - non_diffusers = any(k.startswith("diffusion_model.") for k in state_dict) - if non_diffusers: - state_dict = _convert_non_diffusers_lumina2_lora_to_diffusers(state_dict) - return state_dict # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights