Skip to content

Conversation

@yechank-nvidia
Copy link
Collaborator

@yechank-nvidia yechank-nvidia commented Oct 27, 2025

Summary by CodeRabbit

Release Notes

  • Refactor
    • Restructured vision model weight loading mechanism with enhanced mapping and conversion utilities
    • Consolidated model initialization paths with consistent configuration object references
    • Eliminated device-specific conditional logic for broader device compatibility
    • Streamlined multimodal data routing for vision components

@yechank-nvidia
Copy link
Collaborator Author

/bot run

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 27, 2025

📝 Walkthrough

Walkthrough

This PR refactors the Qwen2VL model's weight loading architecture by introducing explicit load_weights methods in vision model classes and replacing post_config patterns. It consolidates multimodal data path logic, removes SM version conditional branches, and reorganizes initialization to use centralized configuration objects.

Changes

Cohort / File(s) Summary
Imports and API expansion
tensorrt_llm/_torch/models/modeling_qwen2vl.py
Added re module; expanded public exports from modeling_utils to include QuantConfig, _load_weights_impl, and filter_weights.
Vision weight loading
tensorrt_llm/_torch/models/modeling_qwen2vl.py
Introduced load_weights methods to Qwen2VisionModelBase and Qwen2VLModelBase. Qwen2VisionModelBase.load_weights filters vision weights, converts QKV grouping into separate q_proj/k_proj/v_proj entries, and remaps projection names using _load_weights_impl with custom mapping. Qwen2VLModelBase.load_weights is a pass-through placeholder.
Initialization and configuration
tensorrt_llm/_torch/models/modeling_qwen2vl.py
Updated Qwen2VLInputProcessorBase to store vision_dtype from model_config.torch_dtype. Modified Qwen2VisionModelBase, Qwen2_5_VisionModel, and related classes to consistently pull configuration from self.model_config and self.config rather than transient objects. Removed post_config methods across classes.
Multimodal data paths
tensorrt_llm/_torch/models/modeling_qwen2vl.py
Simplified Qwen2VLModel.multimodal_data_device_paths and Qwen2_5_VLModel.multimodal_data_device_paths to return fixed key sets without SM version conditionals. Vision pixel_values and pixel_values_videos now route through to(self.vision_dtype).
Conditional logic removal
tensorrt_llm/_torch/models/modeling_qwen2vl.py
Removed SM version gated branches for vision model class selection and related conditional paths, defaulting to consistent implementation paths.

Sequence Diagram(s)

sequenceDiagram
    participant Caller
    participant Model as Qwen2VLModel
    participant VisionBase as Qwen2VisionModelBase
    participant Utils as _load_weights_impl
    participant Config as ModelConfig

    Note over Caller,Config: New Weight Loading Flow
    Caller->>Model: load_weights(weights)
    Model->>VisionBase: load_weights(weights)
    VisionBase->>VisionBase: filter_weights(vision_weights)
    VisionBase->>VisionBase: convert QKV→q_proj/k_proj/v_proj
    VisionBase->>Utils: _load_weights_impl(converted_weights,<br/>pattern_mapping)
    Utils-->>VisionBase: mapped weights
    VisionBase->>Config: apply mapped weights

    Note over Caller,Config: Initialization Flow (Updated)
    Caller->>Model: __init__(model_config)
    Model->>VisionBase: __init__()
    VisionBase->>Config: get vision_config from<br/>model_config.pretrained_config
    Config-->>VisionBase: vision_config
    VisionBase->>VisionBase: Store vision_dtype from<br/>model_config.torch_dtype
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

  • Weight loading mechanism: New load_weights implementations with QKV conversion and pattern-based weight mapping across multiple vision classes require careful verification of correctness.
  • Configuration management: Changes to how configuration objects are passed and referenced throughout the class hierarchy (self.model_config vs. self.config) need validation for consistency.
  • Initialization flow: Updated paths for storing and accessing vision_dtype and vision_config require tracing through multiple class layers to ensure proper alignment.
  • Conditional removal: Verification that removal of SM version branches doesn't inadvertently break functionality on specific hardware configurations.
  • Multimodal data paths: Simplification to fixed key sets needs validation that all required keys are present across use cases.

