Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions tensorrt_llm/_torch/models/modeling_qwen2vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1054,6 +1054,45 @@ def load_weights(self, weights, weight_mapper: BaseWeightMapper):
}
vision_weights = process_weights(weights, "visual",
weight_name_mapping)
if get_sm_version < 100:
tp_rank = self.model_config.mapping.tp_rank
tp_size = self.model_config.mapping.tp_size
num_vision_heads = self.mm_encoder.config.num_heads

# Need to shard the weights to support tp
def shard_qkv(tensor, is_weight=True):
hidden_dim = tensor.shape[0] // 3
head_dim = hidden_dim // num_vision_heads
heads_per_tp = num_vision_heads // tp_size
start, end = tp_rank * heads_per_tp, (tp_rank +
1) * heads_per_tp

if is_weight:
tensor = tensor.reshape(3, num_vision_heads, head_dim,
hidden_dim)
sliced = tensor[:, start:end].reshape(3, -1, hidden_dim)
else:
tensor = tensor.reshape(3, num_vision_heads, head_dim)
sliced = tensor[:, start:end].reshape(3, -1)

return torch.cat([sliced[0], sliced[1], sliced[2]], dim=0)

for key in vision_weights.keys():
if "attn.qkv_proj" in key:
if "weight" in key:
# qkv_proj.weight shape: [3 * hidden_dim, hidden_dim]
vision_weights[key] = shard_qkv(vision_weights[key],
is_weight=True)
elif "bias" in key:
# qkv_proj.bias shape: [3 * hidden_dim]
vision_weights[key] = shard_qkv(vision_weights[key],
is_weight=False)

