diff --git a/tensorrt_llm/_torch/models/modeling_qwen2vl.py b/tensorrt_llm/_torch/models/modeling_qwen2vl.py index 6d9493fafe7..4913e7f5e54 100644 --- a/tensorrt_llm/_torch/models/modeling_qwen2vl.py +++ b/tensorrt_llm/_torch/models/modeling_qwen2vl.py @@ -4,9 +4,12 @@ import torch import torch.nn as nn -from transformers import (AutoProcessor, AutoTokenizer, PretrainedConfig, - PreTrainedModel, Qwen2_5_VLForConditionalGeneration, - Qwen2VLForConditionalGeneration) +from transformers import (AutoConfig, AutoProcessor, AutoTokenizer, + PretrainedConfig, PreTrainedModel) +from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import \ + Qwen2_5_VisionTransformerPretrainedModel +from transformers.models.qwen2_vl.modeling_qwen2_vl import \ + Qwen2VisionTransformerPretrainedModel from tensorrt_llm.inputs.multimodal import MultimodalParams @@ -328,14 +331,22 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig], # TODO: Change the model class to TRT-LLM's Qwen2VisionModel # Currently, copying vision encoder on all devices. # NOTE: Using attn_implementation='flash_attention_2' to avoid the issue of vision model's GPU OOM. - model = model_class.from_pretrained( - model_path, + hf_model_config = AutoConfig.from_pretrained(model_path) + vision_model = model_class._from_config( + hf_model_config.vision_config, torch_dtype=pretrained_config.torch_dtype, - attn_implementation='flash_attention_2').eval() + attn_implementation='flash_attention_2') # TODO: Make vision model compatible with meta init mode and load_weights at the same place - self.visual = model.visual.to(self.device) + self.visual = vision_model.to(self.device) self.post_config() + def load_weights(self, weights): + filtered_weights = { + k.replace('visual.', ''): v + for k, v in weights.items() if k.startswith('visual.') + } + self.visual.load_state_dict(filtered_weights) + def post_config(self): self.config = self.visual.config @@ -473,6 +484,7 @@ def init_rotary_cos_sin_ori(self): def load_weights(self, weights): self.llm.load_weights(weights) + self.mm_encoder.load_weights(weights) self.init_rotary_cos_sin_ori() def infer_max_seq_len(self) -> int: @@ -626,7 +638,7 @@ def forward( @register_vision_encoder(Qwen2VisionModelBase, - vlm_base_model=Qwen2VLForConditionalGeneration) + vlm_base_model=Qwen2VisionTransformerPretrainedModel) @register_auto_model("Qwen2VLForConditionalGeneration") @register_input_processor( Qwen2VLInputProcessorBase, @@ -645,11 +657,12 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig], *args, super().__init__(model_config, *args, **kwargs) if not DISAGG: self.mm_encoder = Qwen2VisionModelBase( - model_config, Qwen2VLForConditionalGeneration) + model_config, Qwen2VisionTransformerPretrainedModel) -@register_vision_encoder(Qwen2VisionModelBase, - vlm_base_model=Qwen2_5_VLForConditionalGeneration) +@register_vision_encoder( + Qwen2VisionModelBase, + vlm_base_model=Qwen2_5_VisionTransformerPretrainedModel) @register_auto_model("Qwen2_5_VLForConditionalGeneration") @register_input_processor( Qwen2VLInputProcessorBase, @@ -666,4 +679,4 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig], *args, super().__init__(model_config, *args, **kwargs) if not DISAGG: self.mm_encoder = Qwen2VisionModelBase( - model_config, Qwen2_5_VLForConditionalGeneration) + model_config, Qwen2_5_VisionTransformerPretrainedModel)