Pre-merge checks and finishing touches

❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Description Check ⚠️ Warning The pull request description is entirely absent. According to the repository's description template, the PR should include a Description section explaining the issue and solution, a Test Coverage section listing relevant tests, and a PR Checklist. The current submission provides no description content whatsoever, making it impossible to understand the rationale behind the changes, the problem being solved, or how the changes were validated. This represents a complete absence of required documentation rather than just incomplete sections. The author should provide a complete PR description following the template, including: (1) a clear explanation of the issue being fixed (e.g., why weight loading fails with TP > 1 and quantized weights), (2) a summary of the solution approach (e.g., the weight mapping and conversion strategy implemented), and (3) a list of relevant tests that validate the fix. Additionally, the PR Checklist should be reviewed and the final checkbox completed after confirming all items have been addressed.
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (1 passed)
Check name Status Explanation
Title Check ✅ Passed The PR title "[https://nvbugs/5549829][fix] Qwen2.5-VL TP > 1 + Quantized weight load fix" follows the required template format with a valid NVBugs ID, the "fix" type indicator, and a specific summary. The title is directly aligned with the changeset's main objective: the raw_summary documents extensive refactoring of weight loading mechanisms in the Qwen2VL model, including the introduction of new load_weights methods, weight conversion and mapping logic, and removal of post_config methods. The title accurately captures the essence of these weight-loading-related fixes for the Qwen2.5-VL model under specific conditions (TP > 1 and quantized weights).
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
tensorrt_llm/_torch/models/modeling_qwen2vl.py (2)

365-367: Python 3.8 incompatibilities in type hints (PEP 585 generics, type[...], and any).

Project targets Python 3.8+. Replace built-in generics and type[...] with typing equivalents; use Any not any.

Apply:

-    def _preprocess(self, text: dict[str, any], mm_data: dict[str, any],
-                    mm_processor_kwargs: Dict[str, Any]):
+    def _preprocess(self, text: Dict[str, Any], mm_data: Dict[str, Any],
+                    mm_processor_kwargs: Dict[str, Any]):
@@
-            self,
-            input_ids: torch.IntTensor,
-            image_grid_thw: torch.LongTensor,
-            video_grid_thw: torch.LongTensor,
-            attention_mask: torch.Tensor,
-            second_per_grid_ts: torch.Tensor = None) -> dict[str, torch.Tensor]:
+            self,
+            input_ids: torch.IntTensor,
+            image_grid_thw: torch.LongTensor,
+            video_grid_thw: torch.LongTensor,
+            attention_mask: torch.Tensor,
+            second_per_grid_ts: torch.Tensor = None) -> Dict[str, torch.Tensor]:
@@
-    def __init__(self, model_config: ModelConfig[PretrainedConfig],
-                 model_class: Union[type[PreTrainedModel],
-                                    type[torch.nn.Module]]):
+    def __init__(self, model_config: ModelConfig[PretrainedConfig],
+                 model_class: Union[Type[PreTrainedModel], Type[torch.nn.Module]]):
@@
-        position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]],
+        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]],
@@
-        position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
+        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,

Ensure:

-from typing import Any, Dict, List, Optional, Tuple, Union
+from typing import Any, Dict, List, Optional, Tuple, Union, Type

Also applies to: 401-401, 461-464, 611-617, 667-671


1094-1099: Strict state_dict load for HF vision may be brittle with quantized checkpoints.

Quantized/unified checkpoints often include extra buffers (e.g., scales). Consider strict=False or a filtered dict to avoid spurious failures.

Apply:

-            self.mm_encoder.load_state_dict(vision_encoder_weights, strict=True)
+            missing, unexpected = self.mm_encoder.load_state_dict(vision_encoder_weights, strict=False)
+            if missing:
+                logger.debug(f"Missing mm vision weights: {missing}")
+            if unexpected:
+                logger.debug(f"Unexpected mm vision weights: {unexpected}")
🧹 Nitpick comments (4)
tensorrt_llm/_torch/models/modeling_qwen2vl.py (4)

54-90: Consolidated weight filtering helper is fine; prefer re.sub mapping for robustness.

