Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port LoRA to new classification API #7849

Merged
merged 11 commits into from
Mar 31, 2025
183 changes: 87 additions & 96 deletions invokeai/backend/model_manager/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,18 @@
from pathlib import Path
from typing import ClassVar, Literal, Optional, TypeAlias, Union

import safetensors.torch
import torch
from picklescan.scanner import scan_file_path
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
from typing_extensions import Annotated, Any, Dict

from invokeai.app.util.misc import uuid_string
from invokeai.backend.model_hash.hash_validator import validate_hash
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
from invokeai.backend.model_manager.taxonomy import (
AnyVariant,
BaseModelType,
ClipVariantType,
FluxLoRAFormat,
ModelFormat,
ModelRepoVariant,
ModelSourceType,
Expand All @@ -51,9 +50,8 @@
SchedulerPredictionType,
SubModelType,
)
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
from invokeai.backend.model_manager.util.model_util import lora_token_vector_length
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
from invokeai.backend.util.silence_warnings import SilenceWarnings

logger = logging.getLogger(__name__)

Expand All @@ -67,11 +65,6 @@ class InvalidModelConfigException(Exception):
DEFAULTS_PRECISION = Literal["fp16", "fp32"]


class FSLayout(Enum):
FILE = "file"
DIRECTORY = "directory"


class SubmodelDefinition(BaseModel):
path_or_prefix: str
model_type: ModelType
Expand Down Expand Up @@ -102,87 +95,6 @@ class ControlAdapterDefaultSettings(BaseModel):
model_config = ConfigDict(extra="forbid")


class ModelOnDisk:
"""A utility class representing a model stored on disk."""

def __init__(self, path: Path, hash_algo: HASHING_ALGORITHMS = "blake3_single"):
self.path = path
# TODO: Revisit checkpoint vs diffusers terminology
self.layout = FSLayout.DIRECTORY if path.is_dir() else FSLayout.FILE
if self.path.suffix in {".safetensors", ".bin", ".pt", ".ckpt"}:
self.name = path.stem
else:
self.name = path.name
self.hash_algo = hash_algo
self._state_dict_cache = {}

def hash(self) -> str:
return ModelHash(algorithm=self.hash_algo).hash(self.path)

def size(self) -> int:
if self.layout == FSLayout.FILE:
return self.path.stat().st_size
return sum(file.stat().st_size for file in self.path.rglob("*"))

def component_paths(self) -> set[Path]:
if self.layout == FSLayout.FILE:
return {self.path}
extensions = {".safetensors", ".pt", ".pth", ".ckpt", ".bin", ".gguf"}
return {f for f in self.path.rglob("*") if f.suffix in extensions}

def repo_variant(self) -> Optional[ModelRepoVariant]:
if self.layout == FSLayout.FILE:
return None

weight_files = list(self.path.glob("**/*.safetensors"))
weight_files.extend(list(self.path.glob("**/*.bin")))
for x in weight_files:
if ".fp16" in x.suffixes:
return ModelRepoVariant.FP16
if "openvino_model" in x.name:
return ModelRepoVariant.OpenVINO
if "flax_model" in x.name:
return ModelRepoVariant.Flax
if x.suffix == ".onnx":
return ModelRepoVariant.ONNX
return ModelRepoVariant.Default

def load_state_dict(self, path: Optional[Path] = None) -> Dict[str | int, Any]:
if path in self._state_dict_cache:
return self._state_dict_cache[path]

if not path:
components = list(self.component_paths())
match components:
case []:
raise ValueError("No weight files found for this model")
case [p]:
path = p
case ps if len(ps) >= 2:
raise ValueError(
f"Multiple weight files found for this model: {ps}. "
f"Please specify the intended file using the 'path' argument"
)

with SilenceWarnings():
if path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")):
scan_result = scan_file_path(path)
if scan_result.infected_files != 0 or scan_result.scan_err:
raise RuntimeError(f"The model {path.stem} is potentially infected by malware. Aborting import.")
checkpoint = torch.load(path, map_location="cpu")
assert isinstance(checkpoint, dict)
elif path.suffix.endswith(".gguf"):
checkpoint = gguf_sd_loader(path, compute_dtype=torch.float32)
elif path.suffix.endswith(".safetensors"):
checkpoint = safetensors.torch.load_file(path)
else:
raise ValueError(f"Unrecognized model extension: {path.suffix}")

state_dict = checkpoint.get("state_dict", checkpoint)
self._state_dict_cache[path] = state_dict
return state_dict


class MatchSpeed(int, Enum):
"""Represents the estimated runtime speed of a config's 'matches' method."""

