|
38 | 38 | AnyModelConfig,
|
39 | 39 | CheckpointConfigBase,
|
40 | 40 | InvalidModelConfigException,
|
| 41 | + ModelConfigBase, |
41 | 42 | ModelRepoVariant,
|
42 | 43 | ModelSourceType,
|
43 | 44 | )
|
| 45 | +from invokeai.backend.model_manager.legacy_probe import ModelProbe |
44 | 46 | from invokeai.backend.model_manager.metadata import (
|
45 | 47 | AnyModelRepoMetadata,
|
46 | 48 | HuggingFaceMetadataFetch,
|
|
49 | 51 | RemoteModelFile,
|
50 | 52 | )
|
51 | 53 | from invokeai.backend.model_manager.metadata.metadata_base import HuggingFaceMetadata
|
52 |
| -from invokeai.backend.model_manager.probe import ModelProbe |
53 | 54 | from invokeai.backend.model_manager.search import ModelSearch
|
54 | 55 | from invokeai.backend.util import InvokeAILogger
|
55 | 56 | from invokeai.backend.util.catch_sigint import catch_sigint
|
@@ -182,9 +183,7 @@ def install_path(
|
182 | 183 | ) -> str: # noqa D102
|
183 | 184 | model_path = Path(model_path)
|
184 | 185 | config = config or ModelRecordChanges()
|
185 |
| - info: AnyModelConfig = ModelProbe.probe( |
186 |
| - Path(model_path), config.model_dump(), hash_algo=self._app_config.hashing_algorithm |
187 |
| - ) # type: ignore |
| 186 | + info: AnyModelConfig = self._probe(Path(model_path), config) # type: ignore |
188 | 187 |
|
189 | 188 | if preferred_name := config.name:
|
190 | 189 | preferred_name = Path(preferred_name).with_suffix(model_path.suffix)
|
@@ -644,12 +643,22 @@ def _move_model(self, old_path: Path, new_path: Path) -> Path:
|
644 | 643 | move(old_path, new_path)
|
645 | 644 | return new_path
|
646 | 645 |
|
| 646 | + def _probe(self, model_path: Path, config: Optional[ModelRecordChanges] = None): |
| 647 | + config = config or ModelRecordChanges() |
| 648 | + hash_algo = self._app_config.hashing_algorithm |
| 649 | + fields = config.model_dump() |
| 650 | + |
| 651 | + try: |
| 652 | + return ModelConfigBase.classify(model_path=model_path, hash_algo=hash_algo, **fields) |
| 653 | + except InvalidModelConfigException: |
| 654 | + return ModelProbe.probe(model_path=model_path, fields=fields, hash_algo=hash_algo) # type: ignore |
| 655 | + |
647 | 656 | def _register(
|
648 | 657 | self, model_path: Path, config: Optional[ModelRecordChanges] = None, info: Optional[AnyModelConfig] = None
|
649 | 658 | ) -> str:
|
650 | 659 | config = config or ModelRecordChanges()
|
651 | 660 |
|
652 |
| - info = info or ModelProbe.probe(model_path, config.model_dump(), hash_algo=self._app_config.hashing_algorithm) # type: ignore |
| 661 | + info = info or self._probe(model_path, config) |
653 | 662 |
|
654 | 663 | model_path = model_path.resolve()
|
655 | 664 |
|
|
0 commit comments