11import copy
22import os
3+ import re
34from typing import Any , Dict , List , Optional , Tuple , Union
45
56import torch
1819 PredefinedAttentionMask
1920from tensorrt_llm ._torch .models .checkpoints .base_weight_mapper import \
2021 BaseWeightMapper
22+ from tensorrt_llm ._torch .models .checkpoints .hf .qwen2vl_weight_mapper import \
23+ Qwen2VLHfWeightMapper
2124from tensorrt_llm ._torch .modules .attention import Attention
2225from tensorrt_llm ._torch .modules .linear import Linear
2326from tensorrt_llm ._torch .modules .rms_norm import RMSNorm
3942from .modeling_auto import AutoModelForCausalLM
4043from .modeling_multimodal_utils import (find_input_mm_embeds , fuse_input_embeds ,
4144 get_multimodal_embeddings )
42- from .modeling_utils import (ModelConfig , register_auto_model ,
45+ from .modeling_utils import (ModelConfig , QuantConfig , _load_weights_impl ,
46+ filter_weights , register_auto_model ,
4347 register_vision_encoder )
4448
4549DISAGG = os .getenv ('TLLM_MULTIMODAL_DISAGGREGATED' , '0' ) == '1'
@@ -95,6 +99,7 @@ def __init__(self,
9599
96100 super ().__init__ ()
97101 self .model_config = model_config
102+ self .vision_dtype = self .model_config .torch_dtype
98103 self .tokenizer = tokenizer if tokenizer is not None else AutoTokenizer .from_pretrained (
99104 model_path )
100105 self .use_fast = True
@@ -347,14 +352,15 @@ def __call__(
347352 pixel_values = processed_inputs .get ('pixel_values' , None )
348353 if pixel_values is not None :
349354 multimodal_data ["image" ] = {
350- "pixel_values" : pixel_values ,
355+ "pixel_values" : pixel_values . to ( self . vision_dtype ) ,
351356 "image_grid_thw" : processed_inputs .get ('image_grid_thw' )
352357 }
353358
354359 pixel_values_videos = processed_inputs .get ('pixel_values_videos' , None )
355360 if pixel_values_videos is not None :
356361 multimodal_data ["video" ] = {
357- "pixel_values_videos" : pixel_values_videos ,
362+ "pixel_values_videos" :
363+ pixel_values_videos .to (self .vision_dtype ),
358364 "video_grid_thw" : processed_inputs .get ('video_grid_thw' )
359365 }
360366
@@ -382,29 +388,59 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig],
382388 model_class : Union [type [PreTrainedModel ],
383389 type [torch .nn .Module ]]):
384390 super ().__init__ ()
385- config = model_config .pretrained_config .vision_config
386- config .torch_dtype = model_config .pretrained_config .torch_dtype
387391 self .model_config = model_config
388- self .model_dtype = config .torch_dtype
392+ self .model_dtype = self .model_config .pretrained_config .torch_dtype
393+ self .config = self .model_config .pretrained_config .vision_config
394+ self .config .num_attention_heads = self .config .num_heads
395+
396+ # NOTE: Re-setting QuantConfig to exclude vision encoder weights from quantization load.
397+ self .model_config .quant_config = QuantConfig (
398+ kv_cache_quant_algo = self .model_config .quant_config .
399+ kv_cache_quant_algo )
389400
390401 if model_class in [
391402 Qwen2VisionTransformerPretrainedModel ,
392403 Qwen2_5_VisionTransformerPretrainedModel
393404 ]:
394405 # NOTE: For Qwen2VL, we use flash_attention_2 for attention implementation to avoid OOM issue.
395- config ._attn_implementation = 'flash_attention_2'
396- self .visual = model_class (config ).to (self .model_dtype ).eval ()
406+ self .config ._attn_implementation = 'flash_attention_2'
407+ self .visual = model_class (
408+ model_config .pretrained_config .vision_config ).to (
409+ self .model_dtype ).eval ()
397410 elif model_class == Qwen2_5_VisionModel :
398- self .visual = model_class (self .model_config ).to (
399- self .model_dtype ).eval ()
411+ self .visual = model_class (self .model_config ).to (self .model_dtype )
400412 else :
401413 raise NotImplementedError (
402414 f"Model class { model_class } not implemented" )
403415
404- self .post_config ()
405-
406- def post_config (self ):
407- self .config = self .model_config .pretrained_config .vision_config
416+ def load_weights (self , weights : Dict ):
417+ visual_weights = filter_weights ("visual" , weights )
418+ converted_weights = dict ()
419+
420+ qkv_pattern = re .compile (r'(.*?)attn\.qkv\.(.*)' )
421+ for name in visual_weights :
422+ # Handle with weights and bias for vision transformer's qkv projection.
423+ match = qkv_pattern .match (name )
424+ if match :
425+ prefix , suffix = match .groups ()
426+ q_name = f"{ prefix } attn.q_proj.{ suffix } "
427+ k_name = f"{ prefix } attn.k_proj.{ suffix } "
428+ v_name = f"{ prefix } attn.v_proj.{ suffix } "
429+ dim_shape = visual_weights [name ].shape [0 ] // 3
430+ converted_weights [q_name ] = visual_weights [name ][:dim_shape ]
431+ converted_weights [k_name ] = visual_weights [name ][dim_shape :2 *
432+ dim_shape ]
433+ converted_weights [v_name ] = visual_weights [name ][2 * dim_shape :]
434+ else :
435+ converted_weights [name ] = visual_weights [name ]
436+ pattern_mapping = {
437+ r'(.*?)attn.proj.(.*)' : r'\1attn.o_proj.\2' ,
438+ r'(.*?)mlp.fc1.(.*)' : r'\1mlp.up_proj.\2' ,
439+ r'(.*?)mlp.fc2.(.*)' : r'\1mlp.down_proj.\2' ,
440+ }
441+ _load_weights_impl (self .visual ,
442+ converted_weights ,
443+ params_map = pattern_mapping )
408444
409445 def _parse_and_batch_multimodal_data (
410446 self , multimodal_params : List [MultimodalParams ]
@@ -469,12 +505,10 @@ def forward(self, multimodal_params: List[MultimodalParams]):
469505
470506 embeds = []
471507 if pixel_values is not None :
472- pixel_values = pixel_values .to (self .model_dtype )
473508 embed = self .visual (pixel_values , grid_thw = image_grid_thw )
474509 embeds .append (embed )
475510
476511 if pixel_values_videos is not None :
477- pixel_values_videos = pixel_values_videos .to (self .model_dtype )
478512 embeds .append (
479513 self .visual (pixel_values_videos , grid_thw = video_grid_thw ))
480514 return embeds
@@ -615,32 +649,33 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
615649class Qwen2_5_VisionModel (torch .nn .Module ):
616650
617651 def __init__ (self , model_config : ModelConfig [PretrainedConfig ]):
618- config = model_config .pretrained_config .vision_config
619652 super ().__init__ ()
653+ self .model_config = model_config
654+ self .config = self .model_config .pretrained_config .vision_config
620655
621- self .spatial_merge_size = config .spatial_merge_size
622- self .patch_size = config .patch_size
623- self .fullatt_block_indexes = config .fullatt_block_indexes
624- self .window_size = config .window_size
656+ self .spatial_merge_size = self . config .spatial_merge_size
657+ self .patch_size = self . config .patch_size
658+ self .fullatt_block_indexes = self . config .fullatt_block_indexes
659+ self .window_size = self . config .window_size
625660 self .spatial_merge_unit = self .spatial_merge_size * self .spatial_merge_size
626661
627662 self .patch_embed = Qwen2_5_VisionPatchEmbed (
628- patch_size = config .patch_size ,
629- temporal_patch_size = config .temporal_patch_size ,
630- in_channels = config .in_channels ,
631- embed_dim = config .hidden_size ,
663+ patch_size = self . config .patch_size ,
664+ temporal_patch_size = self . config .temporal_patch_size ,
665+ in_channels = self . config .in_channels ,
666+ embed_dim = self . config .hidden_size ,
632667 )
633668
634- head_dim = config .hidden_size // config .num_heads
669+ head_dim = self . config .hidden_size // self . config .num_heads
635670 self .rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding (head_dim // 2 )
636671
637672 self .blocks = torch .nn .ModuleList ([
638673 Qwen2_5_VLVisionBlock (model_config , layer_idx = layer_idx )
639- for layer_idx in range (config .depth )
674+ for layer_idx in range (self . config .depth )
640675 ])
641- self .merger = Qwen2_5_VLPatchMerger (model_config , )
676+ self .merger = Qwen2_5_VLPatchMerger (self . model_config , )
642677 self .metadata_cls = get_attention_backend (
643- model_config .attn_backend ).Metadata
678+ self . model_config .attn_backend ).Metadata
644679
645680 self .full_attn_metadata = self .metadata_cls (
646681 max_num_requests = 8192 , # TODO: Make this dynamic
@@ -798,35 +833,31 @@ def __init__(
798833 * args ,
799834 ** kwargs ,
800835 ) -> None :
801- model_config .pretrained_config .rope_scaling ['type' ] = 'mrope'
802836 self .original_arch = model_config .pretrained_config .architectures [0 ]
837+
803838 # NOTE: Setting disable_fuse_rope to True to do mrope fusion in the model engine by pre-computing rotary_cos_sin in the model engine
804839 disabble_fuse_rope = kwargs .get ('disable_fuse_rope' , False )
805- model_config .pretrained_config .text_config .disable_fuse_rope = disabble_fuse_rope
840+ model_config .pretrained_config .disable_fuse_rope = disabble_fuse_rope
841+ model_config .pretrained_config .rope_scaling ['type' ] = 'mrope'
806842 config = model_config .pretrained_config
807-
808- assert model_config .attn_backend == 'TRTLLM' , "Qwen2/2.5-VL only supports TRTLLM backend now"
809843 super ().__init__ (config )
810- if not disabble_fuse_rope :
811- self .init_mrope_embedding (model_config )
812844
813845 self .model_config = model_config
814- if hasattr (self , "llm" ):
815- return
846+ self .config = model_config .pretrained_config
816847
817- if not DISAGG :
818- self .mm_encoder = Qwen2VisionModelBase (
819- model_config , kwargs .get ('vision_model_class' , None )).eval ()
848+ if model_config .attn_backend != 'TRTLLM' :
849+ raise ValueError ("Qwen2/2.5-VL only supports TRTLLM backend now" )
850+ if not disabble_fuse_rope :
851+ self .init_mrope_embedding (model_config )
820852
821853 llm_model_config = copy .deepcopy (model_config )
822- llm_model_config .pretrained_config = config .text_config
823854 llm_model_config .pretrained_config .architectures = ["Qwen2ForCausalLM" ]
824-
825855 self .llm = AutoModelForCausalLM .from_config (llm_model_config )
826- self .model_dtype = getattr (config , "torch_dtype" , torch .bfloat16 )
827- logger .info (f"{ self .dtype = } { self .model_dtype = } " )
828- self .post_config ()
829- self .is_loaded = True
856+
857+ if not DISAGG :
858+ mm_encoder_config = copy .deepcopy (model_config )
859+ self .mm_encoder = Qwen2VisionModelBase (
860+ mm_encoder_config , kwargs .get ('vision_model_class' , None ))
830861
831862 def init_mrope_embedding (self , model_config : ModelConfig [PretrainedConfig ]):
832863 config = model_config .pretrained_config
@@ -854,11 +885,6 @@ def load_weights(self, weights, weight_mapper: BaseWeightMapper):
854885 def infer_max_seq_len (self ) -> int :
855886 return self .llm .infer_max_seq_len ()
856887
857- def post_config (self ):
858- # use llm.config as config for pytorch model engine
859- self .config = self .llm .config
860- self .model_config .pretrained_config = self .llm .config
861-
862888 @nvtx_range ("Qwen2.5-VL prepare_mrope_config" )
863889 def prepare_mrope_config (self , multimodal_params : List [MultimodalParams ],
864890 num_context_requests : int ):
@@ -998,22 +1024,8 @@ def load_weights(self, weights, weight_mapper: BaseWeightMapper):
9981024 self .llm .load_weights (weights , weight_mapper )
9991025
10001026
1001- def getSMVersion ():
1002- prop = torch .cuda .get_device_properties (0 )
1003- sm_version = prop .major * 10 + prop .minor
1004- return sm_version
1005-
1006-
1007- get_sm_version = getSMVersion ()
1008- if get_sm_version >= 100 :
1009- # NOTE: Qwen2.5-VL with SM 100 and above uses HF's implementation due to lacking of TRT-LLM's Attention kernel.
1010- QWEN2_5_VL_VISION_MODEL_CLASS = Qwen2_5_VisionTransformerPretrainedModel
1011- else :
1012- QWEN2_5_VL_VISION_MODEL_CLASS = Qwen2_5_VisionModel
1013-
1014-
10151027@register_vision_encoder (Qwen2VisionModelBase ,
1016- vlm_base_model = QWEN2_5_VL_VISION_MODEL_CLASS )
1028+ vlm_base_model = Qwen2_5_VisionModel )
10171029@register_auto_model ("Qwen2_5_VLForConditionalGeneration" )
10181030@register_input_processor (
10191031 Qwen2VLInputProcessorBase ,
@@ -1029,39 +1041,23 @@ class Qwen2_5_VLModel(Qwen2VLModelBase):
10291041
10301042 def __init__ (self , model_config : ModelConfig [PretrainedConfig ], * args ,
10311043 ** kwargs ):
1032- kwargs ['vision_model_class' ] = QWEN2_5_VL_VISION_MODEL_CLASS
1044+ kwargs ['vision_model_class' ] = Qwen2_5_VisionModel
10331045 kwargs [
10341046 'disable_fuse_rope' ] = False # TODO: Make this ModelConfig's argument
10351047 super ().__init__ (model_config , * args , ** kwargs )
10361048
10371049 @property
10381050 def multimodal_data_device_paths (self ) -> List [str ]:
1039- if get_sm_version >= 100 :
1040- return [
1041- "image.pixel_values" , "video.pixel_values_videos" ,
1042- "image.image_grid_thw" , "video.video_grid_thw" ,
1043- "multimodal_embedding"
1044- ]
1045- else :
1046- return [
1047- "image.pixel_values" , "video.pixel_values_videos" ,
1048- "multimodal_embedding"
1049- ]
1051+ return [
1052+ "image.pixel_values" , "video.pixel_values_videos" ,
1053+ "multimodal_embedding"
1054+ ]
10501055
10511056 def load_weights (self , weights , weight_mapper : BaseWeightMapper ):
1057+ if isinstance (weight_mapper , Qwen2VLHfWeightMapper ):
1058+ weights = weight_mapper .preprocess_weights (weights )
1059+
10521060 if not DISAGG :
1053- if get_sm_version >= 100 :
1054- weight_name_mapping = None
1055- else :
1056- # Process vision encoder weights
1057- weight_name_mapping = {
1058- "attn.proj.weight" : "attn.o_proj.weight" ,
1059- "attn.proj.bias" : "attn.o_proj.bias" ,
1060- "attn.qkv.weight" : "attn.qkv_proj.weight" ,
1061- "attn.qkv.bias" : "attn.qkv_proj.bias"
1062- }
1063- vision_weights = process_weights (weights , "visual" ,
1064- weight_name_mapping )
1065- self .mm_encoder .load_state_dict (vision_weights , strict = True )
1061+ self .mm_encoder .load_weights (weights )
10661062
1067- self .llm .load_weights (weights , weight_mapper )
1063+ self .llm .load_weights (weights )
0 commit comments