-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Port LoRA to new classification API #7849
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
f251722
LoRA classification API
jazzhaiku 323d409
Make ruff happy
jazzhaiku 0d75c99
Caching
jazzhaiku c619348
Extract ModelOnDisk to its own module
jazzhaiku c276c1c
Comment
jazzhaiku c25f6d1
Merge branch 'main' into lora-classification
jazzhaiku 40c53ab
Guard
jazzhaiku 965753b
Ruff formatting
jazzhaiku f6c2ee5
Merge branch 'main' into lora-classification
jazzhaiku 9868c3b
Merge branch 'main' into lora-classification
jazzhaiku b31c102
Merge branch 'main' into lora-classification
jazzhaiku File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
from pathlib import Path | ||
from typing import Any, Optional, TypeAlias | ||
|
||
import safetensors.torch | ||
import torch | ||
from picklescan.scanner import scan_file_path | ||
|
||
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash | ||
from invokeai.backend.model_manager.taxonomy import ModelRepoVariant | ||
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader | ||
from invokeai.backend.util.silence_warnings import SilenceWarnings | ||
|
||
StateDict: TypeAlias = dict[str | int, Any] # When are the keys int? | ||
|
||
|
||
class ModelOnDisk: | ||
"""A utility class representing a model stored on disk.""" | ||
|
||
def __init__(self, path: Path, hash_algo: HASHING_ALGORITHMS = "blake3_single"): | ||
self.path = path | ||
if self.path.suffix in {".safetensors", ".bin", ".pt", ".ckpt"}: | ||
self.name = path.stem | ||
else: | ||
self.name = path.name | ||
self.hash_algo = hash_algo | ||
# Having a cache helps users of ModelOnDisk (i.e. configs) to save state | ||
# This prevents redundant computations during matching and parsing | ||
self.cache = {"_CACHED_STATE_DICTS": {}} | ||
|
||
def hash(self) -> str: | ||
return ModelHash(algorithm=self.hash_algo).hash(self.path) | ||
|
||
def size(self) -> int: | ||
if self.path.is_file(): | ||
return self.path.stat().st_size | ||
return sum(file.stat().st_size for file in self.path.rglob("*")) | ||
|
||
def component_paths(self) -> set[Path]: | ||
if self.path.is_file(): | ||
return {self.path} | ||
extensions = {".safetensors", ".pt", ".pth", ".ckpt", ".bin", ".gguf"} | ||
return {f for f in self.path.rglob("*") if f.suffix in extensions} | ||
|
||
def repo_variant(self) -> Optional[ModelRepoVariant]: | ||
if self.path.is_file(): | ||
return None | ||
|
||
weight_files = list(self.path.glob("**/*.safetensors")) | ||
weight_files.extend(list(self.path.glob("**/*.bin"))) | ||
for x in weight_files: | ||
if ".fp16" in x.suffixes: | ||
return ModelRepoVariant.FP16 | ||
if "openvino_model" in x.name: | ||
return ModelRepoVariant.OpenVINO | ||
if "flax_model" in x.name: | ||
return ModelRepoVariant.Flax | ||
if x.suffix == ".onnx": | ||
return ModelRepoVariant.ONNX | ||
return ModelRepoVariant.Default | ||
|
||
def load_state_dict(self, path: Optional[Path] = None) -> StateDict: | ||
sd_cache = self.cache["_CACHED_STATE_DICTS"] | ||
|
||
if path in sd_cache: | ||
return sd_cache[path] | ||
|
||
if not path: | ||
components = list(self.component_paths()) | ||
match components: | ||
case []: | ||
raise ValueError("No weight files found for this model") | ||
case [p]: | ||
path = p | ||
case ps if len(ps) >= 2: | ||
raise ValueError( | ||
f"Multiple weight files found for this model: {ps}. " | ||
f"Please specify the intended file using the 'path' argument" | ||
) | ||
|
||
with SilenceWarnings(): | ||
if path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")): | ||
scan_result = scan_file_path(path) | ||
if scan_result.infected_files != 0 or scan_result.scan_err: | ||
raise RuntimeError(f"The model {path.stem} is potentially infected by malware. Aborting import.") | ||
checkpoint = torch.load(path, map_location="cpu") | ||
assert isinstance(checkpoint, dict) | ||
elif path.suffix.endswith(".gguf"): | ||
checkpoint = gguf_sd_loader(path, compute_dtype=torch.float32) | ||
elif path.suffix.endswith(".safetensors"): | ||
checkpoint = safetensors.torch.load_file(path) | ||
else: | ||
raise ValueError(f"Unrecognized model extension: {path.suffix}") | ||
|
||
state_dict = checkpoint.get("state_dict", checkpoint) | ||
sd_cache[path] = state_dict | ||
return state_dict |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.