Skip to content

Commit

Permalink
Support FLUX OneTrainer LoRA formats (incl. DoRA) (#7590)
Browse files Browse the repository at this point in the history
## Summary

This PR adds support for the FLUX LoRA model format produced by
OneTrainer.

Specifically, this PR adds:
- Support for DoRA patches
- Support for patch models that modify the FLUX T5 encoder
- Probing / loading support for OneTrainer models

## Known limitations

- DoRA patches cannot currently be applied to base weights that are
quantized with `bitsandbytes`. The DoRA algorithm requires accessing the
original model weight in order to compute the patch diff, and the
bitsandbytes quantization layers make this difficult. DoRA patches can
be applied to non-quantized and GGUF-quantized layers without issue.
- This PR results in a slight speed regression for a very particular
inference combination: quantized base model + LoRA with diffusers keys
(i.e. uses the `MergedLayerPatch`). Now that more LoRA formats are using
the `MergedLayerPatch`, it was becoming too much work to maintain this
optimization. Regression from ~1.7 it/s to ~1.4 it/s.

## Future Notes

- We may want to consider dropping support for bitsandbytes
quantization. It is very difficult to maintain compatibility for across
features like partial-loading and LoRA patching.
- At a future time, we should refactor the LoRA parsing logic to be more
generalized rather than handling each format independently.
- There are some redundant device casts and dequantizations in
`autocast_linear_forward_sidecar_patches(...)` (and its sub-calls).
Optimizing this is left for future work.

## Related Issues / Discussions

- This PR should address a handful of the LoRAs reported in
#7131 (specifically, most of
the `envy*` LoRAs).
- This PR should address the example in
#6912 (though the intended
effect of that LoRA is not totally clear, so its hard to verify with
full confidence).

## QA Instructions


OneTrainer test models:
-
https://civitai.com/models/844821/envy-flux-dark-watercolor-01?modelVersionId=945159
(DoRA, transformer only)
-
https://civitai.com/models/836757/envy-flux-digital-brush-01?modelVersionId=936167
(hada, transformer only)
- ball_flux from #6912
(DoRA, transformer/clip/t5)

The following tests were repeated with each of the OneTrainer test
models:

- [x] Test with non-quantized base model
- [x] Test with GGUF-quantized base model
- [x] Test with BnB-quantized base model
- [x] Test with non-quantized base model that is partially-loaded onto
the GPU

Other regression test:

- [x] Test some SD1 LoRAs
- [x] Test some SDXL LoRAs
- [x] Test a variety of existing FLUX LoRA formats
- [x] Test a FLUX Control LoRA on all base model quantization formats. 

## Merge Plan

No special instructions.

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [x] _Tests added / updated (if applicable)_
- [x] _Documentation added / updated (if applicable)_
- [ ] _Updated `What's New` copy (if doing a release after this PR)_
  • Loading branch information
RyanJDick authored Jan 28, 2025
2 parents 9d2f8b4 + 229834a commit debcbd6
Show file tree
Hide file tree
Showing 26 changed files with 2,899 additions and 131 deletions.
36 changes: 33 additions & 3 deletions invokeai/app/invocations/flux_lora_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.invocations.model import CLIPField, LoRAField, ModelIdentifierField, TransformerField
from invokeai.app.invocations.model import CLIPField, LoRAField, ModelIdentifierField, T5EncoderField, TransformerField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import BaseModelType

Expand All @@ -21,14 +21,17 @@ class FluxLoRALoaderOutput(BaseInvocationOutput):
default=None, description=FieldDescriptions.transformer, title="FLUX Transformer"
)
clip: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
t5_encoder: Optional[T5EncoderField] = OutputField(
default=None, description=FieldDescriptions.t5_encoder, title="T5 Encoder"
)


@invocation(
"flux_lora_loader",
title="FLUX LoRA",
tags=["lora", "model", "flux"],
category="model",
version="1.1.0",
version="1.2.0",
classification=Classification.Prototype,
)
class FluxLoRALoaderInvocation(BaseInvocation):
Expand All @@ -50,6 +53,12 @@ class FluxLoRALoaderInvocation(BaseInvocation):
description=FieldDescriptions.clip,
input=Input.Connection,
)
t5_encoder: T5EncoderField | None = InputField(
default=None,
title="T5 Encoder",
description=FieldDescriptions.t5_encoder,
input=Input.Connection,
)

