Skip to content

Commit 270683f

Browse files
qwen2.5-vl tp > 1 + fp8 & fp4 weight load fix
Signed-off-by: yechank <[email protected]>
1 parent a6017f6 commit 270683f

File tree

1 file changed

+81
-91
lines changed

1 file changed

+81
-91
lines changed

tensorrt_llm/_torch/models/modeling_qwen2vl.py

Lines changed: 81 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import copy
22
import os
3+
import re
34
from typing import Any, Dict, List, Optional, Tuple, Union
45

56
import numpy as np
@@ -42,7 +43,8 @@
4243
from .modeling_auto import AutoModelForCausalLM
4344
from .modeling_multimodal_utils import (find_input_mm_embeds, fuse_input_embeds,
4445
get_multimodal_embeddings)
45-
from .modeling_utils import (ModelConfig, register_auto_model,
46+
from .modeling_utils import (ModelConfig, QuantConfig, _load_weights_impl,
47+
filter_weights, register_auto_model,
4648
register_vision_encoder)
4749

4850
DISAGG = os.getenv('TLLM_MULTIMODAL_DISAGGREGATED', '0') == '1'
@@ -96,6 +98,7 @@ def __init__(self,
9698
tokenizer: AutoTokenizer,
9799
trust_remote_code: bool = True):
98100
self.model_config = model_config
101+
self.vision_dtype = self.model_config.torch_dtype
99102
self.tokenizer = tokenizer if tokenizer is not None else AutoTokenizer.from_pretrained(
100103
model_path)
101104
self.use_fast = True
@@ -423,14 +426,15 @@ def __call__(
423426
pixel_values = processed_inputs.get('pixel_values', None)
424427
if pixel_values is not None:
425428
multimodal_data["image"] = {
426-
"pixel_values": pixel_values,
429+
"pixel_values": pixel_values.to(self.vision_dtype),
427430
"image_grid_thw": processed_inputs.get('image_grid_thw')
428431
}
429432

430433
pixel_values_videos = processed_inputs.get('pixel_values_videos', None)
431434
if pixel_values_videos is not None:
432435
multimodal_data["video"] = {
433-
"pixel_values_videos": pixel_values_videos,
436+
"pixel_values_videos":
437+
pixel_values_videos.to(self.vision_dtype),
434438
"video_grid_thw": processed_inputs.get('video_grid_thw')
435439
}
436440

@@ -458,29 +462,59 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig],
458462
model_class: Union[type[PreTrainedModel],
459463
type[torch.nn.Module]]):
460464
super().__init__()
461-
config = model_config.pretrained_config.vision_config
462-
config.torch_dtype = model_config.pretrained_config.torch_dtype
463465
self.model_config = model_config
464-
self.model_dtype = config.torch_dtype
466+
self.model_dtype = self.model_config.pretrained_config.torch_dtype
467+
self.config = self.model_config.pretrained_config.vision_config
468+
self.config.num_attention_heads = self.config.num_heads
469+
470+
# NOTE: Re-setting QuantConfig to exclude vision encoder weights from quantization load.
471+
self.model_config.quant_config = QuantConfig(
472+
kv_cache_quant_algo=self.model_config.quant_config.
473+
kv_cache_quant_algo)
465474

466475
if model_class in [
467476
Qwen2VisionTransformerPretrainedModel,
468477
Qwen2_5_VisionTransformerPretrainedModel
469478
]:
470479
# NOTE: For Qwen2VL, we use flash_attention_2 for attention implementation to avoid OOM issue.
471-
config._attn_implementation = 'flash_attention_2'
472-
self.visual = model_class(config).to(self.model_dtype).eval()
480+
self.config._attn_implementation = 'flash_attention_2'
481+
self.visual = model_class(
482+
model_config.pretrained_config.vision_config).to(
483+
self.model_dtype).eval()
473484
elif model_class == Qwen2_5_VisionModel:
474-
self.visual = model_class(self.model_config).to(
475-
self.model_dtype).eval()
485+
self.visual = model_class(self.model_config).to(self.model_dtype)
476486
else:
477487
raise NotImplementedError(
478488
f"Model class {model_class} not implemented")
479489

480-
self.post_config()
481-
482-
def post_config(self):
483-
self.config = self.model_config.pretrained_config.vision_config
490+
def load_weights(self, weights: Dict):
491+
visual_weights = filter_weights("visual", weights)
492+
converted_weights = dict()
493+
494+
qkv_pattern = re.compile(r'(.*?)attn\.qkv\.(.*)')
495+
for name in visual_weights:
496+
# Handle with weights and bias for vision transformer's qkv projection.
497+
match = qkv_pattern.match(name)
498+
if match:
499+
prefix, suffix = match.groups()
500+
q_name = f"{prefix}attn.q_proj.{suffix}"
501+
k_name = f"{prefix}attn.k_proj.{suffix}"
502+
v_name = f"{prefix}attn.v_proj.{suffix}"
503+
dim_shape = visual_weights[name].shape[0] // 3
504+
converted_weights[q_name] = visual_weights[name][:dim_shape]
505+
converted_weights[k_name] = visual_weights[name][dim_shape:2 *
506+
dim_shape]
507+
converted_weights[v_name] = visual_weights[name][2 * dim_shape:]
508+
else:
509+
converted_weights[name] = visual_weights[name]
510+
pattern_mapping = {
511+
r'(.*?)attn.proj.(.*)': r'\1attn.o_proj.\2',
512+
r'(.*?)mlp.fc1.(.*)': r'\1mlp.up_proj.\2',
513+
r'(.*?)mlp.fc2.(.*)': r'\1mlp.down_proj.\2',
514+
}
515+
_load_weights_impl(self.visual,
516+
converted_weights,
517+
params_map=pattern_mapping)
484518

