Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 25 additions & 12 deletions tensorrt_llm/_torch/models/modeling_qwen2vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@

import torch
import torch.nn as nn
from transformers import (AutoProcessor, AutoTokenizer, PretrainedConfig,
PreTrainedModel, Qwen2_5_VLForConditionalGeneration,
Qwen2VLForConditionalGeneration)
from transformers import (AutoConfig, AutoProcessor, AutoTokenizer,
PretrainedConfig, PreTrainedModel)
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import \
Qwen2_5_VisionTransformerPretrainedModel
from transformers.models.qwen2_vl.modeling_qwen2_vl import \
Qwen2VisionTransformerPretrainedModel

from tensorrt_llm.inputs.multimodal import MultimodalParams

Expand Down Expand Up @@ -328,14 +331,22 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig],
# TODO: Change the model class to TRT-LLM's Qwen2VisionModel
# Currently, copying vision encoder on all devices.
# NOTE: Using attn_implementation='flash_attention_2' to avoid the issue of vision model's GPU OOM.
model = model_class.from_pretrained(
model_path,
hf_model_config = AutoConfig.from_pretrained(model_path)
vision_model = model_class._from_config(
hf_model_config.vision_config,
torch_dtype=pretrained_config.torch_dtype,
attn_implementation='flash_attention_2').eval()
attn_implementation='flash_attention_2')
# TODO: Make vision model compatible with meta init mode and load_weights at the same place
self.visual = model.visual.to(self.device)
self.visual = vision_model.to(self.device)
self.post_config()

def load_weights(self, weights):
filtered_weights = {
k.replace('visual.', ''): v
for k, v in weights.items() if k.startswith('visual.')
}
self.visual.load_state_dict(filtered_weights)

def post_config(self):
self.config = self.visual.config

Expand Down Expand Up @@ -473,6 +484,7 @@ def init_rotary_cos_sin_ori(self):

def load_weights(self, weights):
self.llm.load_weights(weights)
self.mm_encoder.load_weights(weights)
self.init_rotary_cos_sin_ori()

def infer_max_seq_len(self) -> int:
Expand Down Expand Up @@ -626,7 +638,7 @@ def forward(


@register_vision_encoder(Qwen2VisionModelBase,
vlm_base_model=Qwen2VLForConditionalGeneration)
vlm_base_model=Qwen2VisionTransformerPretrainedModel)
@register_auto_model("Qwen2VLForConditionalGeneration")
@register_input_processor(
Qwen2VLInputProcessorBase,
Expand All @@ -645,11 +657,12 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig], *args,
super().__init__(model_config, *args, **kwargs)
if not DISAGG:
self.mm_encoder = Qwen2VisionModelBase(
model_config, Qwen2VLForConditionalGeneration)
model_config, Qwen2VisionTransformerPretrainedModel)


@register_vision_encoder(Qwen2VisionModelBase,
vlm_base_model=Qwen2_5_VLForConditionalGeneration)
@register_vision_encoder(
Qwen2VisionModelBase,
vlm_base_model=Qwen2_5_VisionTransformerPretrainedModel)
@register_auto_model("Qwen2_5_VLForConditionalGeneration")
@register_input_processor(
Qwen2VLInputProcessorBase,
Expand All @@ -666,4 +679,4 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig], *args,
super().__init__(model_config, *args, **kwargs)
if not DISAGG:
self.mm_encoder = Qwen2VisionModelBase(
model_config, Qwen2_5_VLForConditionalGeneration)
model_config, Qwen2_5_VisionTransformerPretrainedModel)
Loading