Expand Down Expand Up @@ -257,7 +169,7 @@ def classify(model_path: Path, hash_algo: HASHING_ALGORITHMS = "blake3_single",
Created to deprecate ModelProbe.probe
"""
candidates = ModelConfigBase._USING_CLASSIFY_API
sorted_by_match_speed = sorted(candidates, key=lambda cls: cls._MATCH_SPEED)
sorted_by_match_speed = sorted(candidates, key=lambda cls: (cls._MATCH_SPEED, cls.__name__))
mod = ModelOnDisk(model_path, hash_algo)

for config_cls in sorted_by_match_speed:
Expand Down Expand Up @@ -308,6 +220,9 @@ def cast_overrides(overrides: dict[str, Any]):
if "source_type" in overrides:
overrides["source_type"] = ModelSourceType(overrides["source_type"])

if "variant" in overrides:
overrides["variant"] = ModelVariantType(overrides["variant"])

@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, **overrides):
"""Creates an instance of this config or raises InvalidModelConfigException."""
Expand Down Expand Up @@ -367,6 +282,38 @@ class LoRAConfigBase(ABC, BaseModel):
type: Literal[ModelType.LoRA] = ModelType.LoRA
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)

@classmethod
def flux_lora_format(cls, mod: ModelOnDisk):
key = "FLUX_LORA_FORMAT"
if key in mod.cache:
return mod.cache[key]

from invokeai.backend.patches.lora_conversions.formats import flux_format_from_state_dict

sd = mod.load_state_dict(mod.path)
value = flux_format_from_state_dict(sd)
mod.cache[key] = value
return value

@classmethod
def base_model(cls, mod: ModelOnDisk) -> BaseModelType:
if cls.flux_lora_format(mod):
return BaseModelType.Flux

state_dict = mod.load_state_dict()
# If we've gotten here, we assume that the model is a Stable Diffusion model
token_vector_length = lora_token_vector_length(state_dict)
if token_vector_length == 768:
return BaseModelType.StableDiffusion1
elif token_vector_length == 1024:
return BaseModelType.StableDiffusion2
elif token_vector_length == 1280:
return BaseModelType.StableDiffusionXL # recognizes format at https://civitai.com/models/224641
elif token_vector_length == 2048:
return BaseModelType.StableDiffusionXL
else:
raise InvalidModelConfigException("Unknown LoRA type")


class T5EncoderConfigBase(ABC, BaseModel):
"""Base class for diffusers-style models."""
Expand All @@ -382,11 +329,40 @@ class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase, LegacyProbeMixin,
format: Literal[ModelFormat.BnbQuantizedLlmInt8b] = ModelFormat.BnbQuantizedLlmInt8b


class LoRALyCORISConfig(LoRAConfigBase, LegacyProbeMixin, ModelConfigBase):
class LoRALyCORISConfig(LoRAConfigBase, ModelConfigBase):
"""Model config for LoRA/Lycoris models."""

format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS

@classmethod
def matches(cls, mod: ModelOnDisk) -> bool:
if mod.path.is_dir():
return False

# Avoid false positive match against ControlLoRA and Diffusers
if cls.flux_lora_format(mod) in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]:
return False

state_dict = mod.load_state_dict()
for key in state_dict.keys():
if type(key) is int:
continue

if key.startswith(("lora_te_", "lora_unet_", "lora_te1_", "lora_te2_", "lora_transformer_")):
return True
# "lora_A.weight" and "lora_B.weight" are associated with models in PEFT format. We don't support all PEFT
# LoRA models, but as of the time of writing, we support Diffusers FLUX PEFT LoRA models.
if key.endswith(("to_k_lora.up.weight", "to_q_lora.down.weight", "lora_A.weight", "lora_B.weight")):
return True

return False

@classmethod
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
return {
"base": cls.base_model(mod),
}


class ControlAdapterConfigBase(ABC, BaseModel):
default_settings: Optional[ControlAdapterDefaultSettings] = Field(
Expand All @@ -410,11 +386,26 @@ class ControlLoRADiffusersConfig(ControlAdapterConfigBase, LegacyProbeMixin, Mod
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers


class LoRADiffusersConfig(LoRAConfigBase, LegacyProbeMixin, ModelConfigBase):
class LoRADiffusersConfig(LoRAConfigBase, ModelConfigBase):
"""Model config for LoRA/Diffusers models."""

format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers

@classmethod
def matches(cls, mod: ModelOnDisk) -> bool:
if mod.path.is_file():
return cls.flux_lora_format(mod) == FluxLoRAFormat.Diffusers

suffixes = ["bin", "safetensors"]
weight_files = [mod.path / f"pytorch_lora_weights.{sfx}" for sfx in suffixes]
return any(wf.exists() for wf in weight_files)

@classmethod
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
return {
"base": cls.base_model(mod),
}


class VAECheckpointConfig(CheckpointConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for standalone VAE models."""
Expand Down Expand Up @@ -586,7 +577,7 @@ class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase):

