Skip to content

Commit 4730bd7

Browse files
fix
Signed-off-by: Pamela <[email protected]>
1 parent 6daf8f9 commit 4730bd7

File tree

1 file changed

+19
-8
lines changed

1 file changed

+19
-8
lines changed

tensorrt_llm/_torch/models/modeling_qwen2vl.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@
77
from transformers import (AutoProcessor, AutoTokenizer, PretrainedConfig,
88
PreTrainedModel, Qwen2_5_VLForConditionalGeneration,
99
Qwen2VLForConditionalGeneration)
10+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import \
11+
Qwen2_5_VisionTransformerPretrainedModel
1012
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
13+
from transformers.models.qwen2_vl.modeling_qwen2_vl import \
14+
Qwen2VisionTransformerPretrainedModel
1115

1216
from tensorrt_llm.inputs.multimodal import MultimodalParams
1317

@@ -358,15 +362,21 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig],
358362
# TODO: Change the model class to TRT-LLM's Qwen2VisionModel
359363
# Currently, copying vision encoder on all devices.
360364
# NOTE: Using attn_implementation='flash_attention_2' to avoid the issue of vision model's GPU OOM.
361-
model = model_class.from_pretrained(
362-
model_path,
363-
torch_dtype=pretrained_config.torch_dtype,
364-
attn_implementation='flash_attention_2',
365-
ignore_mismatched_sizes=True).eval()
365+
hf_model_config = AutoConfig.from_pretrained(model_path)
366+
vision_model = model_class(config=hf_model_config.vision_config,
367+
torch_dtype=pretrained_config.torch_dtype,
368+
attn_implementation='flash_attention_2')
366369
# TODO: Make vision model compatible with meta init mode and load_weights at the same place
367-
self.visual = model.visual.to(self.device)
370+
self.visual = vision_model.to(self.device)
368371
self.post_config()
369372

373+
def load_weights(self, weights):
374+
filtered_weights = {
375+
k.replace('visual.', ''): v
376+
for k, v in weights.items() if k.startswith('visual.')
377+
}
378+
self.visual.load_state_dict(filtered_weights)
379+
370380
def post_config(self):
371381
self.config = self.visual.config
372382

@@ -504,6 +514,7 @@ def init_rotary_cos_sin_ori(self):
504514

505515
def load_weights(self, weights):
506516
self.llm.load_weights(weights)
517+
self.mm_encoder.load_weights(weights)
507518
self.init_rotary_cos_sin_ori()
508519

509520
def infer_max_seq_len(self) -> int:
@@ -676,7 +687,7 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig], *args,
676687
super().__init__(model_config, *args, **kwargs)
677688
if not DISAGG:
678689
self.mm_encoder = Qwen2VisionModelBase(
679-
model_config, Qwen2VLForConditionalGeneration)
690+
model_config, Qwen2VisionTransformerPretrainedModel)
680691

681692

682693
@register_vision_encoder(Qwen2VisionModelBase,
@@ -697,4 +708,4 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig], *args,
697708
super().__init__(model_config, *args, **kwargs)
698709
if not DISAGG:
699710
self.mm_encoder = Qwen2VisionModelBase(
700-
model_config, Qwen2_5_VLForConditionalGeneration)
711+
model_config, Qwen2_5_VisionTransformerPretrainedModel)

0 commit comments

Comments
 (0)