Skip to content

Commit b0c3540

Browse files
yechank-nvidiadominicshanshan
authored andcommitted
[https://nvbugs/5549829][fix] Qwen2.5-VL TP > 1 + Quantized weight load fix (NVIDIA#8680)
Signed-off-by: yechank <[email protected]>
1 parent 58fb9e1 commit b0c3540

File tree

2 files changed

+98
-97
lines changed

2 files changed

+98
-97
lines changed

tensorrt_llm/_torch/models/checkpoints/hf/qwen2vl_weight_mapper.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,18 @@ class Qwen2VLHfWeightMapper(HfWeightMapper):
1010
'language_model.' prefix removal from weight keys.
1111
"""
1212

13-
def filter_weights(self, prefix: str, weights: dict) -> dict:
13+
def preprocess_weights(self, weights: dict) -> dict:
14+
"""
15+
Preprocess weights to remove the 'model.language_model.' and 'model.visual.' prefixes.
16+
"""
1417
transformed_weights = {}
15-
language_model_prefix = "model.language_model."
1618
for key, value in weights.items():
17-
if key.startswith(language_model_prefix):
18-
new_key = "model." + key[len(language_model_prefix):]
19+
if key.startswith("model.language_model."):
20+
new_key = "model." + key[len("model.language_model."):]
21+
transformed_weights[new_key] = value
22+
elif key.startswith("model.visual."):
23+
new_key = "visual." + key[len("model.visual."):]
1924
transformed_weights[new_key] = value
2025
else:
2126
transformed_weights[key] = value
22-
return super().filter_weights(prefix, transformed_weights)
27+
return transformed_weights

tensorrt_llm/_torch/models/modeling_qwen2vl.py

Lines changed: 88 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import copy
22
import os
3+
import re
34
from typing import Any, Dict, List, Optional, Tuple, Union
45

56
import torch
@@ -18,6 +19,8 @@
1819
PredefinedAttentionMask
1920
from tensorrt_llm._torch.models.checkpoints.base_weight_mapper import \
2021
BaseWeightMapper
22+
from tensorrt_llm._torch.models.checkpoints.hf.qwen2vl_weight_mapper import \
23+
Qwen2VLHfWeightMapper
2124
from tensorrt_llm._torch.modules.attention import Attention
2225
from tensorrt_llm._torch.modules.linear import Linear
2326
from tensorrt_llm._torch.modules.rms_norm import RMSNorm
@@ -39,7 +42,8 @@
3942
from .modeling_auto import AutoModelForCausalLM
4043
from .modeling_multimodal_utils import (find_input_mm_embeds, fuse_input_embeds,
4144
get_multimodal_embeddings)
42-
from .modeling_utils import (ModelConfig, register_auto_model,
45+
from .modeling_utils import (ModelConfig, QuantConfig, _load_weights_impl,
46+
filter_weights, register_auto_model,
4347
register_vision_encoder)
4448

4549
DISAGG = os.getenv('TLLM_MULTIMODAL_DISAGGREGATED', '0') == '1'
@@ -95,6 +99,7 @@ def __init__(self,
9599

96100
super().__init__()
97101
self.model_config = model_config
102+
self.vision_dtype = self.model_config.torch_dtype
98103
self.tokenizer = tokenizer if tokenizer is not None else AutoTokenizer.from_pretrained(
99104
model_path)
100105
self.use_fast = True
@@ -347,14 +352,15 @@ def __call__(
347352
pixel_values = processed_inputs.get('pixel_values', None)
348353
if pixel_values is not None:
349354
multimodal_data["image"] = {
350-
"pixel_values": pixel_values,
355+
"pixel_values": pixel_values.to(self.vision_dtype),
351356
"image_grid_thw": processed_inputs.get('image_grid_thw')
352357
}
353358

354359
pixel_values_videos = processed_inputs.get('pixel_values_videos', None)
355360
if pixel_values_videos is not None:
356361
multimodal_data["video"] = {
357-
"pixel_values_videos": pixel_values_videos,
362+
"pixel_values_videos":
363+
pixel_values_videos.to(self.vision_dtype),
358364
"video_grid_thw": processed_inputs.get('video_grid_thw')
359365
}
360366

@@ -382,29 +388,59 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig],
382388
model_class: Union[type[PreTrainedModel],
383389
type[torch.nn.Module]]):
384390
super().__init__()
385-
config = model_config.pretrained_config.vision_config
386-
config.torch_dtype = model_config.pretrained_config.torch_dtype
387391
self.model_config = model_config
388-
self.model_dtype = config.torch_dtype
392+
self.model_dtype = self.model_config.pretrained_config.torch_dtype
393+
self.config = self.model_config.pretrained_config.vision_config
394+
self.config.num_attention_heads = self.config.num_heads
395+
396+
# NOTE: Re-setting QuantConfig to exclude vision encoder weights from quantization load.
397+
self.model_config.quant_config = QuantConfig(
398+
kv_cache_quant_algo=self.model_config.quant_config.
399+
kv_cache_quant_algo)
389400

390401
if model_class in [
391402
Qwen2VisionTransformerPretrainedModel,
392403
Qwen2_5_VisionTransformerPretrainedModel
393404
]:
394405
# NOTE: For Qwen2VL, we use flash_attention_2 for attention implementation to avoid OOM issue.
395-
config._attn_implementation = 'flash_attention_2'
396-
self.visual = model_class(config).to(self.model_dtype).eval()
406+
self.config._attn_implementation = 'flash_attention_2'
407+
self.visual = model_class(
408+
model_config.pretrained_config.vision_config).to(
409+
self.model_dtype).eval()
397410
elif model_class == Qwen2_5_VisionModel:
398-
self.visual = model_class(self.model_config).to(
399-
self.model_dtype).eval()
411+
self.visual = model_class(self.model_config).to(self.model_dtype)
400412
else:
401413
raise NotImplementedError(
402414
f"Model class {model_class} not implemented")
403415

404-
self.post_config()
405-
406-
def post_config(self):
407-
self.config = self.model_config.pretrained_config.vision_config
416+
def load_weights(self, weights: Dict):
417+
visual_weights = filter_weights("visual", weights)
418+
converted_weights = dict()
419+
420+
qkv_pattern = re.compile(r'(.*?)attn\.qkv\.(.*)')
421+
for name in visual_weights:
422+
# Handle with weights and bias for vision transformer's qkv projection.
423+
match = qkv_pattern.match(name)
424+
if match:
425+
prefix, suffix = match.groups()
426+
q_name = f"{prefix}attn.q_proj.{suffix}"
427+
k_name = f"{prefix}attn.k_proj.{suffix}"
428+
v_name = f"{prefix}attn.v_proj.{suffix}"
429+
dim_shape = visual_weights[name].shape[0] // 3
430+
converted_weights[q_name] = visual_weights[name][:dim_shape]
431+
converted_weights[k_name] = visual_weights[name][dim_shape:2 *
432+
dim_shape]
433+
converted_weights[v_name] = visual_weights[name][2 * dim_shape:]
434+
else:
435+
converted_weights[name] = visual_weights[name]
436+
pattern_mapping = {
437+
r'(.*?)attn.proj.(.*)': r'\1attn.o_proj.\2',
438+
r'(.*?)mlp.fc1.(.*)': r'\1mlp.up_proj.\2',
439+
r'(.*?)mlp.fc2.(.*)': r'\1mlp.down_proj.\2',
440+
}
441+
_load_weights_impl(self.visual,
442+
converted_weights,
443+
params_map=pattern_mapping)
408444

409445
def _parse_and_batch_multimodal_data(
410446
self, multimodal_params: List[MultimodalParams]
@@ -469,12 +505,10 @@ def forward(self, multimodal_params: List[MultimodalParams]):
469505

470506
embeds = []
471507
if pixel_values is not None:
472-
pixel_values = pixel_values.to(self.model_dtype)
473508
embed = self.visual(pixel_values, grid_thw=image_grid_thw)
474509
embeds.append(embed)
475510

476511
if pixel_values_videos is not None:
477-
pixel_values_videos = pixel_values_videos.to(self.model_dtype)
478512
embeds.append(
479513
self.visual(pixel_values_videos, grid_thw=video_grid_thw))
480514
return embeds
@@ -615,32 +649,33 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
615649
class Qwen2_5_VisionModel(torch.nn.Module):
616650

617651
def __init__(self, model_config: ModelConfig[PretrainedConfig]):
618-
config = model_config.pretrained_config.vision_config
619652
super().__init__()
653+
self.model_config = model_config
654+
self.config = self.model_config.pretrained_config.vision_config
620655

621-
self.spatial_merge_size = config.spatial_merge_size
622-
self.patch_size = config.patch_size
623-
self.fullatt_block_indexes = config.fullatt_block_indexes
624-
self.window_size = config.window_size
656+
self.spatial_merge_size = self.config.spatial_merge_size
657+
self.patch_size = self.config.patch_size
658+
self.fullatt_block_indexes = self.config.fullatt_block_indexes
659+
self.window_size = self.config.window_size
625660
self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size
626661

627662
self.patch_embed = Qwen2_5_VisionPatchEmbed(
628-
patch_size=config.patch_size,
629-
temporal_patch_size=config.temporal_patch_size,
630-
in_channels=config.in_channels,
631-
embed_dim=config.hidden_size,
663+
patch_size=self.config.patch_size,
664+
temporal_patch_size=self.config.temporal_patch_size,
665+
in_channels=self.config.in_channels,
666+
embed_dim=self.config.hidden_size,
632667
)
633668

634-
head_dim = config.hidden_size // config.num_heads
669+
head_dim = self.config.hidden_size // self.config.num_heads
635670
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
636671

637672
self.blocks = torch.nn.ModuleList([
638673
Qwen2_5_VLVisionBlock(model_config, layer_idx=layer_idx)
639-
for layer_idx in range(config.depth)
674+
for layer_idx in range(self.config.depth)
640675
])
641-
self.merger = Qwen2_5_VLPatchMerger(model_config, )
676+
self.merger = Qwen2_5_VLPatchMerger(self.model_config, )
642677
self.metadata_cls = get_attention_backend(
643-
model_config.attn_backend).Metadata
678+
self.model_config.attn_backend).Metadata
644679

645680
self.full_attn_metadata = self.metadata_cls(
646681
max_num_requests=8192, # TODO: Make this dynamic
@@ -798,35 +833,31 @@ def __init__(
798833
*args,
799834
**kwargs,
800835
) -> None:
801-
model_config.pretrained_config.rope_scaling['type'] = 'mrope'
802836
self.original_arch = model_config.pretrained_config.architectures[0]
837+
803838
# NOTE: Setting disable_fuse_rope to True to do mrope fusion in the model engine by pre-computing rotary_cos_sin in the model engine
804839
disabble_fuse_rope = kwargs.get('disable_fuse_rope', False)
805-
model_config.pretrained_config.text_config.disable_fuse_rope = disabble_fuse_rope
840+
model_config.pretrained_config.disable_fuse_rope = disabble_fuse_rope
841+
model_config.pretrained_config.rope_scaling['type'] = 'mrope'
806842
config = model_config.pretrained_config
807-
808-
assert model_config.attn_backend == 'TRTLLM', "Qwen2/2.5-VL only supports TRTLLM backend now"
809843
super().__init__(config)
810-
if not disabble_fuse_rope:
811-
self.init_mrope_embedding(model_config)
812844

813845
self.model_config = model_config
814-
if hasattr(self, "llm"):
815-
return
846+
self.config = model_config.pretrained_config
816847

817-
if not DISAGG:
818-
self.mm_encoder = Qwen2VisionModelBase(
819-
model_config, kwargs.get('vision_model_class', None)).eval()
848+
if model_config.attn_backend != 'TRTLLM':
849+
raise ValueError("Qwen2/2.5-VL only supports TRTLLM backend now")
850+
if not disabble_fuse_rope:
851+
self.init_mrope_embedding(model_config)
820852

821853
llm_model_config = copy.deepcopy(model_config)
822-
llm_model_config.pretrained_config = config.text_config
823854
llm_model_config.pretrained_config.architectures = ["Qwen2ForCausalLM"]
824-
825855
self.llm = AutoModelForCausalLM.from_config(llm_model_config)
826-
self.model_dtype = getattr(config, "torch_dtype", torch.bfloat16)
827-
logger.info(f"{self.dtype=} {self.model_dtype=}")
828-
self.post_config()
829-
self.is_loaded = True
856+
857+
if not DISAGG:
858+
mm_encoder_config = copy.deepcopy(model_config)
859+
self.mm_encoder = Qwen2VisionModelBase(
860+
mm_encoder_config, kwargs.get('vision_model_class', None))
830861

831862
def init_mrope_embedding(self, model_config: ModelConfig[PretrainedConfig]):
832863
config = model_config.pretrained_config
@@ -854,11 +885,6 @@ def load_weights(self, weights, weight_mapper: BaseWeightMapper):
854885
def infer_max_seq_len(self) -> int:
855886
return self.llm.infer_max_seq_len()
856887

857-
def post_config(self):
858-
# use llm.config as config for pytorch model engine
859-
self.config = self.llm.config
860-
self.model_config.pretrained_config = self.llm.config
861-
862888
@nvtx_range("Qwen2.5-VL prepare_mrope_config")
863889
def prepare_mrope_config(self, multimodal_params: List[MultimodalParams],
864890
num_context_requests: int):
@@ -998,22 +1024,8 @@ def load_weights(self, weights, weight_mapper: BaseWeightMapper):
9981024
self.llm.load_weights(weights, weight_mapper)
9991025

10001026

1001-
def getSMVersion():
1002-
prop = torch.cuda.get_device_properties(0)
1003-
sm_version = prop.major * 10 + prop.minor
1004-
return sm_version
1005-
1006-
1007-
get_sm_version = getSMVersion()
1008-
if get_sm_version >= 100:
1009-
# NOTE: Qwen2.5-VL with SM 100 and above uses HF's implementation due to lacking of TRT-LLM's Attention kernel.
1010-
QWEN2_5_VL_VISION_MODEL_CLASS = Qwen2_5_VisionTransformerPretrainedModel
1011-
else:
1012-
QWEN2_5_VL_VISION_MODEL_CLASS = Qwen2_5_VisionModel
1013-
1014-
10151027
@register_vision_encoder(Qwen2VisionModelBase,
1016-
vlm_base_model=QWEN2_5_VL_VISION_MODEL_CLASS)
1028+
vlm_base_model=Qwen2_5_VisionModel)
10171029
@register_auto_model("Qwen2_5_VLForConditionalGeneration")
10181030
@register_input_processor(
10191031
Qwen2VLInputProcessorBase,
@@ -1029,39 +1041,23 @@ class Qwen2_5_VLModel(Qwen2VLModelBase):
10291041

10301042
def __init__(self, model_config: ModelConfig[PretrainedConfig], *args,
10311043
**kwargs):
1032-
kwargs['vision_model_class'] = QWEN2_5_VL_VISION_MODEL_CLASS
1044+
kwargs['vision_model_class'] = Qwen2_5_VisionModel
10331045
kwargs[
10341046
'disable_fuse_rope'] = False # TODO: Make this ModelConfig's argument
10351047
super().__init__(model_config, *args, **kwargs)
10361048

10371049
@property
10381050
def multimodal_data_device_paths(self) -> List[str]:
1039-
if get_sm_version >= 100:
1040-
return [
1041-
"image.pixel_values", "video.pixel_values_videos",
1042-
"image.image_grid_thw", "video.video_grid_thw",
1043-
"multimodal_embedding"
1044-
]
1045-
else:
1046-
return [
1047-
"image.pixel_values", "video.pixel_values_videos",
1048-
"multimodal_embedding"
1049-
]
1051+
return [
1052+
"image.pixel_values", "video.pixel_values_videos",
1053+
"multimodal_embedding"
1054+
]
10501055

10511056
def load_weights(self, weights, weight_mapper: BaseWeightMapper):
1057+
if isinstance(weight_mapper, Qwen2VLHfWeightMapper):
1058+
weights = weight_mapper.preprocess_weights(weights)
1059+
10521060
if not DISAGG:
1053-
if get_sm_version >= 100:
1054-
weight_name_mapping = None
1055-
else:
1056-
# Process vision encoder weights
1057-
weight_name_mapping = {
1058-
"attn.proj.weight": "attn.o_proj.weight",
1059-
"attn.proj.bias": "attn.o_proj.bias",
1060-
"attn.qkv.weight": "attn.qkv_proj.weight",
1061-
"attn.qkv.bias": "attn.qkv_proj.bias"
1062-
}
1063-
vision_weights = process_weights(weights, "visual",
1064-
weight_name_mapping)
1065-
self.mm_encoder.load_state_dict(vision_weights, strict=True)
1061+
self.mm_encoder.load_weights(weights)
10661062

1067-
self.llm.load_weights(weights, weight_mapper)
1063+
self.llm.load_weights(weights)

0 commit comments

Comments
 (0)