Current suffix replace works but can misfire on overlapping suffixes. If/when you pass weight_name_mapping, consider regex-based rename to mirror _load_weights_impl semantics.

Example:

-            new_key = key
-            for old_suffix, new_suffix in weight_name_mapping.items():
-                if key.endswith(old_suffix):
-                    new_key = key.replace(old_suffix, new_suffix)
-                    break
+            new_key = key
+            for pattern, repl in weight_name_mapping.items():
+                new_key = re.sub(pattern, repl, new_key)

481-486: Vision module created with dtype only; consider explicit device or document CPU path.

.to(self.model_dtype) leaves the module on CPU. Given TRTLLM backend in this stack, clarify intent (CPU vision, GPU LLM) or move to CUDA when appropriate.

Example:

-            self.visual = model_class(
-                model_config.pretrained_config.vision_config).to(
-                    self.model_dtype).eval()
+            self.visual = model_class(model_config.pretrained_config.vision_config)
+            self.visual = self.visual.to(dtype=self.model_dtype)  # and optionally .to(device)
+            self.visual.eval()

1125-1128: Device paths: grid_thw removed—confirm not needed on device.

Qwen2_5_VisionModel.forward builds GPU indices if hidden states are on GPU; CPU grid_thw indexing into GPU tensors would fail. Verify grid_thw stays on CPU only if vision remains on CPU.

If vision runs on GPU, include:

-        return [
-            "image.pixel_values", "video.pixel_values_videos",
-            "multimodal_embedding"
-        ]
+        return [
+            "image.pixel_values", "image.image_grid_thw",
+            "video.pixel_values_videos", "video.video_grid_thw",
+            "multimodal_embedding"
+        ]

490-517: Add input validation and tighten regex for QKV weight mapping.

The current implementation assumes well-formed checkpoints and lacks defensive guards. Two improvements:

  1. Stricter regex: Change r'(.*?)attn\.qkv\.(.*)' to r'(.*?)attn\.qkv\.(weight|bias)$' to explicitly match only weight/bias tensors and anchor to string end.

  2. Shape divisibility check: Validate that visual_weights[name].shape[0] % 3 == 0 before splitting, raising ValueError if not. This prevents silent incorrect splits (e.g., shape of 10 would split as 3→3→4 instead of failing). Similar validation is used elsewhere in the codebase (e.g., modeling_utils.py:112).

These are recommended defensive improvements to catch malformed checkpoints early.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 8090c96 and 39d522f.

📒 Files selected for processing (1)
  • tensorrt_llm/_torch/models/modeling_qwen2vl.py (9 hunks)
🧰 Additional context used
📓 Path-based instructions (3)
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Use only spaces, no tabs; indent with 4 spaces.

Files:

  • tensorrt_llm/_torch/models/modeling_qwen2vl.py
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Python code must target Python 3.8+.
Indent Python code with 4 spaces; do not use tabs.
Maintain module namespace when importing; prefer 'from package.subpackage import foo' then 'foo.SomeClass()' instead of importing the class directly.
Python filenames should be snake_case (e.g., some_file.py).
Python classes use PascalCase names.
Functions and methods use snake_case names.
Local variables use snake_case; prefix 'k' for variables that start with a number (e.g., k_99th_percentile).
Global variables use upper SNAKE_CASE prefixed with 'G' (e.g., G_MY_GLOBAL).
Constants use upper SNAKE_CASE (e.g., MY_CONSTANT).
Avoid shadowing variables from an outer scope.
Initialize all externally visible members of a class in the constructor.
Prefer docstrings for interfaces that may be used outside a file; comments for in-function or file-local interfaces.
Use Google-style docstrings for classes and functions (Sphinx-parsable).
Document attributes and variables inline so they render under the class/function docstring.
Avoid reflection when a simpler, explicit approach suffices (e.g., avoid dict(**locals()) patterns).
In try/except, catch the most specific exceptions possible.
For duck-typing try/except, keep the try body minimal and use else for the main logic.

Files:

  • tensorrt_llm/_torch/models/modeling_qwen2vl.py
