11import copy
2- from typing import List , Optional
2+ from time import sleep
3+ from typing import List , Optional , Literal , Dict
34
45from pydantic import BaseModel , Field
56
1314from invokeai .app .invocations .fields import FieldDescriptions , Input , InputField , OutputField , UIType
1415from invokeai .app .services .shared .invocation_context import InvocationContext
1516from invokeai .app .shared .models import FreeUConfig
16- from invokeai .backend .model_manager .config import AnyModelConfig , BaseModelType , ModelType , SubModelType
17+ from invokeai .app .services .model_records import ModelRecordChanges
18+ from invokeai .backend .model_manager .config import AnyModelConfig , BaseModelType , ModelType , SubModelType , ModelFormat
1719
1820
1921class ModelIdentifierField (BaseModel ):
@@ -62,7 +64,6 @@ class CLIPField(BaseModel):
6264
6365class TransformerField (BaseModel ):
6466 transformer : ModelIdentifierField = Field (description = "Info to load Transformer submodel" )
65- scheduler : ModelIdentifierField = Field (description = "Info to load scheduler submodel" )
6667
6768
6869class T5EncoderField (BaseModel ):
@@ -131,6 +132,30 @@ def invoke(self, context: InvocationContext) -> ModelIdentifierOutput:
131132
132133 return ModelIdentifierOutput (model = self .model )
133134
135+ T5_ENCODER_OPTIONS = Literal ["base" , "16b_quantized" , "8b_quantized" ]
136+ T5_ENCODER_MAP : Dict [str , Dict [str , str ]] = {
137+ "base" : {
138+ "text_encoder_repo" : "black-forest-labs/FLUX.1-schnell::text_encoder_2" ,
139+ "tokenizer_repo" : "black-forest-labs/FLUX.1-schnell::tokenizer_2" ,
140+ "text_encoder_name" : "FLUX.1-schnell_text_encoder_2" ,
141+ "tokenizer_name" : "FLUX.1-schnell_tokenizer_2" ,
142+ "format" : ModelFormat .T5Encoder ,
143+ },
144+ "8b_quantized" : {
145+ "text_encoder_repo" : "hf_repo1" ,
146+ "tokenizer_repo" : "hf_repo1" ,
147+ "text_encoder_name" : "hf_repo1" ,
148+ "tokenizer_name" : "hf_repo1" ,
149+ "format" : ModelFormat .T5Encoder8b ,
150+ },
151+ "4b_quantized" : {
152+ "text_encoder_repo" : "hf_repo2" ,
153+ "tokenizer_repo" : "hf_repo2" ,
154+ "text_encoder_name" : "hf_repo2" ,
155+ "tokenizer_name" : "hf_repo2" ,
156+ "format" : ModelFormat .T5Encoder8b ,
157+ },
158+ }
134159
135160@invocation_output ("flux_model_loader_output" )
136161class FluxModelLoaderOutput (BaseInvocationOutput ):
@@ -151,29 +176,55 @@ class FluxModelLoaderInvocation(BaseInvocation):
151176 ui_type = UIType .FluxMainModel ,
152177 input = Input .Direct ,
153178 )
179+
180+ t5_encoder : T5_ENCODER_OPTIONS = InputField (description = "The T5 Encoder model to use." )
154181
155182 def invoke (self , context : InvocationContext ) -> FluxModelLoaderOutput :
156183 model_key = self .model .key
157184
158- # TODO: not found exceptions
159185 if not context .models .exists (model_key ):
160186 raise Exception (f"Unknown model: { model_key } " )
161-
162- transformer = self .model .model_copy (update = {"submodel_type" : SubModelType .Transformer })
163- scheduler = self .model .model_copy (update = {"submodel_type" : SubModelType .Scheduler })
164- tokenizer = self .model .model_copy (update = {"submodel_type" : SubModelType .Tokenizer })
165- text_encoder = self .model .model_copy (update = {"submodel_type" : SubModelType .TextEncoder })
166- tokenizer2 = self .model .model_copy (update = {"submodel_type" : SubModelType .Tokenizer2 })
167- text_encoder2 = self .model .model_copy (update = {"submodel_type" : SubModelType .TextEncoder2 })
168- vae = self .model .model_copy (update = {"submodel_type" : SubModelType .VAE })
187+ transformer = self ._get_model (context , SubModelType .Transformer )
188+ tokenizer = self ._get_model (context , SubModelType .Tokenizer )
189+ tokenizer2 = self ._get_model (context , SubModelType .Tokenizer2 )
190+ clip_encoder = self ._get_model (context , SubModelType .TextEncoder )
191+ t5_encoder = self ._get_model (context , SubModelType .TextEncoder2 )
192+ vae = self ._install_model (context , SubModelType .VAE , "FLUX.1-schnell_ae" , "black-forest-labs/FLUX.1-schnell::ae.safetensors" , ModelFormat .Checkpoint , ModelType .VAE , BaseModelType .Flux )
169193
170194 return FluxModelLoaderOutput (
171- transformer = TransformerField (transformer = transformer , scheduler = scheduler ),
172- clip = CLIPField (tokenizer = tokenizer , text_encoder = text_encoder , loras = [], skipped_layers = 0 ),
173- t5Encoder = T5EncoderField (tokenizer = tokenizer2 , text_encoder = text_encoder2 ),
195+ transformer = TransformerField (transformer = transformer ),
196+ clip = CLIPField (tokenizer = tokenizer , text_encoder = clip_encoder , loras = [], skipped_layers = 0 ),
197+ t5Encoder = T5EncoderField (tokenizer = tokenizer2 , text_encoder = t5_encoder ),
174198 vae = VAEField (vae = vae ),
175199 )
176200
201+ def _get_model (self , context : InvocationContext , submodel :SubModelType ) -> ModelIdentifierField :
202+ match (submodel ):
203+ case SubModelType .Transformer :
204+ return self .model .model_copy (update = {"submodel_type" : SubModelType .Transformer })
205+ case submodel if submodel in [SubModelType .Tokenizer , SubModelType .TextEncoder ]:
206+ return self ._install_model (context , submodel , "clip-vit-large-patch14" , "openai/clip-vit-large-patch14" , ModelFormat .Diffusers , ModelType .CLIPEmbed , BaseModelType .Any )
207+ case SubModelType .TextEncoder2 :
208+ return self ._install_model (context , submodel , T5_ENCODER_MAP [self .t5_encoder ]["text_encoder_name" ], T5_ENCODER_MAP [self .t5_encoder ]["text_encoder_repo" ], ModelFormat (T5_ENCODER_MAP [self .t5_encoder ]["format" ]), ModelType .T5Encoder , BaseModelType .Any )
209+ case SubModelType .Tokenizer2 :
210+ return self ._install_model (context , submodel , T5_ENCODER_MAP [self .t5_encoder ]["tokenizer_name" ], T5_ENCODER_MAP [self .t5_encoder ]["tokenizer_repo" ], ModelFormat (T5_ENCODER_MAP [self .t5_encoder ]["format" ]), ModelType .T5Encoder , BaseModelType .Any )
211+ case _:
212+ raise Exception (f"{ submodel .value } is not a supported submodule for a flux model" )
213+
214+ def _install_model (self , context : InvocationContext , submodel :SubModelType , name : str , repo_id : str , format : ModelFormat , type : ModelType , base : BaseModelType ):
215+ if (models := context .models .search_by_attrs (name = name , base = base , type = type )):
216+ if len (models ) != 1 :
217+ raise Exception (f"Multiple models detected for selected model with name { name } " )
218+ return ModelIdentifierField .from_config (models [0 ]).model_copy (update = {"submodel_type" : submodel })
219+ else :
220+ model_path = context .models .download_and_cache_model (repo_id )
221+ config = ModelRecordChanges (name = name , base = base , type = type , format = format )
222+ model_install_job = context .models .import_local_model (model_path = model_path , config = config )
223+ while not model_install_job .in_terminal_state :
224+ sleep (0.01 )
225+ if not model_install_job .config_out :
226+ raise Exception (f"Failed to install { name } " )
227+ return ModelIdentifierField .from_config (model_install_job .config_out ).model_copy (update = {"submodel_type" : submodel })
177228
178229@invocation (
179230 "main_model_loader" ,
0 commit comments