11import copy
22import os
3+ import re
34from typing import Any , Dict , List , Optional , Tuple , Union
45
56import numpy as np
4243from .modeling_auto import AutoModelForCausalLM
4344from .modeling_multimodal_utils import (find_input_mm_embeds , fuse_input_embeds ,
4445 get_multimodal_embeddings )
45- from .modeling_utils import (ModelConfig , register_auto_model ,
46+ from .modeling_utils import (ModelConfig , QuantConfig , _load_weights_impl ,
47+ filter_weights , register_auto_model ,
4648 register_vision_encoder )
4749
4850DISAGG = os .getenv ('TLLM_MULTIMODAL_DISAGGREGATED' , '0' ) == '1'
@@ -96,6 +98,7 @@ def __init__(self,
9698 tokenizer : AutoTokenizer ,
9799 trust_remote_code : bool = True ):
98100 self .model_config = model_config
101+ self .vision_dtype = self .model_config .torch_dtype
99102 self .tokenizer = tokenizer if tokenizer is not None else AutoTokenizer .from_pretrained (
100103 model_path )
101104 self .use_fast = True
@@ -423,14 +426,15 @@ def __call__(
423426 pixel_values = processed_inputs .get ('pixel_values' , None )
424427 if pixel_values is not None :
425428 multimodal_data ["image" ] = {
426- "pixel_values" : pixel_values ,
429+ "pixel_values" : pixel_values . to ( self . vision_dtype ) ,
427430 "image_grid_thw" : processed_inputs .get ('image_grid_thw' )
428431 }
429432
430433 pixel_values_videos = processed_inputs .get ('pixel_values_videos' , None )
431434 if pixel_values_videos is not None :
432435 multimodal_data ["video" ] = {
433- "pixel_values_videos" : pixel_values_videos ,
436+ "pixel_values_videos" :
437+ pixel_values_videos .to (self .vision_dtype ),
434438 "video_grid_thw" : processed_inputs .get ('video_grid_thw' )
435439 }
436440
@@ -458,29 +462,59 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig],
458462 model_class : Union [type [PreTrainedModel ],
459463 type [torch .nn .Module ]]):
460464 super ().__init__ ()
461- config = model_config .pretrained_config .vision_config
462- config .torch_dtype = model_config .pretrained_config .torch_dtype
463465 self .model_config = model_config
464- self .model_dtype = config .torch_dtype
466+ self .model_dtype = self .model_config .pretrained_config .torch_dtype
467+ self .config = self .model_config .pretrained_config .vision_config
468+ self .config .num_attention_heads = self .config .num_heads
469+
470+ # NOTE: Re-setting QuantConfig to exclude vision encoder weights from quantization load.
471+ self .model_config .quant_config = QuantConfig (
472+ kv_cache_quant_algo = self .model_config .quant_config .
473+ kv_cache_quant_algo )
465474
466475 if model_class in [
467476 Qwen2VisionTransformerPretrainedModel ,
468477 Qwen2_5_VisionTransformerPretrainedModel
469478 ]:
470479 # NOTE: For Qwen2VL, we use flash_attention_2 for attention implementation to avoid OOM issue.
471- config ._attn_implementation = 'flash_attention_2'
472- self .visual = model_class (config ).to (self .model_dtype ).eval ()
480+ self .config ._attn_implementation = 'flash_attention_2'
481+ self .visual = model_class (
482+ model_config .pretrained_config .vision_config ).to (
483+ self .model_dtype ).eval ()
473484 elif model_class == Qwen2_5_VisionModel :
474- self .visual = model_class (self .model_config ).to (
475- self .model_dtype ).eval ()
485+ self .visual = model_class (self .model_config ).to (self .model_dtype )
476486 else :
477487 raise NotImplementedError (
478488 f"Model class { model_class } not implemented" )
479489
480- self .post_config ()
481-
482- def post_config (self ):
483- self .config = self .model_config .pretrained_config .vision_config
490+ def load_weights (self , weights : Dict ):
491+ visual_weights = filter_weights ("visual" , weights )
492+ converted_weights = dict ()
493+
494+ qkv_pattern = re .compile (r'(.*?)attn\.qkv\.(.*)' )
495+ for name in visual_weights :
496+ # Handle with weights and bias for vision transformer's qkv projection.
497+ match = qkv_pattern .match (name )
498+ if match :
499+ prefix , suffix = match .groups ()
500+ q_name = f"{ prefix } attn.q_proj.{ suffix } "
501+ k_name = f"{ prefix } attn.k_proj.{ suffix } "
502+ v_name = f"{ prefix } attn.v_proj.{ suffix } "
503+ dim_shape = visual_weights [name ].shape [0 ] // 3
504+ converted_weights [q_name ] = visual_weights [name ][:dim_shape ]
505+ converted_weights [k_name ] = visual_weights [name ][dim_shape :2 *
506+ dim_shape ]
507+ converted_weights [v_name ] = visual_weights [name ][2 * dim_shape :]
508+ else :
509+ converted_weights [name ] = visual_weights [name ]
510+ pattern_mapping = {
511+ r'(.*?)attn.proj.(.*)' : r'\1attn.o_proj.\2' ,
512+ r'(.*?)mlp.fc1.(.*)' : r'\1mlp.up_proj.\2' ,
513+ r'(.*?)mlp.fc2.(.*)' : r'\1mlp.down_proj.\2' ,
514+ }
515+ _load_weights_impl (self .visual ,
516+ converted_weights ,
517+ params_map = pattern_mapping )
484518
485519 def _parse_and_batch_multimodal_data (
486520 self , multimodal_params : List [MultimodalParams ]
@@ -545,12 +579,10 @@ def forward(self, multimodal_params: List[MultimodalParams]):
545579
546580 embeds = []
547581 if pixel_values is not None :
548- pixel_values = pixel_values .to (self .model_dtype )
549582 embed = self .visual (pixel_values , grid_thw = image_grid_thw )
550583 embeds .append (embed )
551584
552585 if pixel_values_videos is not None :
553- pixel_values_videos = pixel_values_videos .to (self .model_dtype )
554586 embeds .append (
555587 self .visual (pixel_values_videos , grid_thw = video_grid_thw ))
556588 return embeds
@@ -691,32 +723,33 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
691723class Qwen2_5_VisionModel (torch .nn .Module ):
692724
693725 def __init__ (self , model_config : ModelConfig [PretrainedConfig ]):
694- config = model_config .pretrained_config .vision_config
695726 super ().__init__ ()
727+ self .model_config = model_config
728+ self .config = self .model_config .pretrained_config .vision_config
696729
697- self .spatial_merge_size = config .spatial_merge_size
698- self .patch_size = config .patch_size
699- self .fullatt_block_indexes = config .fullatt_block_indexes
700- self .window_size = config .window_size
730+ self .spatial_merge_size = self . config .spatial_merge_size
731+ self .patch_size = self . config .patch_size
732+ self .fullatt_block_indexes = self . config .fullatt_block_indexes
733+ self .window_size = self . config .window_size
701734 self .spatial_merge_unit = self .spatial_merge_size * self .spatial_merge_size
702735
703736 self .patch_embed = Qwen2_5_VisionPatchEmbed (
704- patch_size = config .patch_size ,
705- temporal_patch_size = config .temporal_patch_size ,
706- in_channels = config .in_channels ,
707- embed_dim = config .hidden_size ,
737+ patch_size = self . config .patch_size ,
738+ temporal_patch_size = self . config .temporal_patch_size ,
739+ in_channels = self . config .in_channels ,
740+ embed_dim = self . config .hidden_size ,
708741 )
709742
710- head_dim = config .hidden_size // config .num_heads
743+ head_dim = self . config .hidden_size // self . config .num_heads
711744 self .rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding (head_dim // 2 )
712745
713746 self .blocks = torch .nn .ModuleList ([
714747 Qwen2_5_VLVisionBlock (model_config , layer_idx = layer_idx )
715- for layer_idx in range (config .depth )
748+ for layer_idx in range (self . config .depth )
716749 ])
717- self .merger = Qwen2_5_VLPatchMerger (model_config , )
750+ self .merger = Qwen2_5_VLPatchMerger (self . model_config , )
718751 self .metadata_cls = get_attention_backend (
719- model_config .attn_backend ).Metadata
752+ self . model_config .attn_backend ).Metadata
720753
721754 self .full_attn_metadata = self .metadata_cls (
722755 max_num_requests = 8192 , # TODO: Make this dynamic
@@ -874,35 +907,30 @@ def __init__(
874907 * args ,
875908 ** kwargs ,
876909 ) -> None :
877- model_config .pretrained_config .rope_scaling ['type' ] = 'mrope'
878910 self .original_arch = model_config .pretrained_config .architectures [0 ]
911+
879912 # NOTE: Setting disable_fuse_rope to True to do mrope fusion in the model engine by pre-computing rotary_cos_sin in the model engine
880913 disabble_fuse_rope = kwargs .get ('disable_fuse_rope' , False )
881- model_config .pretrained_config .text_config .disable_fuse_rope = disabble_fuse_rope
914+ model_config .pretrained_config .disable_fuse_rope = disabble_fuse_rope
915+ model_config .pretrained_config .rope_scaling ['type' ] = 'mrope'
882916 config = model_config .pretrained_config
883-
884- assert model_config .attn_backend == 'TRTLLM' , "Qwen2/2.5-VL only supports TRTLLM backend now"
885917 super ().__init__ (config )
886- if not disabble_fuse_rope :
887- self .init_mrope_embedding (model_config )
888918
889919 self .model_config = model_config
890- if hasattr (self , "llm" ):
891- return
920+ self .config = model_config .pretrained_config
892921
893- if not DISAGG :
894- self . mm_encoder = Qwen2VisionModelBase (
895- model_config , kwargs . get ( 'vision_model_class' , None )). eval ( )
922+ assert model_config . attn_backend == 'TRTLLM' , "Qwen2/2.5-VL only supports TRTLLM backend now"
923+ if not disabble_fuse_rope :
924+ self . init_mrope_embedding ( model_config )
896925
897926 llm_model_config = copy .deepcopy (model_config )
898- llm_model_config .pretrained_config = config .text_config
899927 llm_model_config .pretrained_config .architectures = ["Qwen2ForCausalLM" ]
900-
901928 self .llm = AutoModelForCausalLM .from_config (llm_model_config )
902- self .model_dtype = getattr (config , "torch_dtype" , torch .bfloat16 )
903- logger .info (f"{ self .dtype = } { self .model_dtype = } " )
904- self .post_config ()
905- self .is_loaded = True
929+
930+ if not DISAGG :
931+ mm_encoder_config = copy .deepcopy (model_config )
932+ self .mm_encoder = Qwen2VisionModelBase (
933+ mm_encoder_config , kwargs .get ('vision_model_class' , None ))
906934
907935 def init_mrope_embedding (self , model_config : ModelConfig [PretrainedConfig ]):
908936 config = model_config .pretrained_config
@@ -930,11 +958,6 @@ def load_weights(self, weights, weight_mapper: BaseWeightMapper):
930958 def infer_max_seq_len (self ) -> int :
931959 return self .llm .infer_max_seq_len ()
932960
933- def post_config (self ):
934- # use llm.config as config for pytorch model engine
935- self .config = self .llm .config
936- self .model_config .pretrained_config = self .llm .config
937-
938961 @nvtx_range ("Qwen2.5-VL prepare_mrope_config" )
939962 def prepare_mrope_config (self , multimodal_params : List [MultimodalParams ],
940963 num_context_requests : int ):
@@ -1075,22 +1098,8 @@ def load_weights(self, weights, weight_mapper: BaseWeightMapper):
10751098 self .llm .load_weights (weights , weight_mapper )
10761099
10771100
1078- def getSMVersion ():
1079- prop = torch .cuda .get_device_properties (0 )
1080- sm_version = prop .major * 10 + prop .minor
1081- return sm_version
1082-
1083-
1084- get_sm_version = getSMVersion ()
1085- if get_sm_version >= 100 :
1086- # NOTE: Qwen2.5-VL with SM 100 and above uses HF's implementation due to lacking of TRT-LLM's Attention kernel.
1087- QWEN2_5_VL_VISION_MODEL_CLASS = Qwen2_5_VisionTransformerPretrainedModel
1088- else :
1089- QWEN2_5_VL_VISION_MODEL_CLASS = Qwen2_5_VisionModel
1090-
1091-
10921101@register_vision_encoder (Qwen2VisionModelBase ,
1093- vlm_base_model = QWEN2_5_VL_VISION_MODEL_CLASS )
1102+ vlm_base_model = Qwen2_5_VisionModel )
10941103@register_auto_model ("Qwen2_5_VLForConditionalGeneration" )
10951104@register_input_processor (
10961105 Qwen2VLInputProcessorBase ,
@@ -1106,39 +1115,20 @@ class Qwen2_5_VLModel(Qwen2VLModelBase):
11061115
11071116 def __init__ (self , model_config : ModelConfig [PretrainedConfig ], * args ,
11081117 ** kwargs ):
1109- kwargs ['vision_model_class' ] = QWEN2_5_VL_VISION_MODEL_CLASS
1118+ kwargs ['vision_model_class' ] = Qwen2_5_VisionModel
11101119 kwargs [
11111120 'disable_fuse_rope' ] = False # TODO: Make this ModelConfig's argument
11121121 super ().__init__ (model_config , * args , ** kwargs )
11131122
11141123 @property
11151124 def multimodal_data_device_paths (self ) -> List [str ]:
1116- if get_sm_version >= 100 :
1117- return [
1118- "image.pixel_values" , "video.pixel_values_videos" ,
1119- "image.image_grid_thw" , "video.video_grid_thw" ,
1120- "multimodal_embedding"
1121- ]
1122- else :
1123- return [
1124- "image.pixel_values" , "video.pixel_values_videos" ,
1125- "multimodal_embedding"
1126- ]
1125+ return [
1126+ "image.pixel_values" , "video.pixel_values_videos" ,
1127+ "multimodal_embedding"
1128+ ]
11271129
11281130 def load_weights (self , weights , weight_mapper : BaseWeightMapper ):
11291131 if not DISAGG :
1130- if get_sm_version >= 100 :
1131- weight_name_mapping = None
1132- else :
1133- # Process vision encoder weights
1134- weight_name_mapping = {
1135- "attn.proj.weight" : "attn.o_proj.weight" ,
1136- "attn.proj.bias" : "attn.o_proj.bias" ,
1137- "attn.qkv.weight" : "attn.qkv_proj.weight" ,
1138- "attn.qkv.bias" : "attn.qkv_proj.bias"
1139- }
1140- vision_weights = process_weights (weights , "visual" ,
1141- weight_name_mapping )
1142- self .mm_encoder .load_state_dict (vision_weights , strict = True )
1132+ self .mm_encoder .load_weights (weights )
11431133
11441134 self .llm .load_weights (weights , weight_mapper )
0 commit comments