**/*.{cpp,cxx,cc,h,hpp,hh,hxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend the NVIDIA Apache-2.0 copyright header with current year to the top of all source files (e.g., .cpp, .h, .cu, .py).

Files:

  • tensorrt_llm/_torch/models/modeling_qwen2vl.py
🧬 Code graph analysis (1)
tensorrt_llm/_torch/models/modeling_qwen2vl.py (2)
tensorrt_llm/_torch/model_config.py (1)
  • torch_dtype (217-222)
tensorrt_llm/models/modeling_utils.py (1)
  • QuantConfig (131-271)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (7)
tensorrt_llm/_torch/models/modeling_qwen2vl.py (7)

3-3: Import updates LGTM.

Adding re and exposing QuantConfig/_load_weights_impl/filter_weights is appropriate for the new loading path.

Also applies to: 46-48


727-753: Qwen2_5_VisionModel wiring LGTM.

Config sourcing and block construction are consistent; attention backend metadata class use is coherent.


923-925: mRoPE init conditional is fine; confirm generation path handles fused vs non-fused consistently.

No change requested; just verify both paths covered by tests.


930-934: Instantiation order LGTM.

Deepcopy before constructing LLM/vision avoids shared state; good.


1102-1102: Registration to Qwen2_5_VisionModel LGTM.

Matches new encoder path.

Also applies to: 1118-1118


1132-1132: Delegated vision weight load LGTM.

Using mm_encoder.load_weights ensures qkv split + rename and avoids unwanted quant handling.


101-101: No critical device mismatch; PyTorch auto-handles tensor movement.

The code path works correctly due to PyTorch's automatic device movement: the vision encoder (self.visual) is explicitly moved to GPU at line 485 via .to(self.model_dtype), so when it processes pixel_values (dtype-cast at lines 429, 437), the output embeddings are automatically on the GPU. When these embeds reach fuse_input_embeds at line 1049, they're already on the correct device for concatenation with input_ids.

The review comment's cautionary note is reasonable for code clarity—explicitly verifying or documenting device alignment would improve code maintainability—but the current implementation doesn't have a functional issue. The implicit device movement via PyTorch's model execution is safe and standard practice.

@tensorrt-cicd
Copy link
Collaborator

PR_Github #22614 [ run ] triggered by Bot. Commit: 39d522f

@tensorrt-cicd
Copy link
Collaborator

PR_Github #22614 [ run ] completed with state SUCCESS. Commit: 39d522f
/LLM/main/L0_MergeRequest_PR pipeline #17047 completed with status: 'FAILURE'

Copy link
Collaborator

@jaedeok-nvidia jaedeok-nvidia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing bug. Overall this PR fixes the weight loading as well as modularizes the logic, which looks good to me. I've left several questions. Would you please take a look?

@yechank-nvidia
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #22701 [ run ] triggered by Bot. Commit: d7d53a4

@tensorrt-cicd
Copy link
Collaborator

PR_Github #22701 [ run ] completed with state FAILURE. Commit: d7d53a4
/LLM/main/L0_MergeRequest_PR pipeline #17117 completed with status: 'FAILURE'

@yechank-nvidia yechank-nvidia added the Multimodal Label for issues & PRs regarding Multimodal related objects label Oct 28, 2025
Signed-off-by: yechank <[email protected]>
@yechank-nvidia
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #22816 [ run ] triggered by Bot. Commit: 8beebf9

@tensorrt-cicd
Copy link
Collaborator

PR_Github #22816 [ run ] completed with state SUCCESS. Commit: 8beebf9
/LLM/main/L0_MergeRequest_PR pipeline #17210 completed with status: 'SUCCESS'

@yechank-nvidia yechank-nvidia merged commit bc26f4c into NVIDIA:main Oct 29, 2025
5 checks passed
dominicshanshan pushed a commit to dominicshanshan/TensorRT-LLM that referenced this pull request Nov 1, 2025
dominicshanshan pushed a commit to dominicshanshan/TensorRT-LLM that referenced this pull request Nov 3, 2025
dominicshanshan pushed a commit to dominicshanshan/TensorRT-LLM that referenced this pull request Nov 3, 2025
dominicshanshan pushed a commit to dominicshanshan/TensorRT-LLM that referenced this pull request Nov 3, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Multimodal Label for issues & PRs regarding Multimodal related objects

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants