From c314a98863fe37f1810cc4544956cae5be4ba8e2 Mon Sep 17 00:00:00 2001 From: Mary Hipp Date: Mon, 18 Aug 2025 13:16:35 -0400 Subject: [PATCH 1/6] save transfomer only for peft outputs --- .../_shared/flux/lora_checkpoint_utils.py | 40 +++++-------------- .../pipelines/flux/lora/train.py | 7 ++-- 2 files changed, 13 insertions(+), 34 deletions(-) diff --git a/src/invoke_training/_shared/flux/lora_checkpoint_utils.py b/src/invoke_training/_shared/flux/lora_checkpoint_utils.py index 786ad2aa..823d6e52 100644 --- a/src/invoke_training/_shared/flux/lora_checkpoint_utils.py +++ b/src/invoke_training/_shared/flux/lora_checkpoint_utils.py @@ -59,42 +59,20 @@ } -def save_flux_peft_checkpoint( +def save_flux_peft_checkpoint_single_file( checkpoint_dir: Path | str, transformer: peft.PeftModel | None, - text_encoder_1: peft.PeftModel | None, - text_encoder_2: peft.PeftModel | None, ): - models = {} - if transformer is not None: - models[FLUX_PEFT_TRANSFORMER_KEY] = transformer - if text_encoder_1 is not None: - models[FLUX_PEFT_TEXT_ENCODER_1_KEY] = text_encoder_1 - if text_encoder_2 is not None: - models[FLUX_PEFT_TEXT_ENCODER_2_KEY] = text_encoder_2 - - save_multi_model_peft_checkpoint(checkpoint_dir=checkpoint_dir, models=models) - + assert isinstance(transformer, peft.PeftModel) -def load_flux_peft_checkpoint( - checkpoint_dir: Path | str, - transformer: FluxTransformer2DModel, - text_encoder_1: CLIPTextModel, - text_encoder_2: CLIPTextModel, - is_trainable: bool = False, -): - models = load_multi_model_peft_checkpoint( - checkpoint_dir=checkpoint_dir, - models={ - FLUX_PEFT_TRANSFORMER_KEY: transformer, - FLUX_PEFT_TEXT_ENCODER_1_KEY: text_encoder_1, - FLUX_PEFT_TEXT_ENCODER_2_KEY: text_encoder_2, - }, - is_trainable=is_trainable, - raise_if_subdir_missing=False, - ) + if ( + hasattr(transformer, "config") + and isinstance(transformer.config, dict) + and "_name_or_path" not in transformer.config + ): + transformer.config["_name_or_path"] = None - return models[FLUX_PEFT_TRANSFORMER_KEY], models[FLUX_PEFT_TEXT_ENCODER_1_KEY], models[FLUX_PEFT_TEXT_ENCODER_2_KEY] + transformer.save_pretrained(str(checkpoint_dir)) def save_flux_kohya_checkpoint( diff --git a/src/invoke_training/pipelines/flux/lora/train.py b/src/invoke_training/pipelines/flux/lora/train.py index 1ae48b71..59493459 100644 --- a/src/invoke_training/pipelines/flux/lora/train.py +++ b/src/invoke_training/pipelines/flux/lora/train.py @@ -33,7 +33,7 @@ from invoke_training._shared.flux.encoding_utils import encode_prompt from invoke_training._shared.flux.lora_checkpoint_utils import ( save_flux_kohya_checkpoint, - save_flux_peft_checkpoint, + save_flux_peft_checkpoint_single_file, ) from invoke_training._shared.flux.model_loading_utils import load_models_flux from invoke_training._shared.flux.validation import generate_validation_images_flux @@ -63,9 +63,10 @@ def _save_flux_lora_checkpoint( if lora_checkpoint_format == "invoke_peft": model_type = ModelType.FLUX_LORA_PEFT - save_flux_peft_checkpoint( - Path(save_path), transformer=transformer, text_encoder_1=text_encoder_1, text_encoder_2=text_encoder_2 + save_flux_peft_checkpoint_single_file( + Path(save_path), transformer=transformer ) + elif lora_checkpoint_format == "kohya": model_type = ModelType.FLUX_LORA_KOHYA save_flux_kohya_checkpoint( From bee1db1d595fc71c34b8a73c0cb874320da309e4 Mon Sep 17 00:00:00 2001 From: Mary Hipp Date: Mon, 18 Aug 2025 13:23:11 -0400 Subject: [PATCH 2/6] formatting --- src/invoke_training/_shared/flux/lora_checkpoint_utils.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/invoke_training/_shared/flux/lora_checkpoint_utils.py b/src/invoke_training/_shared/flux/lora_checkpoint_utils.py index 823d6e52..6f4876c4 100644 --- a/src/invoke_training/_shared/flux/lora_checkpoint_utils.py +++ b/src/invoke_training/_shared/flux/lora_checkpoint_utils.py @@ -4,13 +4,9 @@ import peft import torch -from diffusers import FluxTransformer2DModel -from transformers import CLIPTextModel from invoke_training._shared.checkpoints.lora_checkpoint_utils import ( _convert_peft_state_dict_to_kohya_state_dict, - load_multi_model_peft_checkpoint, - save_multi_model_peft_checkpoint, ) from invoke_training._shared.checkpoints.serialization import save_state_dict @@ -64,7 +60,6 @@ def save_flux_peft_checkpoint_single_file( transformer: peft.PeftModel | None, ): assert isinstance(transformer, peft.PeftModel) - if ( hasattr(transformer, "config") and isinstance(transformer.config, dict) From 188163df30a50dd33b42968a5fff6d5b5546ad16 Mon Sep 17 00:00:00 2001 From: Mary Hipp Date: Mon, 18 Aug 2025 14:13:43 -0400 Subject: [PATCH 3/6] delete everything but safetensors --- .../_shared/flux/lora_checkpoint_utils.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/invoke_training/_shared/flux/lora_checkpoint_utils.py b/src/invoke_training/_shared/flux/lora_checkpoint_utils.py index 6f4876c4..7fa6882a 100644 --- a/src/invoke_training/_shared/flux/lora_checkpoint_utils.py +++ b/src/invoke_training/_shared/flux/lora_checkpoint_utils.py @@ -1,5 +1,6 @@ # ruff: noqa: N806 import os +import shutil from pathlib import Path import peft @@ -69,6 +70,17 @@ def save_flux_peft_checkpoint_single_file( transformer.save_pretrained(str(checkpoint_dir)) + # Remove anything that is not the base safetensors file + checkpoint_path = Path(checkpoint_dir) + for child in checkpoint_path.iterdir(): + if child.is_dir(): + shutil.rmtree(child, ignore_errors=True) + elif child.is_file() and child.name != "adapter_model.safetensors": + try: + child.unlink() + except OSError: + pass + def save_flux_kohya_checkpoint( checkpoint_path: Path, From 3eb106576632d773e957965b6d0896587823cfcf Mon Sep 17 00:00:00 2001 From: Mary Hipp Date: Mon, 18 Aug 2025 15:18:59 -0400 Subject: [PATCH 4/6] massage peft output into a single transformer --- .../_shared/flux/lora_checkpoint_utils.py | 33 ++++++++++++------- .../pipelines/flux/lora/train.py | 3 +- 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/src/invoke_training/_shared/flux/lora_checkpoint_utils.py b/src/invoke_training/_shared/flux/lora_checkpoint_utils.py index 7fa6882a..33c66629 100644 --- a/src/invoke_training/_shared/flux/lora_checkpoint_utils.py +++ b/src/invoke_training/_shared/flux/lora_checkpoint_utils.py @@ -1,6 +1,7 @@ # ruff: noqa: N806 import os import shutil +import tempfile from pathlib import Path import peft @@ -68,18 +69,26 @@ def save_flux_peft_checkpoint_single_file( ): transformer.config["_name_or_path"] = None - transformer.save_pretrained(str(checkpoint_dir)) - - # Remove anything that is not the base safetensors file - checkpoint_path = Path(checkpoint_dir) - for child in checkpoint_path.iterdir(): - if child.is_dir(): - shutil.rmtree(child, ignore_errors=True) - elif child.is_file() and child.name != "adapter_model.safetensors": - try: - child.unlink() - except OSError: - pass + # Normalize output path and ensure parent exists when saving as file + out_path = Path(checkpoint_dir) + + with tempfile.TemporaryDirectory() as tmpdir: + tmp_path = Path(tmpdir) + # Save PEFT adapter into temporary directory + transformer.save_pretrained(str(tmp_path)) + + # Move adapter_model.safetensors out of the temp dir to the requested location + src_file = tmp_path / "adapter_model.safetensors" + if not src_file.exists(): + raise FileNotFoundError( + f"Expected adapter file not found in temporary directory: {src_file}" + ) + + # Always rename/move to exactly the path provided by checkpoint_dir + dest_file = out_path + dest_file.parent.mkdir(parents=True, exist_ok=True) + + shutil.move(str(src_file), str(dest_file)) def save_flux_kohya_checkpoint( diff --git a/src/invoke_training/pipelines/flux/lora/train.py b/src/invoke_training/pipelines/flux/lora/train.py index 59493459..5464cd00 100644 --- a/src/invoke_training/pipelines/flux/lora/train.py +++ b/src/invoke_training/pipelines/flux/lora/train.py @@ -531,7 +531,8 @@ def inject_lora_layers(model, lora_config: peft.LoraConfig, lr: float | None = N base_dir=ckpt_dir, prefix="checkpoint", max_checkpoints=config.max_checkpoints, - extension=".safetensors" if config.lora_checkpoint_format == "kohya" else None, + # extension=".safetensors" if config.lora_checkpoint_format == "kohya" else None, + extension=".safetensors" # we are going to massage the peft model for this in save_flux_peft_checkpoint_single_file ) # Train! From e620d0cad08a0f46ec28ab69bd66e686546ad278 Mon Sep 17 00:00:00 2001 From: Kent Keirsey <31807370+hipsterusername@users.noreply.github.com> Date: Mon, 18 Aug 2025 16:33:00 -0400 Subject: [PATCH 5/6] alpha fixes --- src/invoke_training/pipelines/flux/lora/train.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/src/invoke_training/pipelines/flux/lora/train.py b/src/invoke_training/pipelines/flux/lora/train.py index 5464cd00..3b9b5368 100644 --- a/src/invoke_training/pipelines/flux/lora/train.py +++ b/src/invoke_training/pipelines/flux/lora/train.py @@ -63,10 +63,8 @@ def _save_flux_lora_checkpoint( if lora_checkpoint_format == "invoke_peft": model_type = ModelType.FLUX_LORA_PEFT - save_flux_peft_checkpoint_single_file( - Path(save_path), transformer=transformer - ) - + save_flux_peft_checkpoint_single_file(Path(save_path), transformer=transformer) + elif lora_checkpoint_format == "kohya": model_type = ModelType.FLUX_LORA_KOHYA save_flux_kohya_checkpoint( @@ -425,8 +423,7 @@ def inject_lora_layers(model, lora_config: peft.LoraConfig, lr: float | None = N if config.train_transformer: transformer_lora_config = peft.LoraConfig( r=config.lora_rank_dim, - # TODO(ryand): Diffusers uses lora_alpha=config.lora_rank_dim. Is that preferred? - lora_alpha=1.0, + lora_alpha=config.lora_rank_dim, target_modules=config.flux_lora_target_modules, ) transformer = inject_lora_layers(transformer, transformer_lora_config, lr=config.transformer_learning_rate) @@ -434,7 +431,7 @@ def inject_lora_layers(model, lora_config: peft.LoraConfig, lr: float | None = N if config.train_text_encoder: text_encoder_lora_config = peft.LoraConfig( r=config.lora_rank_dim, - lora_alpha=1.0, + lora_alpha=config.lora_rank_dim, # init_lora_weights="gaussian", target_modules=config.text_encoder_lora_target_modules, ) @@ -532,7 +529,7 @@ def inject_lora_layers(model, lora_config: peft.LoraConfig, lr: float | None = N prefix="checkpoint", max_checkpoints=config.max_checkpoints, # extension=".safetensors" if config.lora_checkpoint_format == "kohya" else None, - extension=".safetensors" # we are going to massage the peft model for this in save_flux_peft_checkpoint_single_file + extension=".safetensors", # we are going to massage the peft model for this in save_flux_peft_checkpoint_single_file ) # Train! From 0fc93b44f6b2d29ae4e73795c8eb90b1fd9e50f6 Mon Sep 17 00:00:00 2001 From: Mary Hipp Date: Mon, 18 Aug 2025 18:40:11 -0400 Subject: [PATCH 6/6] ruff --- src/invoke_training/_shared/flux/lora_checkpoint_utils.py | 6 ++---- src/invoke_training/pipelines/flux/lora/train.py | 4 ++-- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/invoke_training/_shared/flux/lora_checkpoint_utils.py b/src/invoke_training/_shared/flux/lora_checkpoint_utils.py index 33c66629..bd52fe7b 100644 --- a/src/invoke_training/_shared/flux/lora_checkpoint_utils.py +++ b/src/invoke_training/_shared/flux/lora_checkpoint_utils.py @@ -80,10 +80,8 @@ def save_flux_peft_checkpoint_single_file( # Move adapter_model.safetensors out of the temp dir to the requested location src_file = tmp_path / "adapter_model.safetensors" if not src_file.exists(): - raise FileNotFoundError( - f"Expected adapter file not found in temporary directory: {src_file}" - ) - + raise FileNotFoundError(f"Expected adapter file not found in temporary directory: {src_file}") + # Always rename/move to exactly the path provided by checkpoint_dir dest_file = out_path dest_file.parent.mkdir(parents=True, exist_ok=True) diff --git a/src/invoke_training/pipelines/flux/lora/train.py b/src/invoke_training/pipelines/flux/lora/train.py index 3b9b5368..17653fb5 100644 --- a/src/invoke_training/pipelines/flux/lora/train.py +++ b/src/invoke_training/pipelines/flux/lora/train.py @@ -528,8 +528,8 @@ def inject_lora_layers(model, lora_config: peft.LoraConfig, lr: float | None = N base_dir=ckpt_dir, prefix="checkpoint", max_checkpoints=config.max_checkpoints, - # extension=".safetensors" if config.lora_checkpoint_format == "kohya" else None, - extension=".safetensors", # we are going to massage the peft model for this in save_flux_peft_checkpoint_single_file + # we are going to massage the peft model for this in save_flux_peft_checkpoint_single_file + extension=".safetensors", ) # Train!