def invoke(self, context: InvocationContext) -> FluxLoRALoaderOutput:
lora_key = self.lora.key
Expand All @@ -62,6 +71,8 @@ def invoke(self, context: InvocationContext) -> FluxLoRALoaderOutput:
raise ValueError(f'LoRA "{lora_key}" already applied to transformer.')
if self.clip and any(lora.lora.key == lora_key for lora in self.clip.loras):
raise ValueError(f'LoRA "{lora_key}" already applied to CLIP encoder.')
if self.t5_encoder and any(lora.lora.key == lora_key for lora in self.t5_encoder.loras):
raise ValueError(f'LoRA "{lora_key}" already applied to T5 encoder.')

output = FluxLoRALoaderOutput()

Expand All @@ -82,6 +93,14 @@ def invoke(self, context: InvocationContext) -> FluxLoRALoaderOutput:
weight=self.weight,
)
)
if self.t5_encoder is not None:
output.t5_encoder = self.t5_encoder.model_copy(deep=True)
output.t5_encoder.loras.append(
LoRAField(
lora=self.lora,
weight=self.weight,
)
)

return output

Expand All @@ -91,7 +110,7 @@ def invoke(self, context: InvocationContext) -> FluxLoRALoaderOutput:
title="FLUX LoRA Collection Loader",
tags=["lora", "model", "flux"],
category="model",
version="1.1.0",
version="1.2.0",
classification=Classification.Prototype,
)
class FLUXLoRACollectionLoader(BaseInvocation):
Expand All @@ -113,6 +132,12 @@ class FLUXLoRACollectionLoader(BaseInvocation):
description=FieldDescriptions.clip,
input=Input.Connection,
)
t5_encoder: T5EncoderField | None = InputField(
default=None,
title="T5 Encoder",
description=FieldDescriptions.t5_encoder,
input=Input.Connection,
)

def invoke(self, context: InvocationContext) -> FluxLoRALoaderOutput:
output = FluxLoRALoaderOutput()
Expand Down Expand Up @@ -140,4 +165,9 @@ def invoke(self, context: InvocationContext) -> FluxLoRALoaderOutput:
output.clip = self.clip.model_copy(deep=True)
output.clip.loras.append(lora)

if self.t5_encoder is not None:
if output.t5_encoder is None:
output.t5_encoder = self.t5_encoder.model_copy(deep=True)
output.t5_encoder.loras.append(lora)

return output
4 changes: 2 additions & 2 deletions invokeai/app/invocations/flux_model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class FluxModelLoaderOutput(BaseInvocationOutput):
title="Flux Main Model",
tags=["model", "flux"],
category="model",
version="1.0.4",
version="1.0.5",
classification=Classification.Prototype,
)
class FluxModelLoaderInvocation(BaseInvocation):
Expand Down Expand Up @@ -87,7 +87,7 @@ def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
return FluxModelLoaderOutput(
transformer=TransformerField(transformer=transformer, loras=[]),
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder, loras=[]),
vae=VAEField(vae=vae),
max_seq_len=max_seq_lengths[transformer_config.config_path],
)
43 changes: 41 additions & 2 deletions invokeai/app/invocations/flux_text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from invokeai.backend.flux.modules.conditioner import HFEncoder
from invokeai.backend.model_manager.config import ModelFormat
from invokeai.backend.patches.layer_patcher import LayerPatcher
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX, FLUX_LORA_T5_PREFIX
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo

Expand Down Expand Up @@ -71,13 +71,45 @@ def invoke(self, context: InvocationContext) -> FluxConditioningOutput:
def _t5_encode(self, context: InvocationContext) -> torch.Tensor:
prompt = [self.prompt]

t5_encoder_info = context.models.load(self.t5_encoder.text_encoder)
t5_encoder_config = t5_encoder_info.config
assert t5_encoder_config is not None