if "attn.o_proj.weight" in key:
# o_proj.weight shape: [hidden_dim, hidden_dim]
vision_weights[key] = torch.chunk(vision_weights[key],
tp_size,
Comment on lines +1058 to +1094
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Add explicit head-count divisibility checks

heads_per_tp = num_vision_heads // tp_size silently floors. If someone configures TP with a head-count that isn’t cleanly divisible, we’ll reshape/slice to the wrong size and load_state_dict will fail (or worse, we’d drop heads). Please reject misconfigured checkpoints up front and validate that each per-projection slice splits cleanly across heads before reshaping.

                 num_vision_heads = self.mm_encoder.config.num_heads
 
                 # Need to shard the weights to support tp
                 def shard_qkv(tensor, is_weight=True):
-                    hidden_dim = tensor.shape[0] // 3
+                    if tensor.shape[0] % 3 != 0:
+                        raise ValueError(
+                            f"Unexpected fused-qkv size {tensor.shape[0]} (not divisible by 3).")
+                    hidden_dim = tensor.shape[0] // 3
+                    if hidden_dim % num_vision_heads != 0:
+                        raise ValueError(
+                            f"Vision head split requires hidden_dim {hidden_dim} to be divisible by num_heads {num_vision_heads}.")
                     head_dim = hidden_dim // num_vision_heads
                     heads_per_tp = num_vision_heads // tp_size
                     start, end = tp_rank * heads_per_tp, (tp_rank +
                                                           1) * heads_per_tp
@@
-                for key in vision_weights.keys():
+                if num_vision_heads % tp_size != 0:
+                    raise ValueError(
+                        f"Vision TP requires num_heads ({num_vision_heads}) to be divisible by tp_size ({tp_size}).")
+                for key in vision_weights.keys():
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
tp_rank = self.model_config.mapping.tp_rank
tp_size = self.model_config.mapping.tp_size
num_vision_heads = self.mm_encoder.config.num_heads
# Need to shard the weights to support tp
def shard_qkv(tensor, is_weight=True):
hidden_dim = tensor.shape[0] // 3
head_dim = hidden_dim // num_vision_heads
heads_per_tp = num_vision_heads // tp_size
start, end = tp_rank * heads_per_tp, (tp_rank +
1) * heads_per_tp
if is_weight:
tensor = tensor.reshape(3, num_vision_heads, head_dim,
hidden_dim)
sliced = tensor[:, start:end].reshape(3, -1, hidden_dim)
else:
tensor = tensor.reshape(3, num_vision_heads, head_dim)
sliced = tensor[:, start:end].reshape(3, -1)
return torch.cat([sliced[0], sliced[1], sliced[2]], dim=0)
for key in vision_weights.keys():
if "attn.qkv_proj" in key:
if "weight" in key:
# qkv_proj.weight shape: [3 * hidden_dim, hidden_dim]
vision_weights[key] = shard_qkv(vision_weights[key],
is_weight=True)
elif "bias" in key:
# qkv_proj.bias shape: [3 * hidden_dim]
vision_weights[key] = shard_qkv(vision_weights[key],
is_weight=False)
if "attn.o_proj.weight" in key:
# o_proj.weight shape: [hidden_dim, hidden_dim]
vision_weights[key] = torch.chunk(vision_weights[key],
tp_size,
tp_rank = self.model_config.mapping.tp_rank
tp_size = self.model_config.mapping.tp_size
num_vision_heads = self.mm_encoder.config.num_heads
# Need to shard the weights to support tp
def shard_qkv(tensor, is_weight=True):
if tensor.shape[0] % 3 != 0:
raise ValueError(
f"Unexpected fused-qkv size {tensor.shape[0]} (not divisible by 3).")
hidden_dim = tensor.shape[0] // 3
if hidden_dim % num_vision_heads != 0:
raise ValueError(
f"Vision head split requires hidden_dim {hidden_dim} to be divisible by num_heads {num_vision_heads}.")
head_dim = hidden_dim // num_vision_heads
heads_per_tp = num_vision_heads // tp_size
start, end = tp_rank * heads_per_tp, (tp_rank + 1) * heads_per_tp
if is_weight:
tensor = tensor.reshape(3, num_vision_heads, head_dim,
hidden_dim)
sliced = tensor[:, start:end].reshape(3, -1, hidden_dim)
else:
tensor = tensor.reshape(3, num_vision_heads, head_dim)
sliced = tensor[:, start:end].reshape(3, -1)
return torch.cat([sliced[0], sliced[1], sliced[2]], dim=0)
if num_vision_heads % tp_size != 0:
raise ValueError(
f"Vision TP requires num_heads ({num_vision_heads}) to be divisible by tp_size ({tp_size}).")
for key in vision_weights.keys():
if "attn.qkv_proj" in key:
if "weight" in key:
# qkv_proj.weight shape: [3 * hidden_dim, hidden_dim]
vision_weights[key] = shard_qkv(vision_weights[key],
is_weight=True)
elif "bias" in key:
# qkv_proj.bias shape: [3 * hidden_dim]
vision_weights[key] = shard_qkv(vision_weights[key],
is_weight=False)
if "attn.o_proj.weight" in key:
# o_proj.weight shape: [hidden_dim, hidden_dim]
vision_weights[key] = torch.chunk(vision_weights[key],
tp_size,
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/models/modeling_qwen2vl.py around lines 1058 to 1094, add
explicit divisibility checks to reject misconfigured TP/head counts before
reshaping: validate that num_vision_heads % tp_size == 0 and raise a clear
ValueError if not; inside shard_qkv for weights assert tensor.shape[0] % 3 == 0,
compute hidden_dim = tensor.shape[0] // 3 then assert hidden_dim %
num_vision_heads == 0 before computing head_dim; for biases assert
tensor.shape[0] % 3 == 0 and (tensor.shape[0] // 3) % num_vision_heads == 0;
lastly, before torch.chunk(vision_weights[key], tp_size) check that
vision_weights[key].shape[0] % tp_size == 0 and raise a ValueError if not so
reshapes/chunks never silently floor or drop heads.

dim=1)[tp_rank]
self.mm_encoder.load_state_dict(vision_weights, strict=True)

self.llm.load_weights(weights, weight_mapper)
Loading