77from transformers import (AutoProcessor , AutoTokenizer , PretrainedConfig ,
88 PreTrainedModel , Qwen2_5_VLForConditionalGeneration ,
99 Qwen2VLForConditionalGeneration )
10+ from transformers .models .qwen2_5_vl .modeling_qwen2_5_vl import \
11+ Qwen2_5_VisionTransformerPretrainedModel
1012from transformers .models .qwen2_vl .image_processing_qwen2_vl import smart_resize
13+ from transformers .models .qwen2_vl .modeling_qwen2_vl import \
14+ Qwen2VisionTransformerPretrainedModel
1115
1216from tensorrt_llm .inputs .multimodal import MultimodalParams
1317
@@ -358,15 +362,21 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig],
358362 # TODO: Change the model class to TRT-LLM's Qwen2VisionModel
359363 # Currently, copying vision encoder on all devices.
360364 # NOTE: Using attn_implementation='flash_attention_2' to avoid the issue of vision model's GPU OOM.
361- model = model_class .from_pretrained (
362- model_path ,
363- torch_dtype = pretrained_config .torch_dtype ,
364- attn_implementation = 'flash_attention_2' ,
365- ignore_mismatched_sizes = True ).eval ()
365+ hf_model_config = AutoConfig .from_pretrained (model_path )
366+ vision_model = model_class (config = hf_model_config .vision_config ,
367+ torch_dtype = pretrained_config .torch_dtype ,
368+ attn_implementation = 'flash_attention_2' )
366369 # TODO: Make vision model compatible with meta init mode and load_weights at the same place
367- self .visual = model . visual .to (self .device )
370+ self .visual = vision_model .to (self .device )
368371 self .post_config ()
369372
373+ def load_weights (self , weights ):
374+ filtered_weights = {
375+ k .replace ('visual.' , '' ): v
376+ for k , v in weights .items () if k .startswith ('visual.' )
377+ }
378+ self .visual .load_state_dict (filtered_weights )
379+
370380 def post_config (self ):
371381 self .config = self .visual .config
372382
@@ -504,6 +514,7 @@ def init_rotary_cos_sin_ori(self):
504514
505515 def load_weights (self , weights ):
506516 self .llm .load_weights (weights )
517+ self .mm_encoder .load_weights (weights )
507518 self .init_rotary_cos_sin_ori ()
508519
509520 def infer_max_seq_len (self ) -> int :
@@ -676,7 +687,7 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig], *args,
676687 super ().__init__ (model_config , * args , ** kwargs )
677688 if not DISAGG :
678689 self .mm_encoder = Qwen2VisionModelBase (
679- model_config , Qwen2VLForConditionalGeneration )
690+ model_config , Qwen2VisionTransformerPretrainedModel )
680691
681692
682693@register_vision_encoder (Qwen2VisionModelBase ,
@@ -697,4 +708,4 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig], *args,
697708 super ().__init__ (model_config , * args , ** kwargs )
698709 if not DISAGG :
699710 self .mm_encoder = Qwen2VisionModelBase (
700- model_config , Qwen2_5_VLForConditionalGeneration )
711+ model_config , Qwen2_5_VisionTransformerPretrainedModel )
0 commit comments