@classmethod
def matches(cls, mod: ModelOnDisk) -> bool:
if mod.layout == FSLayout.FILE:
if mod.path.is_file():
return False

config_path = mod.path / "config.json"
Expand Down
96 changes: 96 additions & 0 deletions invokeai/backend/model_manager/model_on_disk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from pathlib import Path
from typing import Any, Optional, TypeAlias

import safetensors.torch
import torch
from picklescan.scanner import scan_file_path

from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
from invokeai.backend.model_manager.taxonomy import ModelRepoVariant
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
from invokeai.backend.util.silence_warnings import SilenceWarnings

StateDict: TypeAlias = dict[str | int, Any] # When are the keys int?


class ModelOnDisk:
"""A utility class representing a model stored on disk."""

def __init__(self, path: Path, hash_algo: HASHING_ALGORITHMS = "blake3_single"):
self.path = path
if self.path.suffix in {".safetensors", ".bin", ".pt", ".ckpt"}:
self.name = path.stem
else:
self.name = path.name
self.hash_algo = hash_algo
# Having a cache helps users of ModelOnDisk (i.e. configs) to save state
# This prevents redundant computations during matching and parsing
self.cache = {"_CACHED_STATE_DICTS": {}}

def hash(self) -> str:
return ModelHash(algorithm=self.hash_algo).hash(self.path)

def size(self) -> int:
if self.path.is_file():
return self.path.stat().st_size
return sum(file.stat().st_size for file in self.path.rglob("*"))

def component_paths(self) -> set[Path]:
if self.path.is_file():
return {self.path}
extensions = {".safetensors", ".pt", ".pth", ".ckpt", ".bin", ".gguf"}
return {f for f in self.path.rglob("*") if f.suffix in extensions}

def repo_variant(self) -> Optional[ModelRepoVariant]:
if self.path.is_file():
return None

weight_files = list(self.path.glob("**/*.safetensors"))
weight_files.extend(list(self.path.glob("**/*.bin")))
for x in weight_files:
if ".fp16" in x.suffixes:
return ModelRepoVariant.FP16
if "openvino_model" in x.name:
return ModelRepoVariant.OpenVINO
if "flax_model" in x.name:
return ModelRepoVariant.Flax
if x.suffix == ".onnx":
return ModelRepoVariant.ONNX
return ModelRepoVariant.Default

def load_state_dict(self, path: Optional[Path] = None) -> StateDict:
sd_cache = self.cache["_CACHED_STATE_DICTS"]

if path in sd_cache:
return sd_cache[path]

if not path:
components = list(self.component_paths())
match components:
case []:
raise ValueError("No weight files found for this model")
case [p]:
path = p
case ps if len(ps) >= 2:
raise ValueError(
f"Multiple weight files found for this model: {ps}. "
f"Please specify the intended file using the 'path' argument"
)

with SilenceWarnings():
if path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")):
scan_result = scan_file_path(path)
if scan_result.infected_files != 0 or scan_result.scan_err:
raise RuntimeError(f"The model {path.stem} is potentially infected by malware. Aborting import.")
checkpoint = torch.load(path, map_location="cpu")
assert isinstance(checkpoint, dict)
elif path.suffix.endswith(".gguf"):
checkpoint = gguf_sd_loader(path, compute_dtype=torch.float32)
elif path.suffix.endswith(".safetensors"):
checkpoint = safetensors.torch.load_file(path)
else:
raise ValueError(f"Unrecognized model extension: {path.suffix}")

state_dict = checkpoint.get("state_dict", checkpoint)
sd_cache[path] = state_dict
return state_dict
9 changes: 9 additions & 0 deletions invokeai/backend/model_manager/taxonomy.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,4 +126,13 @@ class ModelSourceType(str, Enum):
HFRepoID = "hf_repo_id"


class FluxLoRAFormat(str, Enum):
"""Flux LoRA formats."""

Diffusers = "flux.diffusers"
Kohya = "flux.kohya"
OneTrainer = "flux.onetrainer"
Control = "flux.control"


AnyVariant: TypeAlias = Union[ModelVariantType, ClipVariantType, None]
Loading