485519
def _parse_and_batch_multimodal_data(
486520
self, multimodal_params: List[MultimodalParams]
@@ -545,12 +579,10 @@ def forward(self, multimodal_params: List[MultimodalParams]):
545579

546580
embeds = []
547581
if pixel_values is not None:
548-
pixel_values = pixel_values.to(self.model_dtype)
549582
embed = self.visual(pixel_values, grid_thw=image_grid_thw)
550583
embeds.append(embed)
551584

552585
if pixel_values_videos is not None:
553-
pixel_values_videos = pixel_values_videos.to(self.model_dtype)
554586
embeds.append(
555587
self.visual(pixel_values_videos, grid_thw=video_grid_thw))
556588
return embeds
@@ -691,32 +723,33 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
691723
class Qwen2_5_VisionModel(torch.nn.Module):
692724

693725
def __init__(self, model_config: ModelConfig[PretrainedConfig]):
694-
config = model_config.pretrained_config.vision_config
695726
super().__init__()
727+
self.model_config = model_config
728+
self.config = self.model_config.pretrained_config.vision_config
696729

697-
self.spatial_merge_size = config.spatial_merge_size
698-
self.patch_size = config.patch_size
699-
self.fullatt_block_indexes = config.fullatt_block_indexes
700-
self.window_size = config.window_size
730+
self.spatial_merge_size = self.config.spatial_merge_size
731+
self.patch_size = self.config.patch_size
732+
self.fullatt_block_indexes = self.config.fullatt_block_indexes
733+
self.window_size = self.config.window_size
701734
self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size
702735

703736
self.patch_embed = Qwen2_5_VisionPatchEmbed(
704-
patch_size=config.patch_size,
705-
temporal_patch_size=config.temporal_patch_size,
706-
in_channels=config.in_channels,
707-
embed_dim=config.hidden_size,
737+
patch_size=self.config.patch_size,
738+
temporal_patch_size=self.config.temporal_patch_size,
739+
in_channels=self.config.in_channels,
740+
embed_dim=self.config.hidden_size,
708741
)
709742

710-
head_dim = config.hidden_size // config.num_heads
743+
head_dim = self.config.hidden_size // self.config.num_heads
711744
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
712745

713746
self.blocks = torch.nn.ModuleList([
714747
Qwen2_5_VLVisionBlock(model_config, layer_idx=layer_idx)
715-
for layer_idx in range(config.depth)
748+
for layer_idx in range(self.config.depth)
716749
])
717-
self.merger = Qwen2_5_VLPatchMerger(model_config, )
750+
self.merger = Qwen2_5_VLPatchMerger(self.model_config, )
718751
self.metadata_cls = get_attention_backend(
719-
model_config.attn_backend).Metadata
752+
self.model_config.attn_backend).Metadata
720753

