Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,18 @@ class Qwen2VLHfWeightMapper(HfWeightMapper):
'language_model.' prefix removal from weight keys.
"""

def filter_weights(self, prefix: str, weights: dict) -> dict:
def preprocess_weights(self, weights: dict) -> dict:
"""
Preprocess weights to remove the 'model.language_model.' and 'model.visual.' prefixes.
"""
transformed_weights = {}
language_model_prefix = "model.language_model."
for key, value in weights.items():
if key.startswith(language_model_prefix):
new_key = "model." + key[len(language_model_prefix):]
if key.startswith("model.language_model."):
new_key = "model." + key[len("model.language_model."):]
transformed_weights[new_key] = value
elif key.startswith("model.visual."):
new_key = "visual." + key[len("model.visual."):]
transformed_weights[new_key] = value
else:
transformed_weights[key] = value
return super().filter_weights(prefix, transformed_weights)
return transformed_weights
180 changes: 88 additions & 92 deletions tensorrt_llm/_torch/models/modeling_qwen2vl.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import os
import re
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
Expand All @@ -20,6 +21,8 @@
PredefinedAttentionMask
from tensorrt_llm._torch.models.checkpoints.base_weight_mapper import \
BaseWeightMapper
from tensorrt_llm._torch.models.checkpoints.hf.qwen2vl_weight_mapper import \
Qwen2VLHfWeightMapper
from tensorrt_llm._torch.modules.attention import Attention
from tensorrt_llm._torch.modules.linear import Linear
from tensorrt_llm._torch.modules.rms_norm import RMSNorm
Expand All @@ -42,7 +45,8 @@
from .modeling_auto import AutoModelForCausalLM
from .modeling_multimodal_utils import (find_input_mm_embeds, fuse_input_embeds,
get_multimodal_embeddings)
from .modeling_utils import (ModelConfig, register_auto_model,
from .modeling_utils import (ModelConfig, QuantConfig, _load_weights_impl,
filter_weights, register_auto_model,
register_vision_encoder)

DISAGG = os.getenv('TLLM_MULTIMODAL_DISAGGREGATED', '0') == '1'
Expand Down Expand Up @@ -96,6 +100,7 @@ def __init__(self,
tokenizer: AutoTokenizer,
trust_remote_code: bool = True):
self.model_config = model_config
self.vision_dtype = self.model_config.torch_dtype
self.tokenizer = tokenizer if tokenizer is not None else AutoTokenizer.from_pretrained(
model_path)
self.use_fast = True
Expand Down Expand Up @@ -423,14 +428,15 @@ def __call__(
pixel_values = processed_inputs.get('pixel_values', None)
if pixel_values is not None:
multimodal_data["image"] = {
"pixel_values": pixel_values,
"pixel_values": pixel_values.to(self.vision_dtype),
"image_grid_thw": processed_inputs.get('image_grid_thw')
}

pixel_values_videos = processed_inputs.get('pixel_values_videos', None)
if pixel_values_videos is not None:
multimodal_data["video"] = {
"pixel_values_videos": pixel_values_videos,
"pixel_values_videos":
pixel_values_videos.to(self.vision_dtype),
"video_grid_thw": processed_inputs.get('video_grid_thw')
}

Expand Down Expand Up @@ -458,29 +464,59 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig],
model_class: Union[type[PreTrainedModel],
type[torch.nn.Module]]):
super().__init__()
config = model_config.pretrained_config.vision_config
config.torch_dtype = model_config.pretrained_config.torch_dtype
self.model_config = model_config
self.model_dtype = config.torch_dtype
self.model_dtype = self.model_config.pretrained_config.torch_dtype
self.config = self.model_config.pretrained_config.vision_config
self.config.num_attention_heads = self.config.num_heads

# NOTE: Re-setting QuantConfig to exclude vision encoder weights from quantization load.
self.model_config.quant_config = QuantConfig(
kv_cache_quant_algo=self.model_config.quant_config.
kv_cache_quant_algo)

if model_class in [
Qwen2VisionTransformerPretrainedModel,
Qwen2_5_VisionTransformerPretrainedModel
]:
# NOTE: For Qwen2VL, we use flash_attention_2 for attention implementation to avoid OOM issue.
config._attn_implementation = 'flash_attention_2'
self.visual = model_class(config).to(self.model_dtype).eval()
self.config._attn_implementation = 'flash_attention_2'
self.visual = model_class(
model_config.pretrained_config.vision_config).to(
self.model_dtype).eval()
elif model_class == Qwen2_5_VisionModel:
self.visual = model_class(self.model_config).to(
self.model_dtype).eval()
self.visual = model_class(self.model_config).to(self.model_dtype)
else:
raise NotImplementedError(
f"Model class {model_class} not implemented")

self.post_config()

def post_config(self):
self.config = self.model_config.pretrained_config.vision_config
def load_weights(self, weights: Dict):
visual_weights = filter_weights("visual", weights)
converted_weights = dict()

qkv_pattern = re.compile(r'(.*?)attn\.qkv\.(.*)')
for name in visual_weights:
# Handle with weights and bias for vision transformer's qkv projection.
match = qkv_pattern.match(name)
if match:
prefix, suffix = match.groups()
q_name = f"{prefix}attn.q_proj.{suffix}"
k_name = f"{prefix}attn.k_proj.{suffix}"
v_name = f"{prefix}attn.v_proj.{suffix}"
dim_shape = visual_weights[name].shape[0] // 3
converted_weights[q_name] = visual_weights[name][:dim_shape]
converted_weights[k_name] = visual_weights[name][dim_shape:2 *
dim_shape]
converted_weights[v_name] = visual_weights[name][2 * dim_shape:]
else:
converted_weights[name] = visual_weights[name]
pattern_mapping = {
r'(.*?)attn.proj.(.*)': r'\1attn.o_proj.\2',
r'(.*?)mlp.fc1.(.*)': r'\1mlp.up_proj.\2',
r'(.*?)mlp.fc2.(.*)': r'\1mlp.down_proj.\2',
}
_load_weights_impl(self.visual,
converted_weights,
params_map=pattern_mapping)

def _parse_and_batch_multimodal_data(
self, multimodal_params: List[MultimodalParams]
Expand Down Expand Up @@ -545,12 +581,10 @@ def forward(self, multimodal_params: List[MultimodalParams]):

embeds = []
if pixel_values is not None:
pixel_values = pixel_values.to(self.model_dtype)
embed = self.visual(pixel_values, grid_thw=image_grid_thw)
embeds.append(embed)

if pixel_values_videos is not None:
pixel_values_videos = pixel_values_videos.to(self.model_dtype)
embeds.append(
self.visual(pixel_values_videos, grid_thw=video_grid_thw))
return embeds
Expand Down Expand Up @@ -691,32 +725,33 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
class Qwen2_5_VisionModel(torch.nn.Module):

def __init__(self, model_config: ModelConfig[PretrainedConfig]):
config = model_config.pretrained_config.vision_config
super().__init__()
self.model_config = model_config
self.config = self.model_config.pretrained_config.vision_config

self.spatial_merge_size = config.spatial_merge_size
self.patch_size = config.patch_size
self.fullatt_block_indexes = config.fullatt_block_indexes
self.window_size = config.window_size
self.spatial_merge_size = self.config.spatial_merge_size
self.patch_size = self.config.patch_size
self.fullatt_block_indexes = self.config.fullatt_block_indexes
self.window_size = self.config.window_size
self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size

self.patch_embed = Qwen2_5_VisionPatchEmbed(
patch_size=config.patch_size,
temporal_patch_size=config.temporal_patch_size,
in_channels=config.in_channels,
embed_dim=config.hidden_size,
patch_size=self.config.patch_size,
temporal_patch_size=self.config.temporal_patch_size,
in_channels=self.config.in_channels,
embed_dim=self.config.hidden_size,
)

head_dim = config.hidden_size // config.num_heads
head_dim = self.config.hidden_size // self.config.num_heads
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)

self.blocks = torch.nn.ModuleList([
Qwen2_5_VLVisionBlock(model_config, layer_idx=layer_idx)
for layer_idx in range(config.depth)
for layer_idx in range(self.config.depth)
])
self.merger = Qwen2_5_VLPatchMerger(model_config, )
self.merger = Qwen2_5_VLPatchMerger(self.model_config, )
self.metadata_cls = get_attention_backend(
model_config.attn_backend).Metadata
self.model_config.attn_backend).Metadata

self.full_attn_metadata = self.metadata_cls(
max_num_requests=8192, # TODO: Make this dynamic
Expand Down Expand Up @@ -874,35 +909,31 @@ def __init__(
*args,
**kwargs,
) -> None:
model_config.pretrained_config.rope_scaling['type'] = 'mrope'
self.original_arch = model_config.pretrained_config.architectures[0]

# NOTE: Setting disable_fuse_rope to True to do mrope fusion in the model engine by pre-computing rotary_cos_sin in the model engine
disabble_fuse_rope = kwargs.get('disable_fuse_rope', False)
model_config.pretrained_config.text_config.disable_fuse_rope = disabble_fuse_rope
model_config.pretrained_config.disable_fuse_rope = disabble_fuse_rope
model_config.pretrained_config.rope_scaling['type'] = 'mrope'
config = model_config.pretrained_config

assert model_config.attn_backend == 'TRTLLM', "Qwen2/2.5-VL only supports TRTLLM backend now"
super().__init__(config)
if not disabble_fuse_rope:
self.init_mrope_embedding(model_config)

self.model_config = model_config
if hasattr(self, "llm"):
return
self.config = model_config.pretrained_config

if not DISAGG:
self.mm_encoder = Qwen2VisionModelBase(
model_config, kwargs.get('vision_model_class', None)).eval()
if model_config.attn_backend != 'TRTLLM':
raise ValueError("Qwen2/2.5-VL only supports TRTLLM backend now")
if not disabble_fuse_rope:
self.init_mrope_embedding(model_config)

llm_model_config = copy.deepcopy(model_config)
llm_model_config.pretrained_config = config.text_config
llm_model_config.pretrained_config.architectures = ["Qwen2ForCausalLM"]

self.llm = AutoModelForCausalLM.from_config(llm_model_config)
self.model_dtype = getattr(config, "torch_dtype", torch.bfloat16)
logger.info(f"{self.dtype=} {self.model_dtype=}")
self.post_config()
self.is_loaded = True

if not DISAGG:
mm_encoder_config = copy.deepcopy(model_config)
self.mm_encoder = Qwen2VisionModelBase(
mm_encoder_config, kwargs.get('vision_model_class', None))

def init_mrope_embedding(self, model_config: ModelConfig[PretrainedConfig]):
config = model_config.pretrained_config
Expand Down Expand Up @@ -930,11 +961,6 @@ def load_weights(self, weights, weight_mapper: BaseWeightMapper):
def infer_max_seq_len(self) -> int:
return self.llm.infer_max_seq_len()

def post_config(self):
# use llm.config as config for pytorch model engine
self.config = self.llm.config
self.model_config.pretrained_config = self.llm.config

@nvtx_range("Qwen2.5-VL prepare_mrope_config")
def prepare_mrope_config(self, multimodal_params: List[MultimodalParams],
num_context_requests: int):
Expand Down Expand Up @@ -1075,22 +1101,8 @@ def load_weights(self, weights, weight_mapper: BaseWeightMapper):
self.llm.load_weights(weights, weight_mapper)


def getSMVersion():
prop = torch.cuda.get_device_properties(0)
sm_version = prop.major * 10 + prop.minor
return sm_version


get_sm_version = getSMVersion()
if get_sm_version >= 100:
# NOTE: Qwen2.5-VL with SM 100 and above uses HF's implementation due to lacking of TRT-LLM's Attention kernel.
QWEN2_5_VL_VISION_MODEL_CLASS = Qwen2_5_VisionTransformerPretrainedModel
else:
QWEN2_5_VL_VISION_MODEL_CLASS = Qwen2_5_VisionModel


@register_vision_encoder(Qwen2VisionModelBase,
vlm_base_model=QWEN2_5_VL_VISION_MODEL_CLASS)
vlm_base_model=Qwen2_5_VisionModel)
@register_auto_model("Qwen2_5_VLForConditionalGeneration")
@register_input_processor(
Qwen2VLInputProcessorBase,
Expand All @@ -1106,39 +1118,23 @@ class Qwen2_5_VLModel(Qwen2VLModelBase):

def __init__(self, model_config: ModelConfig[PretrainedConfig], *args,
**kwargs):
kwargs['vision_model_class'] = QWEN2_5_VL_VISION_MODEL_CLASS
kwargs['vision_model_class'] = Qwen2_5_VisionModel
kwargs[
'disable_fuse_rope'] = False # TODO: Make this ModelConfig's argument
super().__init__(model_config, *args, **kwargs)

@property
def multimodal_data_device_paths(self) -> List[str]:
if get_sm_version >= 100:
return [
"image.pixel_values", "video.pixel_values_videos",
"image.image_grid_thw", "video.video_grid_thw",
"multimodal_embedding"
]
else:
return [
"image.pixel_values", "video.pixel_values_videos",
"multimodal_embedding"
]
return [
"image.pixel_values", "video.pixel_values_videos",
"multimodal_embedding"
]

def load_weights(self, weights, weight_mapper: BaseWeightMapper):
if isinstance(weight_mapper, Qwen2VLHfWeightMapper):
weights = weight_mapper.preprocess_weights(weights)

if not DISAGG:
if get_sm_version >= 100:
weight_name_mapping = None
else:
# Process vision encoder weights
weight_name_mapping = {
"attn.proj.weight": "attn.o_proj.weight",
"attn.proj.bias": "attn.o_proj.bias",
"attn.qkv.weight": "attn.qkv_proj.weight",
"attn.qkv.bias": "attn.qkv_proj.bias"
}
vision_weights = process_weights(weights, "visual",
weight_name_mapping)
self.mm_encoder.load_state_dict(vision_weights, strict=True)
self.mm_encoder.load_weights(weights)

self.llm.load_weights(weights, weight_mapper)
self.llm.load_weights(weights)