11import copy
22from time import sleep
3- from typing import List , Optional , Literal , Dict
3+ from typing import Dict , List , Literal , Optional
44
55from pydantic import BaseModel , Field
66
1212 invocation_output ,
1313)
1414from invokeai .app .invocations .fields import FieldDescriptions , Input , InputField , OutputField , UIType
15+ from invokeai .app .services .model_records import ModelRecordChanges
1516from invokeai .app .services .shared .invocation_context import InvocationContext
1617from invokeai .app .shared .models import FreeUConfig
17- from invokeai .app .services .model_records import ModelRecordChanges
18- from invokeai .backend .model_manager .config import AnyModelConfig , BaseModelType , ModelType , SubModelType , ModelFormat
18+ from invokeai .backend .model_manager .config import AnyModelConfig , BaseModelType , ModelFormat , ModelType , SubModelType
1919
2020
2121class ModelIdentifierField (BaseModel ):
@@ -132,31 +132,22 @@ def invoke(self, context: InvocationContext) -> ModelIdentifierOutput:
132132
133133 return ModelIdentifierOutput (model = self .model )
134134
135- T5_ENCODER_OPTIONS = Literal ["base" , "16b_quantized" , "8b_quantized" ]
135+
136+ T5_ENCODER_OPTIONS = Literal ["base" , "8b_quantized" ]
136137T5_ENCODER_MAP : Dict [str , Dict [str , str ]] = {
137138 "base" : {
138- "text_encoder_repo" : "black-forest-labs/FLUX.1-schnell::text_encoder_2" ,
139- "tokenizer_repo" : "black-forest-labs/FLUX.1-schnell::tokenizer_2" ,
140- "text_encoder_name" : "FLUX.1-schnell_text_encoder_2" ,
141- "tokenizer_name" : "FLUX.1-schnell_tokenizer_2" ,
139+ "repo" : "invokeai/flux_dev::t5_xxl_encoder/base" ,
140+ "name" : "t5_base_encoder" ,
142141 "format" : ModelFormat .T5Encoder ,
143142 },
144143 "8b_quantized" : {
145- "text_encoder_repo" : "hf_repo1" ,
146- "tokenizer_repo" : "hf_repo1" ,
147- "text_encoder_name" : "hf_repo1" ,
148- "tokenizer_name" : "hf_repo1" ,
149- "format" : ModelFormat .T5Encoder8b ,
150- },
151- "4b_quantized" : {
152- "text_encoder_repo" : "hf_repo2" ,
153- "tokenizer_repo" : "hf_repo2" ,
154- "text_encoder_name" : "hf_repo2" ,
155- "tokenizer_name" : "hf_repo2" ,
156- "format" : ModelFormat .T5Encoder8b ,
144+ "repo" : "invokeai/flux_dev::t5_xxl_encoder/8b_quantized" ,
145+ "name" : "t5_8b_quantized_encoder" ,
146+ "format" : ModelFormat .T5Encoder ,
157147 },
158148}
159149
150+
160151@invocation_output ("flux_model_loader_output" )
161152class FluxModelLoaderOutput (BaseInvocationOutput ):
162153 """Flux base model loader output"""
@@ -176,7 +167,7 @@ class FluxModelLoaderInvocation(BaseInvocation):
176167 ui_type = UIType .FluxMainModel ,
177168 input = Input .Direct ,
178169 )
179-
170+
180171 t5_encoder : T5_ENCODER_OPTIONS = InputField (description = "The T5 Encoder model to use." )
181172
182173 def invoke (self , context : InvocationContext ) -> FluxModelLoaderOutput :
@@ -189,7 +180,15 @@ def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
189180 tokenizer2 = self ._get_model (context , SubModelType .Tokenizer2 )
190181 clip_encoder = self ._get_model (context , SubModelType .TextEncoder )
191182 t5_encoder = self ._get_model (context , SubModelType .TextEncoder2 )
192- vae = self ._install_model (context , SubModelType .VAE , "FLUX.1-schnell_ae" , "black-forest-labs/FLUX.1-schnell::ae.safetensors" , ModelFormat .Checkpoint , ModelType .VAE , BaseModelType .Flux )
183+ vae = self ._install_model (
184+ context ,
185+ SubModelType .VAE ,
186+ "FLUX.1-schnell_ae" ,
187+ "black-forest-labs/FLUX.1-schnell::ae.safetensors" ,
188+ ModelFormat .Checkpoint ,
189+ ModelType .VAE ,
190+ BaseModelType .Flux ,
191+ )
193192
194193 return FluxModelLoaderOutput (
195194 transformer = TransformerField (transformer = transformer ),
@@ -198,33 +197,59 @@ def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
198197 vae = VAEField (vae = vae ),
199198 )
200199
201- def _get_model (self , context : InvocationContext , submodel :SubModelType ) -> ModelIdentifierField :
202- match ( submodel ) :
200+ def _get_model (self , context : InvocationContext , submodel : SubModelType ) -> ModelIdentifierField :
201+ match submodel :
203202 case SubModelType .Transformer :
204203 return self .model .model_copy (update = {"submodel_type" : SubModelType .Transformer })
205204 case submodel if submodel in [SubModelType .Tokenizer , SubModelType .TextEncoder ]:
206- return self ._install_model (context , submodel , "clip-vit-large-patch14" , "openai/clip-vit-large-patch14" , ModelFormat .Diffusers , ModelType .CLIPEmbed , BaseModelType .Any )
207- case SubModelType .TextEncoder2 :
208- return self ._install_model (context , submodel , T5_ENCODER_MAP [self .t5_encoder ]["text_encoder_name" ], T5_ENCODER_MAP [self .t5_encoder ]["text_encoder_repo" ], ModelFormat (T5_ENCODER_MAP [self .t5_encoder ]["format" ]), ModelType .T5Encoder , BaseModelType .Any )
209- case SubModelType .Tokenizer2 :
210- return self ._install_model (context , submodel , T5_ENCODER_MAP [self .t5_encoder ]["tokenizer_name" ], T5_ENCODER_MAP [self .t5_encoder ]["tokenizer_repo" ], ModelFormat (T5_ENCODER_MAP [self .t5_encoder ]["format" ]), ModelType .T5Encoder , BaseModelType .Any )
205+ return self ._install_model (
206+ context ,
207+ submodel ,
208+ "clip-vit-large-patch14" ,
209+ "openai/clip-vit-large-patch14" ,
210+ ModelFormat .Diffusers ,
211+ ModelType .CLIPEmbed ,
212+ BaseModelType .Any ,
213+ )
214+ case submodel if submodel in [SubModelType .Tokenizer2 , SubModelType .TextEncoder2 ]:
215+ return self ._install_model (
216+ context ,
217+ submodel ,
218+ T5_ENCODER_MAP [self .t5_encoder ]["name" ],
219+ T5_ENCODER_MAP [self .t5_encoder ]["repo" ],
220+ ModelFormat (T5_ENCODER_MAP [self .t5_encoder ]["format" ]),
221+ ModelType .T5Encoder ,
222+ BaseModelType .Any ,
223+ )
211224 case _:
212- raise Exception (f"{ submodel .value } is not a supported submodule for a flux model" )
213-
214- def _install_model (self , context : InvocationContext , submodel :SubModelType , name : str , repo_id : str , format : ModelFormat , type : ModelType , base : BaseModelType ):
215- if (models := context .models .search_by_attrs (name = name , base = base , type = type )):
225+ raise Exception (f"{ submodel .value } is not a supported submodule for a flux model" )
226+
227+ def _install_model (
228+ self ,
229+ context : InvocationContext ,
230+ submodel : SubModelType ,
231+ name : str ,
232+ repo_id : str ,
233+ format : ModelFormat ,
234+ type : ModelType ,
235+ base : BaseModelType ,
236+ ):
237+ if models := context .models .search_by_attrs (name = name , base = base , type = type ):
216238 if len (models ) != 1 :
217239 raise Exception (f"Multiple models detected for selected model with name { name } " )
218240 return ModelIdentifierField .from_config (models [0 ]).model_copy (update = {"submodel_type" : submodel })
219241 else :
220242 model_path = context .models .download_and_cache_model (repo_id )
221- config = ModelRecordChanges (name = name , base = base , type = type , format = format )
243+ config = ModelRecordChanges (name = name , base = base , type = type , format = format )
222244 model_install_job = context .models .import_local_model (model_path = model_path , config = config )
223245 while not model_install_job .in_terminal_state :
224246 sleep (0.01 )
225247 if not model_install_job .config_out :
226248 raise Exception (f"Failed to install { name } " )
227- return ModelIdentifierField .from_config (model_install_job .config_out ).model_copy (update = {"submodel_type" : submodel })
249+ return ModelIdentifierField .from_config (model_install_job .config_out ).model_copy (
250+ update = {"submodel_type" : submodel }
251+ )
252+
228253
229254@invocation (
230255 "main_model_loader" ,
0 commit comments