721754
self.full_attn_metadata = self.metadata_cls(
722755
max_num_requests=8192, # TODO: Make this dynamic
@@ -874,35 +907,30 @@ def __init__(
874907
*args,
875908
**kwargs,
876909
) -> None:
877-
model_config.pretrained_config.rope_scaling['type'] = 'mrope'
878910
self.original_arch = model_config.pretrained_config.architectures[0]
911+
879912
# NOTE: Setting disable_fuse_rope to True to do mrope fusion in the model engine by pre-computing rotary_cos_sin in the model engine
880913
disabble_fuse_rope = kwargs.get('disable_fuse_rope', False)
881-
model_config.pretrained_config.text_config.disable_fuse_rope = disabble_fuse_rope
914+
model_config.pretrained_config.disable_fuse_rope = disabble_fuse_rope
915+
model_config.pretrained_config.rope_scaling['type'] = 'mrope'
882916
config = model_config.pretrained_config
883-
884-
assert model_config.attn_backend == 'TRTLLM', "Qwen2/2.5-VL only supports TRTLLM backend now"
885917
super().__init__(config)
886-
if not disabble_fuse_rope:
887-
self.init_mrope_embedding(model_config)
888918

889919
self.model_config = model_config
890-
if hasattr(self, "llm"):
891-
return
920+
self.config = model_config.pretrained_config
892921

893-
if not DISAGG:
894-
self.mm_encoder = Qwen2VisionModelBase(
895-
model_config, kwargs.get('vision_model_class', None)).eval()
922+
assert model_config.attn_backend == 'TRTLLM', "Qwen2/2.5-VL only supports TRTLLM backend now"
923+
if not disabble_fuse_rope:
924+
self.init_mrope_embedding(model_config)
896925

897926
llm_model_config = copy.deepcopy(model_config)
898-
llm_model_config.pretrained_config = config.text_config
899927
llm_model_config.pretrained_config.architectures = ["Qwen2ForCausalLM"]
900-
901928
self.llm = AutoModelForCausalLM.from_config(llm_model_config)
902-
self.model_dtype = getattr(config, "torch_dtype", torch.bfloat16)
903-
logger.info(f"{self.dtype=} {self.model_dtype=}")
904-
self.post_config()
905-
self.is_loaded = True
929+
930+
if not DISAGG:
931+
mm_encoder_config = copy.deepcopy(model_config)
932+
self.mm_encoder = Qwen2VisionModelBase(
933+
mm_encoder_config, kwargs.get('vision_model_class', None))
906934

907935
def init_mrope_embedding(self, model_config: ModelConfig[PretrainedConfig]):
908936
config = model_config.pretrained_config
@@ -930,11 +958,6 @@ def load_weights(self, weights, weight_mapper: BaseWeightMapper):
930958
def infer_max_seq_len(self) -> int:
931959
return self.llm.infer_max_seq_len()
932960

933-
def post_config(self):
934-
# use llm.config as config for pytorch model engine
935-
self.config = self.llm.config
936-
self.model_config.pretrained_config = self.llm.config
937-
938961
@nvtx_range("Qwen2.5-VL prepare_mrope_config")
939962
def prepare_mrope_config(self, multimodal_params: List[MultimodalParams],
940963
num_context_requests: int):
@@ -1075,22 +1098,8 @@ def load_weights(self, weights, weight_mapper: BaseWeightMapper):
10751098
self.llm.load_weights(weights, weight_mapper)
10761099

10771100

1078-
def getSMVersion():
1079-
prop = torch.cuda.get_device_properties(0)
1080-
sm_version = prop.major * 10 + prop.minor
1081-
return sm_version
1082-
1083-
1084-
get_sm_version = getSMVersion()
1085-
if get_sm_version >= 100:
1086-
# NOTE: Qwen2.5-VL with SM 100 and above uses HF's implementation due to lacking of TRT-LLM's Attention kernel.
1087-
QWEN2_5_VL_VISION_MODEL_CLASS = Qwen2_5_VisionTransformerPretrainedModel
1088-
else:
1089-
QWEN2_5_VL_VISION_MODEL_CLASS = Qwen2_5_VisionModel
1090-
1091-
10921101
@register_vision_encoder(Qwen2VisionModelBase,
1093-
vlm_base_model=QWEN2_5_VL_VISION_MODEL_CLASS)
1102+
vlm_base_model=Qwen2_5_VisionModel)
10941103
@register_auto_model("Qwen2_5_VLForConditionalGeneration")
10951104
@register_input_processor(
10961105
Qwen2VLInputProcessorBase,
@@ -1106,39 +1115,20 @@ class Qwen2_5_VLModel(Qwen2VLModelBase):
11061115

11071116
def __init__(self, model_config: ModelConfig[PretrainedConfig], *args,
11081117
**kwargs):
1109-
kwargs['vision_model_class'] = QWEN2_5_VL_VISION_MODEL_CLASS
1118+
kwargs['vision_model_class'] = Qwen2_5_VisionModel
11101119
kwargs[
11111120
'disable_fuse_rope'] = False # TODO: Make this ModelConfig's argument
11121121
super().__init__(model_config, *args, **kwargs)
11131122

11141123
@property
11151124
def multimodal_data_device_paths(self) -> List[str]:
1116-
if get_sm_version >= 100:
1117-
return [
1118-
"image.pixel_values", "video.pixel_values_videos",
1119-
"image.image_grid_thw", "video.video_grid_thw",
1120-
"multimodal_embedding"
1121-
]
1122-
else:
1123-
return [
1124-
"image.pixel_values", "video.pixel_values_videos",
1125-
"multimodal_embedding"
1126-
]
1125+
return [
1126+
"image.pixel_values", "video.pixel_values_videos",
1127+
"multimodal_embedding"
1128+
]
11271129

11281130
def load_weights(self, weights, weight_mapper: BaseWeightMapper):
11291131
if not DISAGG:
1130-
if get_sm_version >= 100:
1131-
weight_name_mapping = None
1132-
else:
1133-
# Process vision encoder weights
1134-
weight_name_mapping = {
1135-
"attn.proj.weight": "attn.o_proj.weight",
1136-
"attn.proj.bias": "attn.o_proj.bias",
1137-
"attn.qkv.weight": "attn.qkv_proj.weight",
1138-
"attn.qkv.bias": "attn.qkv_proj.bias"
1139-
}
1140-
vision_weights = process_weights(weights, "visual",
1141-
weight_name_mapping)
1142-
self.mm_encoder.load_state_dict(vision_weights, strict=True)
1132+
self.mm_encoder.load_weights(weights)
11431133

11441134
self.llm.load_weights(weights, weight_mapper)

0 commit comments

Comments
 (0)