@@ -111,6 +111,7 @@ class ModelFormat(str, Enum):
111111 T5Encoder = "t5_encoder"
112112 T5Encoder8b = "t5_encoder_8b"
113113 T5Encoder4b = "t5_encoder_4b"
114+ BnbQuantizednf4b = "bnb_quantized_nf4b"
114115
115116
116117class SchedulerPredictionType (str , Enum ):
@@ -193,7 +194,7 @@ def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> N
193194class CheckpointConfigBase (ModelConfigBase ):
194195 """Model config for checkpoint-style models."""
195196
196- format : Literal [ModelFormat .Checkpoint ] = ModelFormat .Checkpoint
197+ format : Literal [ModelFormat .Checkpoint , ModelFormat . BnbQuantizednf4b ] = Field ( description = "Format of the provided checkpoint model" , default = ModelFormat .Checkpoint )
197198 config_path : str = Field (description = "path to the checkpoint model config file" )
198199 converted_at : Optional [float ] = Field (
199200 description = "When this model was last converted to diffusers" , default_factory = time .time
@@ -248,7 +249,6 @@ class VAECheckpointConfig(CheckpointConfigBase):
248249 """Model config for standalone VAE models."""
249250
250251 type : Literal [ModelType .VAE ] = ModelType .VAE
251- format : Literal [ModelFormat .Checkpoint ] = ModelFormat .Checkpoint
252252
253253 @staticmethod
254254 def get_tag () -> Tag :
@@ -287,7 +287,6 @@ class ControlNetCheckpointConfig(CheckpointConfigBase, ControlAdapterConfigBase)
287287 """Model config for ControlNet models (diffusers version)."""
288288
289289 type : Literal [ModelType .ControlNet ] = ModelType .ControlNet
290- format : Literal [ModelFormat .Checkpoint ] = ModelFormat .Checkpoint
291290
292291 @staticmethod
293292 def get_tag () -> Tag :
@@ -336,6 +335,21 @@ def get_tag() -> Tag:
336335 return Tag (f"{ ModelType .Main .value } .{ ModelFormat .Checkpoint .value } " )
337336
338337
338+ class MainBnbQuantized4bCheckpointConfig (CheckpointConfigBase , MainConfigBase ):
339+ """Model config for main checkpoint models."""
340+
341+ prediction_type : SchedulerPredictionType = SchedulerPredictionType .Epsilon
342+ upcast_attention : bool = False
343+
344+ def __init__ (self , * args , ** kwargs ):
345+ super ().__init__ (* args , ** kwargs )
346+ self .format = ModelFormat .BnbQuantizednf4b
347+
348+ @staticmethod
349+ def get_tag () -> Tag :
350+ return Tag (f"{ ModelType .Main .value } .{ ModelFormat .BnbQuantizednf4b .value } " )
351+
352+
339353class MainDiffusersConfig (DiffusersConfigBase , MainConfigBase ):
340354 """Model config for main diffusers models."""
341355
@@ -438,6 +452,7 @@ def get_model_discriminator_value(v: Any) -> str:
438452 Union [
439453 Annotated [MainDiffusersConfig , MainDiffusersConfig .get_tag ()],
440454 Annotated [MainCheckpointConfig , MainCheckpointConfig .get_tag ()],
455+ Annotated [MainBnbQuantized4bCheckpointConfig , MainBnbQuantized4bCheckpointConfig .get_tag ()],
441456 Annotated [VAEDiffusersConfig , VAEDiffusersConfig .get_tag ()],
442457 Annotated [VAECheckpointConfig , VAECheckpointConfig .get_tag ()],
443458 Annotated [ControlNetDiffusersConfig , ControlNetDiffusersConfig .get_tag ()],
0 commit comments