with (
context.models.load(self.t5_encoder.text_encoder) as t5_text_encoder,
t5_encoder_info.model_on_device() as (cached_weights, t5_text_encoder),
context.models.load(self.t5_encoder.tokenizer) as t5_tokenizer,
ExitStack() as exit_stack,
):
assert isinstance(t5_text_encoder, T5EncoderModel)
assert isinstance(t5_tokenizer, (T5Tokenizer, T5TokenizerFast))

# Determine if the model is quantized.
# If the model is quantized, then we need to apply the LoRA weights as sidecar layers. This results in
# slower inference than direct patching, but is agnostic to the quantization format.
if t5_encoder_config.format in [ModelFormat.T5Encoder, ModelFormat.Diffusers]:
model_is_quantized = False
elif t5_encoder_config.format in [
ModelFormat.BnbQuantizedLlmInt8b,
ModelFormat.BnbQuantizednf4b,
ModelFormat.GGUFQuantized,
]:
model_is_quantized = True
else:
raise ValueError(f"Unsupported model format: {t5_encoder_config.format}")

# Apply LoRA models to the T5 encoder.
# Note: We apply the LoRA after the encoder has been moved to its target device for faster patching.
exit_stack.enter_context(
LayerPatcher.apply_smart_model_patches(
model=t5_text_encoder,
patches=self._t5_lora_iterator(context),
prefix=FLUX_LORA_T5_PREFIX,
dtype=t5_text_encoder.dtype,
cached_weights=cached_weights,
force_sidecar_patching=model_is_quantized,
)
)

t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False, self.t5_max_seq_len)

context.util.signal_progress("Running T5 encoder")
Expand Down Expand Up @@ -132,3 +164,10 @@ def _clip_lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[Mode
assert isinstance(lora_info.model, ModelPatchRaw)
yield (lora_info.model, lora.weight)
del lora_info

def _t5_lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[ModelPatchRaw, float]]:
for lora in self.t5_encoder.loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, ModelPatchRaw)
yield (lora_info.model, lora.weight)
del lora_info
1 change: 1 addition & 0 deletions invokeai/app/invocations/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class CLIPField(BaseModel):
class T5EncoderField(BaseModel):
tokenizer: ModelIdentifierField = Field(description="Info to load tokenizer submodel")
text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel")
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")


class VAEField(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
CustomModuleMixin,
)
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.layers.concatenated_lora_layer import ConcatenatedLoRALayer
from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer
from invokeai.backend.patches.layers.lora_layer import LoRALayer

Expand All @@ -22,25 +21,6 @@ def linear_lora_forward(input: torch.Tensor, lora_layer: LoRALayer, lora_weight:
return x


def concatenated_lora_forward(
input: torch.Tensor, concatenated_lora_layer: ConcatenatedLoRALayer, lora_weight: float
) -> torch.Tensor:
"""An optimized implementation of the residual calculation for a sidecar ConcatenatedLoRALayer."""
x_chunks: list[torch.Tensor] = []
for lora_layer in concatenated_lora_layer.lora_layers:
x_chunk = torch.nn.functional.linear(input, lora_layer.down)
if lora_layer.mid is not None:
x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.mid)
x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.up, bias=lora_layer.bias)
x_chunk *= lora_weight * lora_layer.scale()
x_chunks.append(x_chunk)

# TODO(ryand): Generalize to support concat_axis != 0.
assert concatenated_lora_layer.concat_axis == 0
x = torch.cat(x_chunks, dim=-1)
return x


