@@ -406,16 +406,94 @@ def base_model(cls, mod: ModelOnDisk) -> BaseModelType:
406406class T5EncoderConfigBase (ABC , BaseModel ):
407407 """Base class for diffusers-style models."""
408408
409+ base : Literal [BaseModelType .Any ] = BaseModelType .Any
409410 type : Literal [ModelType .T5Encoder ] = ModelType .T5Encoder
410411
412+ @classmethod
413+ def get_config (cls , mod : ModelOnDisk ) -> dict [str , Any ]:
414+ path = mod .path / "text_encoder_2" / "config.json"
415+ with open (path , "r" ) as file :
416+ return json .load (file )
417+
418+ @classmethod
419+ def parse (cls , mod : ModelOnDisk ) -> dict [str , Any ]:
420+ return {}
421+
411422
412- class T5EncoderConfig (T5EncoderConfigBase , LegacyProbeMixin , ModelConfigBase ):
423+ class T5EncoderConfig (T5EncoderConfigBase , ModelConfigBase ):
413424 format : Literal [ModelFormat .T5Encoder ] = ModelFormat .T5Encoder
414425
426+ @classmethod
427+ def matches (cls , mod : ModelOnDisk , ** overrides ) -> MatchCertainty :
428+ is_t5_type_override = overrides .get ("type" ) is ModelType .T5Encoder
429+ is_t5_format_override = overrides .get ("format" ) is ModelFormat .T5Encoder
430+
431+ if is_t5_type_override and is_t5_format_override :
432+ return MatchCertainty .OVERRIDE
433+
434+ if mod .path .is_file ():
435+ return MatchCertainty .NEVER
436+
437+ model_dir = mod .path / "text_encoder_2"
438+
439+ if not model_dir .exists ():
440+ return MatchCertainty .NEVER
441+
442+ try :
443+ config = cls .get_config (mod )
444+
445+ is_t5_encoder_model = get_class_name_from_config (config ) == "T5EncoderModel"
446+ is_t5_format = (model_dir / "model.safetensors.index.json" ).exists ()
415447
416- class T5EncoderBnbQuantizedLlmInt8bConfig (T5EncoderConfigBase , LegacyProbeMixin , ModelConfigBase ):
448+ if is_t5_encoder_model and is_t5_format :
449+ return MatchCertainty .EXACT
450+ except Exception :
451+ pass
452+
453+ return MatchCertainty .NEVER
454+
455+
456+ class T5EncoderBnbQuantizedLlmInt8bConfig (T5EncoderConfigBase , ModelConfigBase ):
417457 format : Literal [ModelFormat .BnbQuantizedLlmInt8b ] = ModelFormat .BnbQuantizedLlmInt8b
418458
459+ @classmethod
460+ def matches (cls , mod : ModelOnDisk , ** overrides ) -> MatchCertainty :
461+ is_t5_type_override = overrides .get ("type" ) is ModelType .T5Encoder
462+ is_bnb_format_override = overrides .get ("format" ) is ModelFormat .BnbQuantizedLlmInt8b
463+
464+ if is_t5_type_override and is_bnb_format_override :
465+ return MatchCertainty .OVERRIDE
466+
467+ if mod .path .is_file ():
468+ return MatchCertainty .NEVER
469+
470+ model_dir = mod .path / "text_encoder_2"
471+
472+ if not model_dir .exists ():
473+ return MatchCertainty .NEVER
474+
475+ try :
476+ config = cls .get_config (mod )
477+
478+ is_t5_encoder_model = get_class_name_from_config (config ) == "T5EncoderModel"
479+
480+ # Heuristic: look for the quantization in the name
481+ files = model_dir .glob ("*.safetensors" )
482+ filename_looks_like_bnb = any (x for x in files if "llm_int8" in x .as_posix ())
483+
484+ if is_t5_encoder_model and filename_looks_like_bnb :
485+ return MatchCertainty .EXACT
486+
487+ # Heuristic: Look for the presence of "SCB" in state dict keys (typically a suffix)
488+ has_scb_key = mod .has_keys_ending_with ("SCB" )
489+
490+ if is_t5_encoder_model and has_scb_key :
491+ return MatchCertainty .EXACT
492+ except Exception :
493+ pass
494+
495+ return MatchCertainty .NEVER
496+
419497
420498class LoRAOmiConfig (LoRAConfigBase , ModelConfigBase ):
421499 format : Literal [ModelFormat .OMI ] = ModelFormat .OMI
0 commit comments