@@ -661,47 +661,108 @@ struct SD3CLIPEmbedder : public Conditioner {
661661 std::shared_ptr<CLIPTextModelRunner> clip_l;
662662 std::shared_ptr<CLIPTextModelRunner> clip_g;
663663 std::shared_ptr<T5Runner> t5;
664+ bool use_clip_l = false ;
665+ bool use_clip_g = false ;
666+ bool use_t5 = false ;
664667
665668 SD3CLIPEmbedder (ggml_backend_t backend,
666669 std::map<std::string, enum ggml_type>& tensor_types,
667670 int clip_skip = -1 )
668671 : clip_g_tokenizer(0 ) {
669- clip_l = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, " text_encoders.clip_l.transformer.text_model" , OPENAI_CLIP_VIT_L_14, false );
670- clip_g = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, " text_encoders.clip_g.transformer.text_model" , OPEN_CLIP_VIT_BIGG_14, false );
671- t5 = std::make_shared<T5Runner>(backend, tensor_types, " text_encoders.t5xxl.transformer" );
672+ if (clip_skip <= 0 ) {
673+ clip_skip = 2 ;
674+ }
675+
676+ for (auto pair : tensor_types) {
677+ if (pair.first .find (" text_encoders.clip_l" ) != std::string::npos) {
678+ use_clip_l = true ;
679+ } else if (pair.first .find (" text_encoders.clip_g" ) != std::string::npos) {
680+ use_clip_g = true ;
681+ } else if (pair.first .find (" text_encoders.t5xxl" ) != std::string::npos) {
682+ use_t5 = true ;
683+ }
684+ }
685+ if (!use_clip_l && !use_clip_g && !use_t5) {
686+ LOG_WARN (" IMPORTANT NOTICE: No text encoders provided, cannot process prompts!" );
687+ return ;
688+ }
689+ if (use_clip_l) {
690+ clip_l = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, " text_encoders.clip_l.transformer.text_model" , OPENAI_CLIP_VIT_L_14, clip_skip, false );
691+ } else {
692+ LOG_WARN (" clip_l text encoder not found! Prompt adherence might be degraded." );
693+ }
694+ if (use_clip_g) {
695+ clip_g = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, " text_encoders.clip_g.transformer.text_model" , OPEN_CLIP_VIT_BIGG_14, clip_skip, false );
696+ } else {
697+ LOG_WARN (" clip_g text encoder not found! Prompt adherence might be degraded." );
698+ }
699+ if (use_t5) {
700+ t5 = std::make_shared<T5Runner>(backend, tensor_types, " text_encoders.t5xxl.transformer" );
701+ } else {
702+ LOG_WARN (" t5xxl text encoder not found! Prompt adherence might be degraded." );
703+ }
672704 set_clip_skip (clip_skip);
673705 }
674706
675707 void set_clip_skip (int clip_skip) {
676708 if (clip_skip <= 0 ) {
677709 clip_skip = 2 ;
678710 }
679- clip_l->set_clip_skip (clip_skip);
680- clip_g->set_clip_skip (clip_skip);
711+ if (use_clip_l) {
712+ clip_l->set_clip_skip (clip_skip);
713+ }
714+ if (use_clip_g) {
715+ clip_g->set_clip_skip (clip_skip);
716+ }
681717 }
682718
683719 void get_param_tensors (std::map<std::string, struct ggml_tensor *>& tensors) {
684- clip_l->get_param_tensors (tensors, " text_encoders.clip_l.transformer.text_model" );
685- clip_g->get_param_tensors (tensors, " text_encoders.clip_g.transformer.text_model" );
686- t5->get_param_tensors (tensors, " text_encoders.t5xxl.transformer" );
720+ if (use_clip_l) {
721+ clip_l->get_param_tensors (tensors, " text_encoders.clip_l.transformer.text_model" );
722+ }
723+ if (use_clip_g) {
724+ clip_g->get_param_tensors (tensors, " text_encoders.clip_g.transformer.text_model" );
725+ }
726+ if (use_t5) {
727+ t5->get_param_tensors (tensors, " text_encoders.t5xxl.transformer" );
728+ }
687729 }
688730
689731 void alloc_params_buffer () {
690- clip_l->alloc_params_buffer ();
691- clip_g->alloc_params_buffer ();
692- t5->alloc_params_buffer ();
732+ if (use_clip_l) {
733+ clip_l->alloc_params_buffer ();
734+ }
735+ if (use_clip_g) {
736+ clip_g->alloc_params_buffer ();
737+ }
738+ if (use_t5) {
739+ t5->alloc_params_buffer ();
740+ }
693741 }
694742
695743 void free_params_buffer () {
696- clip_l->free_params_buffer ();
697- clip_g->free_params_buffer ();
698- t5->free_params_buffer ();
744+ if (use_clip_l) {
745+ clip_l->free_params_buffer ();
746+ }
747+ if (use_clip_g) {
748+ clip_g->free_params_buffer ();
749+ }
750+ if (use_t5) {
751+ t5->free_params_buffer ();
752+ }
699753 }
700754
701755 size_t get_params_buffer_size () {
702- size_t buffer_size = clip_l->get_params_buffer_size ();
703- buffer_size += clip_g->get_params_buffer_size ();
704- buffer_size += t5->get_params_buffer_size ();
756+ size_t buffer_size = 0 ;
757+ if (use_clip_l) {
758+ buffer_size += clip_l->get_params_buffer_size ();
759+ }
760+ if (use_clip_g) {
761+ buffer_size += clip_g->get_params_buffer_size ();
762+ }
763+ if (use_t5) {
764+ buffer_size += t5->get_params_buffer_size ();
765+ }
705766 return buffer_size;
706767 }
707768
@@ -733,23 +794,32 @@ struct SD3CLIPEmbedder : public Conditioner {
733794 for (const auto & item : parsed_attention) {
734795 const std::string& curr_text = item.first ;
735796 float curr_weight = item.second ;
736-
737- std::vector<int > curr_tokens = clip_l_tokenizer.encode (curr_text, on_new_token_cb);
738- clip_l_tokens.insert (clip_l_tokens.end (), curr_tokens.begin (), curr_tokens.end ());
739- clip_l_weights.insert (clip_l_weights.end (), curr_tokens.size (), curr_weight);
740-
741- curr_tokens = clip_g_tokenizer.encode (curr_text, on_new_token_cb);
742- clip_g_tokens.insert (clip_g_tokens.end (), curr_tokens.begin (), curr_tokens.end ());
743- clip_g_weights.insert (clip_g_weights.end (), curr_tokens.size (), curr_weight);
744-
745- curr_tokens = t5_tokenizer.Encode (curr_text, true );
746- t5_tokens.insert (t5_tokens.end (), curr_tokens.begin (), curr_tokens.end ());
747- t5_weights.insert (t5_weights.end (), curr_tokens.size (), curr_weight);
797+ if (use_clip_l) {
798+ std::vector<int > curr_tokens = clip_l_tokenizer.encode (curr_text, on_new_token_cb);
799+ clip_l_tokens.insert (clip_l_tokens.end (), curr_tokens.begin (), curr_tokens.end ());
800+ clip_l_weights.insert (clip_l_weights.end (), curr_tokens.size (), curr_weight);
801+ }
802+ if (use_clip_g) {
803+ std::vector<int > curr_tokens = clip_g_tokenizer.encode (curr_text, on_new_token_cb);
804+ clip_g_tokens.insert (clip_g_tokens.end (), curr_tokens.begin (), curr_tokens.end ());
805+ clip_g_weights.insert (clip_g_weights.end (), curr_tokens.size (), curr_weight);
806+ }
807+ if (use_t5) {
808+ std::vector<int > curr_tokens = t5_tokenizer.Encode (curr_text, true );
809+ t5_tokens.insert (t5_tokens.end (), curr_tokens.begin (), curr_tokens.end ());
810+ t5_weights.insert (t5_weights.end (), curr_tokens.size (), curr_weight);
811+ }
748812 }
749813
750- clip_l_tokenizer.pad_tokens (clip_l_tokens, clip_l_weights, max_length, padding);
751- clip_g_tokenizer.pad_tokens (clip_g_tokens, clip_g_weights, max_length, padding);
752- t5_tokenizer.pad_tokens (t5_tokens, t5_weights, NULL , max_length, padding);
814+ if (use_clip_l) {
815+ clip_l_tokenizer.pad_tokens (clip_l_tokens, clip_l_weights, max_length, padding);
816+ }
817+ if (use_clip_g) {
818+ clip_g_tokenizer.pad_tokens (clip_g_tokens, clip_g_weights, max_length, padding);
819+ }
820+ if (use_t5) {
821+ t5_tokenizer.pad_tokens (t5_tokens, t5_weights, NULL , max_length, padding);
822+ }
753823
754824 // for (int i = 0; i < clip_l_tokens.size(); i++) {
755825 // std::cout << clip_l_tokens[i] << ":" << clip_l_weights[i] << ", ";
@@ -794,10 +864,10 @@ struct SD3CLIPEmbedder : public Conditioner {
794864 std::vector<float > hidden_states_vec;
795865
796866 size_t chunk_len = 77 ;
797- size_t chunk_count = clip_l_tokens.size () / chunk_len;
867+ size_t chunk_count = std::max ( std::max ( clip_l_tokens.size (), clip_g_tokens. size ()), t5_tokens. size () ) / chunk_len;
798868 for (int chunk_idx = 0 ; chunk_idx < chunk_count; chunk_idx++) {
799869 // clip_l
800- {
870+ if (use_clip_l) {
801871 std::vector<int > chunk_tokens (clip_l_tokens.begin () + chunk_idx * chunk_len,
802872 clip_l_tokens.begin () + (chunk_idx + 1 ) * chunk_len);
803873 std::vector<float > chunk_weights (clip_l_weights.begin () + chunk_idx * chunk_len,
@@ -842,10 +912,17 @@ struct SD3CLIPEmbedder : public Conditioner {
842912 &pooled_l,
843913 work_ctx);
844914 }
915+ } else {
916+ chunk_hidden_states_l = ggml_new_tensor_2d (work_ctx, GGML_TYPE_F32, 768 , chunk_len);
917+ ggml_set_f32 (chunk_hidden_states_l, 0 .f );
918+ if (chunk_idx == 0 ) {
919+ pooled_l = ggml_new_tensor_1d (work_ctx, GGML_TYPE_F32, 768 );
920+ ggml_set_f32 (pooled_l, 0 .f );
921+ }
845922 }
846923
847924 // clip_g
848- {
925+ if (use_clip_g) {
849926 std::vector<int > chunk_tokens (clip_g_tokens.begin () + chunk_idx * chunk_len,
850927 clip_g_tokens.begin () + (chunk_idx + 1 ) * chunk_len);
851928 std::vector<float > chunk_weights (clip_g_weights.begin () + chunk_idx * chunk_len,
@@ -891,10 +968,17 @@ struct SD3CLIPEmbedder : public Conditioner {
891968 &pooled_g,
892969 work_ctx);
893970 }
971+ } else {
972+ chunk_hidden_states_g = ggml_new_tensor_2d (work_ctx, GGML_TYPE_F32, 1280 , chunk_len);
973+ ggml_set_f32 (chunk_hidden_states_g, 0 .f );
974+ if (chunk_idx == 0 ) {
975+ pooled_g = ggml_new_tensor_1d (work_ctx, GGML_TYPE_F32, 1280 );
976+ ggml_set_f32 (pooled_g, 0 .f );
977+ }
894978 }
895979
896980 // t5
897- {
981+ if (use_t5) {
898982 std::vector<int > chunk_tokens (t5_tokens.begin () + chunk_idx * chunk_len,
899983 t5_tokens.begin () + (chunk_idx + 1 ) * chunk_len);
900984 std::vector<float > chunk_weights (t5_weights.begin () + chunk_idx * chunk_len,
@@ -922,6 +1006,8 @@ struct SD3CLIPEmbedder : public Conditioner {
9221006 float new_mean = ggml_tensor_mean (tensor);
9231007 ggml_tensor_scale (tensor, (original_mean / new_mean));
9241008 }
1009+ } else {
1010+ chunk_hidden_states_t5 = ggml_new_tensor_2d (work_ctx, GGML_TYPE_F32, 4096 , 0 );
9251011 }
9261012
9271013 auto chunk_hidden_states_lg_pad = ggml_new_tensor_3d (work_ctx,
@@ -964,11 +1050,19 @@ struct SD3CLIPEmbedder : public Conditioner {
9641050 ((float *)chunk_hidden_states->data ) + ggml_nelements (chunk_hidden_states));
9651051 }
9661052
967- hidden_states = vector_to_ggml_tensor (work_ctx, hidden_states_vec);
968- hidden_states = ggml_reshape_2d (work_ctx,
969- hidden_states,
970- chunk_hidden_states->ne [0 ],
971- ggml_nelements (hidden_states) / chunk_hidden_states->ne [0 ]);
1053+ if (hidden_states_vec.size () > 0 ) {
1054+ hidden_states = vector_to_ggml_tensor (work_ctx, hidden_states_vec);
1055+ hidden_states = ggml_reshape_2d (work_ctx,
1056+ hidden_states,
1057+ chunk_hidden_states->ne [0 ],
1058+ ggml_nelements (hidden_states) / chunk_hidden_states->ne [0 ]);
1059+ } else {
1060+ hidden_states = ggml_new_tensor_2d (work_ctx, GGML_TYPE_F32, 4096 , 0 );
1061+ }
1062+ if (pooled == NULL ) {
1063+ pooled = ggml_new_tensor_1d (work_ctx, GGML_TYPE_F32, 2048 );
1064+ ggml_set_f32 (pooled, 0 .f );
1065+ }
9721066 return SDCondition (hidden_states, pooled, NULL );
9731067 }
9741068
0 commit comments