def autocast_linear_forward_sidecar_patches(
orig_module: torch.nn.Linear, input: torch.Tensor, patches_and_weights: list[tuple[BaseLayerPatch, float]]
) -> torch.Tensor:
Expand All @@ -66,8 +46,6 @@ def autocast_linear_forward_sidecar_patches(
output += linear_lora_forward(orig_input, patch, patch_weight)
elif isinstance(patch, LoRALayer):
output += linear_lora_forward(input, patch, patch_weight)
elif isinstance(patch, ConcatenatedLoRALayer):
output += concatenated_lora_forward(input, patch, patch_weight)
else:
unprocessed_patches_and_weights.append((patch, patch_weight))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import torch

from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.layers.param_shape_utils import get_param_shape
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor


class CustomModuleMixin:
Expand Down Expand Up @@ -42,6 +44,20 @@ def _aggregate_patch_parameters(
device: torch.device | None = None,
):
"""Helper function that aggregates the parameters from all patches into a single dict."""
# HACK(ryand): If the original parameters are in a quantized format whose weights can't be accessed, we replace
# them with dummy tensors on the 'meta' device. This allows patch layers to access the shapes of the original
# parameters. But, of course, any sub-layers that need to access the actual values of the parameters will fail.
for param_name in orig_params.keys():
param = orig_params[param_name]
if type(param) is torch.nn.Parameter and type(param.data) is torch.Tensor:
pass
elif type(param) is GGMLTensor:
# Move to device and dequantize here. Doing it in the patch layer can result in redundant casts /
# dequantizations.
orig_params[param_name] = param.to(device=device).get_dequantized_tensor()
else:
orig_params[param_name] = torch.empty(get_param_shape(param), device="meta")

params: dict[str, torch.Tensor] = {}

for patch, patch_weight in patches_and_weights:
Expand Down
8 changes: 8 additions & 0 deletions invokeai/backend/model_manager/load/model_loaders/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@
is_state_dict_likely_in_flux_kohya_format,
lora_model_from_flux_kohya_state_dict,
)
from invokeai.backend.patches.lora_conversions.flux_onetrainer_lora_conversion_utils import (
is_state_dict_likely_in_flux_onetrainer_format,
lora_model_from_flux_onetrainer_state_dict,
)
from invokeai.backend.patches.lora_conversions.sd_lora_conversion_utils import lora_model_from_sd_state_dict
from invokeai.backend.patches.lora_conversions.sdxl_lora_conversion_utils import convert_sdxl_keys_to_diffusers_format

Expand Down Expand Up @@ -84,8 +88,12 @@ def _load_model(
elif config.format == ModelFormat.LyCORIS:
if is_state_dict_likely_in_flux_kohya_format(state_dict=state_dict):
model = lora_model_from_flux_kohya_state_dict(state_dict=state_dict)
elif is_state_dict_likely_in_flux_onetrainer_format(state_dict=state_dict):
model = lora_model_from_flux_onetrainer_state_dict(state_dict=state_dict)
elif is_state_dict_likely_flux_control(state_dict=state_dict):
model = lora_model_from_flux_control_state_dict(state_dict=state_dict)
else:
raise ValueError(f"LoRA model is in unsupported FLUX format: {config.format}")
else:
raise ValueError(f"LoRA model is in unsupported FLUX format: {config.format}")
elif self._model_base in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]:
Expand Down
6 changes: 5 additions & 1 deletion invokeai/backend/model_manager/probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@
from invokeai.backend.patches.lora_conversions.flux_kohya_lora_conversion_utils import (
is_state_dict_likely_in_flux_kohya_format,
)
from invokeai.backend.patches.lora_conversions.flux_onetrainer_lora_conversion_utils import (
is_state_dict_likely_in_flux_onetrainer_format,
)
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
Expand Down Expand Up @@ -283,7 +286,7 @@ def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: Optional[C
return ModelType.Main
elif key.startswith(("encoder.conv_in", "decoder.conv_in")):
return ModelType.VAE
elif key.startswith(("lora_te_", "lora_unet_")):
elif key.startswith(("lora_te_", "lora_unet_", "lora_te1_", "lora_te2_", "lora_transformer_")):
return ModelType.LoRA
# "lora_A.weight" and "lora_B.weight" are associated with models in PEFT format. We don't support all PEFT
# LoRA models, but as of the time of writing, we support Diffusers FLUX PEFT LoRA models.
Expand Down Expand Up @@ -632,6 +635,7 @@ def get_format(self) -> ModelFormat:
def get_base_type(self) -> BaseModelType:
if (
is_state_dict_likely_in_flux_kohya_format(self.checkpoint)
or is_state_dict_likely_in_flux_onetrainer_format(self.checkpoint)
or is_state_dict_likely_in_flux_diffusers_format(self.checkpoint)
or is_state_dict_likely_flux_control(self.checkpoint)
):
Expand Down
55 changes: 0 additions & 55 deletions invokeai/backend/patches/layers/concatenated_lora_layer.py

This file was deleted.

Loading

0 comments on commit debcbd6

Please sign in to comment.