Skip to content

Commit 133a7fd

Browse files
Model classification api (#7742)
## Summary The _goal_ of this PR is to make it easier to add an new config type. This _scope_ of this PR is to integrate the API and does not include adding new configs (outside tests) or porting existing ones. One of the glaring issues of the existing *legacy probe* is that the logic for each type is spread across multiple classes and intertwined with the other configs. This means that adding a new config type (or modifying an existing one) is complex and error prone. This PR attempts to remedy this by providing a new API for adding configs that: - Is backwards compatible with the existing probe. - Encapsulates fields and logic in a single class, keeping things self-contained and easy to modify safely. Below is a minimal toy example illustrating the proposed new structure: ```python class MinimalConfigExample(ModelConfigBase): type: ModelType = ModelType.Main format: ModelFormat = ModelFormat.Checkpoint fun_quote: str @classmethod def matches(cls, mod: ModelOnDisk) -> bool: return mod.path.suffix == ".json" @classmethod def parse(cls, mod: ModelOnDisk) -> dict[str, Any]: with open(mod.path, "r") as f: contents = json.load(f) return { "fun_quote": contents["quote"], "base": BaseModelType.Any, } ``` To create a new config type, one needs to inherit from `ModelConfigBase` and implement its interface. The code falls back to the legacy model probe for existing models using the old API. This allows us to incrementally port the configs one by one. ## Related Issues / Discussions <!--WHEN APPLICABLE: List any related issues or discussions on github or discord. If this PR closes an issue, please use the "Closes #1234" format, so that the issue will be automatically closed when the PR merges.--> ## QA Instructions <!--WHEN APPLICABLE: Describe how you have tested the changes in this PR. Provide enough detail that a reviewer can reproduce your tests.--> ## Merge Plan <!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like DB schemas, may need some care when merging. For example, a careful rebase by the change author, timing to not interfere with a pending release, or a message to contributors on discord after merging.--> ## Checklist - [x] _The PR has a short but descriptive title, suitable for a changelog_ - [x] _Tests added / updated (if applicable)_ - [x] _Documentation added / updated (if applicable)_ - [ ] _Updated `What's New` copy (if doing a release after this PR)_
2 parents 1f86320 + e61c5a3 commit 133a7fd

File tree

10 files changed

+938
-800
lines changed

10 files changed

+938
-800
lines changed

invokeai/app/services/model_install/model_install_default.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,11 @@
3838
AnyModelConfig,
3939
CheckpointConfigBase,
4040
InvalidModelConfigException,
41+
ModelConfigBase,
4142
ModelRepoVariant,
4243
ModelSourceType,
4344
)
45+
from invokeai.backend.model_manager.legacy_probe import ModelProbe
4446
from invokeai.backend.model_manager.metadata import (
4547
AnyModelRepoMetadata,
4648
HuggingFaceMetadataFetch,
@@ -49,7 +51,6 @@
4951
RemoteModelFile,
5052
)
5153
from invokeai.backend.model_manager.metadata.metadata_base import HuggingFaceMetadata
52-
from invokeai.backend.model_manager.probe import ModelProbe
5354
from invokeai.backend.model_manager.search import ModelSearch
5455
from invokeai.backend.util import InvokeAILogger
5556
from invokeai.backend.util.catch_sigint import catch_sigint
@@ -182,9 +183,7 @@ def install_path(
182183
) -> str: # noqa D102
183184
model_path = Path(model_path)
184185
config = config or ModelRecordChanges()
185-
info: AnyModelConfig = ModelProbe.probe(
186-
Path(model_path), config.model_dump(), hash_algo=self._app_config.hashing_algorithm
187-
) # type: ignore
186+
info: AnyModelConfig = self._probe(Path(model_path), config) # type: ignore
188187

189188
if preferred_name := config.name:
190189
preferred_name = Path(preferred_name).with_suffix(model_path.suffix)
@@ -644,12 +643,22 @@ def _move_model(self, old_path: Path, new_path: Path) -> Path:
644643
move(old_path, new_path)
645644
return new_path
646645

646+
def _probe(self, model_path: Path, config: Optional[ModelRecordChanges] = None):
647+
config = config or ModelRecordChanges()
648+
hash_algo = self._app_config.hashing_algorithm
649+
fields = config.model_dump()
650+
651+
try:
652+
return ModelConfigBase.classify(model_path=model_path, hash_algo=hash_algo, **fields)
653+
except InvalidModelConfigException:
654+
return ModelProbe.probe(model_path=model_path, fields=fields, hash_algo=hash_algo) # type: ignore
655+
647656
def _register(
648657
self, model_path: Path, config: Optional[ModelRecordChanges] = None, info: Optional[AnyModelConfig] = None
649658
) -> str:
650659
config = config or ModelRecordChanges()
651660

652-
info = info or ModelProbe.probe(model_path, config.model_dump(), hash_algo=self._app_config.hashing_algorithm) # type: ignore
661+
info = info or self._probe(model_path, config)
653662

654663
model_path = model_path.resolve()
655664

invokeai/backend/model_manager/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
SchedulerPredictionType,
1414
SubModelType,
1515
)
16+
from invokeai.backend.model_manager.legacy_probe import ModelProbe
1617
from invokeai.backend.model_manager.load import LoadedModel
17-
from invokeai.backend.model_manager.probe import ModelProbe
1818
from invokeai.backend.model_manager.search import ModelSearch
1919

2020
__all__ = [

0 commit comments

Comments
 (0)