1
1
import copy
2
2
from time import sleep
3
- from typing import List , Optional , Literal , Dict
3
+ from typing import Dict , List , Literal , Optional
4
4
5
5
from pydantic import BaseModel , Field
6
6
12
12
invocation_output ,
13
13
)
14
14
from invokeai .app .invocations .fields import FieldDescriptions , Input , InputField , OutputField , UIType
15
+ from invokeai .app .services .model_records import ModelRecordChanges
15
16
from invokeai .app .services .shared .invocation_context import InvocationContext
16
17
from invokeai .app .shared .models import FreeUConfig
17
- from invokeai .app .services .model_records import ModelRecordChanges
18
- from invokeai .backend .model_manager .config import AnyModelConfig , BaseModelType , ModelType , SubModelType , ModelFormat
18
+ from invokeai .backend .model_manager .config import AnyModelConfig , BaseModelType , ModelFormat , ModelType , SubModelType
19
19
20
20
21
21
class ModelIdentifierField (BaseModel ):
@@ -132,31 +132,22 @@ def invoke(self, context: InvocationContext) -> ModelIdentifierOutput:
132
132
133
133
return ModelIdentifierOutput (model = self .model )
134
134
135
- T5_ENCODER_OPTIONS = Literal ["base" , "16b_quantized" , "8b_quantized" ]
135
+
136
+ T5_ENCODER_OPTIONS = Literal ["base" , "8b_quantized" ]
136
137
T5_ENCODER_MAP : Dict [str , Dict [str , str ]] = {
137
138
"base" : {
138
- "text_encoder_repo" : "black-forest-labs/FLUX.1-schnell::text_encoder_2" ,
139
- "tokenizer_repo" : "black-forest-labs/FLUX.1-schnell::tokenizer_2" ,
140
- "text_encoder_name" : "FLUX.1-schnell_text_encoder_2" ,
141
- "tokenizer_name" : "FLUX.1-schnell_tokenizer_2" ,
139
+ "repo" : "invokeai/flux_dev::t5_xxl_encoder/base" ,
140
+ "name" : "t5_base_encoder" ,
142
141
"format" : ModelFormat .T5Encoder ,
143
142
},
144
143
"8b_quantized" : {
145
- "text_encoder_repo" : "hf_repo1" ,
146
- "tokenizer_repo" : "hf_repo1" ,
147
- "text_encoder_name" : "hf_repo1" ,
148
- "tokenizer_name" : "hf_repo1" ,
149
- "format" : ModelFormat .T5Encoder8b ,
150
- },
151
- "4b_quantized" : {
152
- "text_encoder_repo" : "hf_repo2" ,
153
- "tokenizer_repo" : "hf_repo2" ,
154
- "text_encoder_name" : "hf_repo2" ,
155
- "tokenizer_name" : "hf_repo2" ,
156
- "format" : ModelFormat .T5Encoder8b ,
144
+ "repo" : "invokeai/flux_dev::t5_xxl_encoder/8b_quantized" ,
145
+ "name" : "t5_8b_quantized_encoder" ,
146
+ "format" : ModelFormat .T5Encoder ,
157
147
},
158
148
}
159
149
150
+
160
151
@invocation_output ("flux_model_loader_output" )
161
152
class FluxModelLoaderOutput (BaseInvocationOutput ):
162
153
"""Flux base model loader output"""
@@ -176,7 +167,7 @@ class FluxModelLoaderInvocation(BaseInvocation):
176
167
ui_type = UIType .FluxMainModel ,
177
168
input = Input .Direct ,
178
169
)
179
-
170
+
180
171
t5_encoder : T5_ENCODER_OPTIONS = InputField (description = "The T5 Encoder model to use." )
181
172
182
173
def invoke (self , context : InvocationContext ) -> FluxModelLoaderOutput :
@@ -189,7 +180,15 @@ def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
189
180
tokenizer2 = self ._get_model (context , SubModelType .Tokenizer2 )
190
181
clip_encoder = self ._get_model (context , SubModelType .TextEncoder )
191
182
t5_encoder = self ._get_model (context , SubModelType .TextEncoder2 )
192
- vae = self ._install_model (context , SubModelType .VAE , "FLUX.1-schnell_ae" , "black-forest-labs/FLUX.1-schnell::ae.safetensors" , ModelFormat .Checkpoint , ModelType .VAE , BaseModelType .Flux )
183
+ vae = self ._install_model (
184
+ context ,
185
+ SubModelType .VAE ,
186
+ "FLUX.1-schnell_ae" ,
187
+ "black-forest-labs/FLUX.1-schnell::ae.safetensors" ,
188
+ ModelFormat .Checkpoint ,
189
+ ModelType .VAE ,
190
+ BaseModelType .Flux ,
191
+ )
193
192
194
193
return FluxModelLoaderOutput (
195
194
transformer = TransformerField (transformer = transformer ),
@@ -198,33 +197,59 @@ def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
198
197
vae = VAEField (vae = vae ),
199
198
)
200
199
201
- def _get_model (self , context : InvocationContext , submodel :SubModelType ) -> ModelIdentifierField :
202
- match ( submodel ) :
200
+ def _get_model (self , context : InvocationContext , submodel : SubModelType ) -> ModelIdentifierField :
201
+ match submodel :
203
202
case SubModelType .Transformer :
204
203
return self .model .model_copy (update = {"submodel_type" : SubModelType .Transformer })
205
204
case submodel if submodel in [SubModelType .Tokenizer , SubModelType .TextEncoder ]:
206
- return self ._install_model (context , submodel , "clip-vit-large-patch14" , "openai/clip-vit-large-patch14" , ModelFormat .Diffusers , ModelType .CLIPEmbed , BaseModelType .Any )
207
- case SubModelType .TextEncoder2 :
208
- return self ._install_model (context , submodel , T5_ENCODER_MAP [self .t5_encoder ]["text_encoder_name" ], T5_ENCODER_MAP [self .t5_encoder ]["text_encoder_repo" ], ModelFormat (T5_ENCODER_MAP [self .t5_encoder ]["format" ]), ModelType .T5Encoder , BaseModelType .Any )
209
- case SubModelType .Tokenizer2 :
210
- return self ._install_model (context , submodel , T5_ENCODER_MAP [self .t5_encoder ]["tokenizer_name" ], T5_ENCODER_MAP [self .t5_encoder ]["tokenizer_repo" ], ModelFormat (T5_ENCODER_MAP [self .t5_encoder ]["format" ]), ModelType .T5Encoder , BaseModelType .Any )
205
+ return self ._install_model (
206
+ context ,
207
+ submodel ,
208
+ "clip-vit-large-patch14" ,
209
+ "openai/clip-vit-large-patch14" ,
210
+ ModelFormat .Diffusers ,
211
+ ModelType .CLIPEmbed ,
212
+ BaseModelType .Any ,
213
+ )
214
+ case submodel if submodel in [SubModelType .Tokenizer2 , SubModelType .TextEncoder2 ]:
215
+ return self ._install_model (
216
+ context ,
217
+ submodel ,
218
+ T5_ENCODER_MAP [self .t5_encoder ]["name" ],
219
+ T5_ENCODER_MAP [self .t5_encoder ]["repo" ],
220
+ ModelFormat (T5_ENCODER_MAP [self .t5_encoder ]["format" ]),
221
+ ModelType .T5Encoder ,
222
+ BaseModelType .Any ,
223
+ )
211
224
case _:
212
- raise Exception (f"{ submodel .value } is not a supported submodule for a flux model" )
213
-
214
- def _install_model (self , context : InvocationContext , submodel :SubModelType , name : str , repo_id : str , format : ModelFormat , type : ModelType , base : BaseModelType ):
215
- if (models := context .models .search_by_attrs (name = name , base = base , type = type )):
225
+ raise Exception (f"{ submodel .value } is not a supported submodule for a flux model" )
226
+
227
+ def _install_model (
228
+ self ,
229
+ context : InvocationContext ,
230
+ submodel : SubModelType ,
231
+ name : str ,
232
+ repo_id : str ,
233
+ format : ModelFormat ,
234
+ type : ModelType ,
235
+ base : BaseModelType ,
236
+ ):
237
+ if models := context .models .search_by_attrs (name = name , base = base , type = type ):
216
238
if len (models ) != 1 :
217
239
raise Exception (f"Multiple models detected for selected model with name { name } " )
218
240
return ModelIdentifierField .from_config (models [0 ]).model_copy (update = {"submodel_type" : submodel })
219
241
else :
220
242
model_path = context .models .download_and_cache_model (repo_id )
221
- config = ModelRecordChanges (name = name , base = base , type = type , format = format )
243
+ config = ModelRecordChanges (name = name , base = base , type = type , format = format )
222
244
model_install_job = context .models .import_local_model (model_path = model_path , config = config )
223
245
while not model_install_job .in_terminal_state :
224
246
sleep (0.01 )
225
247
if not model_install_job .config_out :
226
248
raise Exception (f"Failed to install { name } " )
227
- return ModelIdentifierField .from_config (model_install_job .config_out ).model_copy (update = {"submodel_type" : submodel })
249
+ return ModelIdentifierField .from_config (model_install_job .config_out ).model_copy (
250
+ update = {"submodel_type" : submodel }
251
+ )
252
+
228
253
229
254
@invocation (
230
255
"main_model_loader" ,
0 commit comments