@@ -1103,39 +1103,83 @@ struct FluxCLIPEmbedder : public Conditioner {
11031103 std::shared_ptr<T5Runner> t5;
11041104 size_t chunk_len = 256 ;
11051105
1106+ bool use_clip_l = false ;
1107+ bool use_t5 = false ;
1108+
11061109 FluxCLIPEmbedder (ggml_backend_t backend,
11071110 std::map<std::string, enum ggml_type>& tensor_types,
11081111 int clip_skip = -1 ) {
1109- clip_l = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, " text_encoders.clip_l.transformer.text_model" , OPENAI_CLIP_VIT_L_14, true );
1110- t5 = std::make_shared<T5Runner>(backend, tensor_types, " text_encoders.t5xxl.transformer" );
1112+
1113+ for (auto pair : tensor_types) {
1114+ if (pair.first .find (" text_encoders.clip_l" ) != std::string::npos) {
1115+ use_clip_l = true ;
1116+ } else if (pair.first .find (" text_encoders.t5xxl" ) != std::string::npos) {
1117+ use_t5 = true ;
1118+ }
1119+ }
1120+
1121+ if (!use_clip_l && !use_t5) {
1122+ LOG_WARN (" IMPORTANT NOTICE: No text encoders provided, cannot process prompts!" );
1123+ return ;
1124+ }
1125+
1126+ if (use_clip_l) {
1127+ clip_l = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, " text_encoders.clip_l.transformer.text_model" , OPENAI_CLIP_VIT_L_14, true );
1128+ } else {
1129+ LOG_WARN (" clip_l text encoder not found! Prompt adherence might be degraded." );
1130+ }
1131+ if (use_t5) {
1132+ t5 = std::make_shared<T5Runner>(backend, tensor_types, " text_encoders.t5xxl.transformer" );
1133+ } else {
1134+ LOG_WARN (" t5xxl text encoder not found! Prompt adherence might be degraded." );
1135+ }
11111136 set_clip_skip (clip_skip);
11121137 }
11131138
11141139 void set_clip_skip (int clip_skip) {
11151140 if (clip_skip <= 0 ) {
11161141 clip_skip = 2 ;
11171142 }
1118- clip_l->set_clip_skip (clip_skip);
1143+ if (use_clip_l) {
1144+ clip_l->set_clip_skip (clip_skip);
1145+ }
11191146 }
11201147
11211148 void get_param_tensors (std::map<std::string, struct ggml_tensor *>& tensors) {
1122- clip_l->get_param_tensors (tensors, " text_encoders.clip_l.transformer.text_model" );
1123- t5->get_param_tensors (tensors, " text_encoders.t5xxl.transformer" );
1149+ if (use_clip_l) {
1150+ clip_l->get_param_tensors (tensors, " text_encoders.clip_l.transformer.text_model" );
1151+ }
1152+ if (use_t5) {
1153+ t5->get_param_tensors (tensors, " text_encoders.t5xxl.transformer" );
1154+ }
11241155 }
11251156
11261157 void alloc_params_buffer () {
1127- clip_l->alloc_params_buffer ();
1128- t5->alloc_params_buffer ();
1158+ if (use_clip_l) {
1159+ clip_l->alloc_params_buffer ();
1160+ }
1161+ if (use_t5) {
1162+ t5->alloc_params_buffer ();
1163+ }
11291164 }
11301165
11311166 void free_params_buffer () {
1132- clip_l->free_params_buffer ();
1133- t5->free_params_buffer ();
1167+ if (use_clip_l) {
1168+ clip_l->free_params_buffer ();
1169+ }
1170+ if (use_t5) {
1171+ t5->free_params_buffer ();
1172+ }
11341173 }
11351174
11361175 size_t get_params_buffer_size () {
1137- size_t buffer_size = clip_l->get_params_buffer_size ();
1138- buffer_size += t5->get_params_buffer_size ();
1176+ size_t buffer_size = 0 ;
1177+ if (use_clip_l) {
1178+ buffer_size += clip_l->get_params_buffer_size ();
1179+ }
1180+ if (use_t5) {
1181+ buffer_size += t5->get_params_buffer_size ();
1182+ }
11391183 return buffer_size;
11401184 }
11411185
@@ -1165,18 +1209,23 @@ struct FluxCLIPEmbedder : public Conditioner {
11651209 for (const auto & item : parsed_attention) {
11661210 const std::string& curr_text = item.first ;
11671211 float curr_weight = item.second ;
1168-
1169- std::vector<int > curr_tokens = clip_l_tokenizer.encode (curr_text, on_new_token_cb);
1170- clip_l_tokens.insert (clip_l_tokens.end (), curr_tokens.begin (), curr_tokens.end ());
1171- clip_l_weights.insert (clip_l_weights.end (), curr_tokens.size (), curr_weight);
1172-
1173- curr_tokens = t5_tokenizer.Encode (curr_text, true );
1174- t5_tokens.insert (t5_tokens.end (), curr_tokens.begin (), curr_tokens.end ());
1175- t5_weights.insert (t5_weights.end (), curr_tokens.size (), curr_weight);
1212+ if (use_clip_l) {
1213+ std::vector<int > curr_tokens = clip_l_tokenizer.encode (curr_text, on_new_token_cb);
1214+ clip_l_tokens.insert (clip_l_tokens.end (), curr_tokens.begin (), curr_tokens.end ());
1215+ clip_l_weights.insert (clip_l_weights.end (), curr_tokens.size (), curr_weight);
1216+ }
1217+ if (use_t5) {
1218+ std::vector<int > curr_tokens = t5_tokenizer.Encode (curr_text, true );
1219+ t5_tokens.insert (t5_tokens.end (), curr_tokens.begin (), curr_tokens.end ());
1220+ t5_weights.insert (t5_weights.end (), curr_tokens.size (), curr_weight);
1221+ }
1222+ }
1223+ if (use_clip_l) {
1224+ clip_l_tokenizer.pad_tokens (clip_l_tokens, clip_l_weights, 77 , padding);
1225+ }
1226+ if (use_t5) {
1227+ t5_tokenizer.pad_tokens (t5_tokens, t5_weights, NULL , max_length, padding);
11761228 }
1177-
1178- clip_l_tokenizer.pad_tokens (clip_l_tokens, clip_l_weights, 77 , padding);
1179- t5_tokenizer.pad_tokens (t5_tokens, t5_weights, NULL , max_length, padding);
11801229
11811230 // for (int i = 0; i < clip_l_tokens.size(); i++) {
11821231 // std::cout << clip_l_tokens[i] << ":" << clip_l_weights[i] << ", ";
@@ -1208,34 +1257,36 @@ struct FluxCLIPEmbedder : public Conditioner {
12081257 struct ggml_tensor * pooled = NULL ; // [768,]
12091258 std::vector<float > hidden_states_vec;
12101259
1211- size_t chunk_count = t5_tokens.size () / chunk_len;
1260+ size_t chunk_count = std::max (clip_l_tokens. size () > 0 ? chunk_len : 0 , t5_tokens.size () ) / chunk_len;
12121261 for (int chunk_idx = 0 ; chunk_idx < chunk_count; chunk_idx++) {
12131262 // clip_l
12141263 if (chunk_idx == 0 ) {
1215- size_t chunk_len_l = 77 ;
1216- std::vector<int > chunk_tokens (clip_l_tokens.begin (),
1217- clip_l_tokens.begin () + chunk_len_l);
1218- std::vector<float > chunk_weights (clip_l_weights.begin (),
1219- clip_l_weights.begin () + chunk_len_l);
1264+ if (use_clip_l) {
1265+ size_t chunk_len_l = 77 ;
1266+ std::vector<int > chunk_tokens (clip_l_tokens.begin (),
1267+ clip_l_tokens.begin () + chunk_len_l);
1268+ std::vector<float > chunk_weights (clip_l_weights.begin (),
1269+ clip_l_weights.begin () + chunk_len_l);
12201270
1221- auto input_ids = vector_to_ggml_tensor_i32 (work_ctx, chunk_tokens);
1222- size_t max_token_idx = 0 ;
1271+ auto input_ids = vector_to_ggml_tensor_i32 (work_ctx, chunk_tokens);
1272+ size_t max_token_idx = 0 ;
12231273
1224- auto it = std::find (chunk_tokens.begin (), chunk_tokens.end (), clip_l_tokenizer.EOS_TOKEN_ID );
1225- max_token_idx = std::min<size_t >(std::distance (chunk_tokens.begin (), it), chunk_tokens.size () - 1 );
1274+ auto it = std::find (chunk_tokens.begin (), chunk_tokens.end (), clip_l_tokenizer.EOS_TOKEN_ID );
1275+ max_token_idx = std::min<size_t >(std::distance (chunk_tokens.begin (), it), chunk_tokens.size () - 1 );
12261276
1227- clip_l->compute (n_threads,
1228- input_ids,
1229- 0 ,
1230- NULL ,
1231- max_token_idx,
1232- true ,
1233- &pooled,
1234- work_ctx);
1277+ clip_l->compute (n_threads,
1278+ input_ids,
1279+ 0 ,
1280+ NULL ,
1281+ max_token_idx,
1282+ true ,
1283+ &pooled,
1284+ work_ctx);
1285+ }
12351286 }
12361287
12371288 // t5
1238- {
1289+ if (use_t5) {
12391290 std::vector<int > chunk_tokens (t5_tokens.begin () + chunk_idx * chunk_len,
12401291 t5_tokens.begin () + (chunk_idx + 1 ) * chunk_len);
12411292 std::vector<float > chunk_weights (t5_weights.begin () + chunk_idx * chunk_len,
@@ -1263,8 +1314,12 @@ struct FluxCLIPEmbedder : public Conditioner {
12631314 float new_mean = ggml_tensor_mean (tensor);
12641315 ggml_tensor_scale (tensor, (original_mean / new_mean));
12651316 }
1317+ } else {
1318+ chunk_hidden_states = ggml_new_tensor_2d (work_ctx, GGML_TYPE_F32, 4096 , chunk_len);
1319+ ggml_set_f32 (chunk_hidden_states, 0 .f );
12661320 }
12671321
1322+
12681323 int64_t t1 = ggml_time_ms ();
12691324 LOG_DEBUG (" computing condition graph completed, taking %" PRId64 " ms" , t1 - t0);
12701325 if (force_zero_embeddings) {
@@ -1273,17 +1328,26 @@ struct FluxCLIPEmbedder : public Conditioner {
12731328 vec[i] = 0 ;
12741329 }
12751330 }
1276-
1331+
12771332 hidden_states_vec.insert (hidden_states_vec.end (),
1278- (float *)chunk_hidden_states->data ,
1279- ((float *)chunk_hidden_states->data ) + ggml_nelements (chunk_hidden_states));
1333+ (float *)chunk_hidden_states->data ,
1334+ ((float *)chunk_hidden_states->data ) + ggml_nelements (chunk_hidden_states));
1335+ }
1336+
1337+ if (hidden_states_vec.size () > 0 ) {
1338+ hidden_states = vector_to_ggml_tensor (work_ctx, hidden_states_vec);
1339+ hidden_states = ggml_reshape_2d (work_ctx,
1340+ hidden_states,
1341+ chunk_hidden_states->ne [0 ],
1342+ ggml_nelements (hidden_states) / chunk_hidden_states->ne [0 ]);
1343+ } else {
1344+ hidden_states = ggml_new_tensor_2d (work_ctx, GGML_TYPE_F32, 4096 , 256 );
1345+ ggml_set_f32 (hidden_states, 0 .f );
1346+ }
1347+ if (pooled == NULL ) {
1348+ pooled = ggml_new_tensor_1d (work_ctx, GGML_TYPE_F32, 768 );
1349+ ggml_set_f32 (pooled, 0 .f );
12801350 }
1281-
1282- hidden_states = vector_to_ggml_tensor (work_ctx, hidden_states_vec);
1283- hidden_states = ggml_reshape_2d (work_ctx,
1284- hidden_states,
1285- chunk_hidden_states->ne [0 ],
1286- ggml_nelements (hidden_states) / chunk_hidden_states->ne [0 ]);
12871351 return SDCondition (hidden_states, pooled, NULL );
12881352 }
12891353
0 commit comments