-
Notifications
You must be signed in to change notification settings - Fork 1.8k
[https://nvbugs/5549829][fix] Qwen2.5-VL TP > 1 + Quantized weight load fix #8680
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
/bot run |
📝 WalkthroughWalkthroughThis PR refactors the Qwen2VL model's weight loading architecture by introducing explicit Changes
Sequence Diagram(s)sequenceDiagram
participant Caller
participant Model as Qwen2VLModel
participant VisionBase as Qwen2VisionModelBase
participant Utils as _load_weights_impl
participant Config as ModelConfig
Note over Caller,Config: New Weight Loading Flow
Caller->>Model: load_weights(weights)
Model->>VisionBase: load_weights(weights)
VisionBase->>VisionBase: filter_weights(vision_weights)
VisionBase->>VisionBase: convert QKV→q_proj/k_proj/v_proj
VisionBase->>Utils: _load_weights_impl(converted_weights,<br/>pattern_mapping)
Utils-->>VisionBase: mapped weights
VisionBase->>Config: apply mapped weights
Note over Caller,Config: Initialization Flow (Updated)
Caller->>Model: __init__(model_config)
Model->>VisionBase: __init__()
VisionBase->>Config: get vision_config from<br/>model_config.pretrained_config
Config-->>VisionBase: vision_config
VisionBase->>VisionBase: Store vision_dtype from<br/>model_config.torch_dtype
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes
Pre-merge checks and finishing touches❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
tensorrt_llm/_torch/models/modeling_qwen2vl.py (2)
365-367: Python 3.8 incompatibilities in type hints (PEP 585 generics,type[...], andany).Project targets Python 3.8+. Replace built-in generics and
type[...]with typing equivalents; useAnynotany.Apply:
- def _preprocess(self, text: dict[str, any], mm_data: dict[str, any], - mm_processor_kwargs: Dict[str, Any]): + def _preprocess(self, text: Dict[str, Any], mm_data: Dict[str, Any], + mm_processor_kwargs: Dict[str, Any]): @@ - self, - input_ids: torch.IntTensor, - image_grid_thw: torch.LongTensor, - video_grid_thw: torch.LongTensor, - attention_mask: torch.Tensor, - second_per_grid_ts: torch.Tensor = None) -> dict[str, torch.Tensor]: + self, + input_ids: torch.IntTensor, + image_grid_thw: torch.LongTensor, + video_grid_thw: torch.LongTensor, + attention_mask: torch.Tensor, + second_per_grid_ts: torch.Tensor = None) -> Dict[str, torch.Tensor]: @@ - def __init__(self, model_config: ModelConfig[PretrainedConfig], - model_class: Union[type[PreTrainedModel], - type[torch.nn.Module]]): + def __init__(self, model_config: ModelConfig[PretrainedConfig], + model_class: Union[Type[PreTrainedModel], Type[torch.nn.Module]]): @@ - position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]], + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]], @@ - position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,Ensure:
-from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union, TypeAlso applies to: 401-401, 461-464, 611-617, 667-671
1094-1099: Strict state_dict load for HF vision may be brittle with quantized checkpoints.Quantized/unified checkpoints often include extra buffers (e.g., scales). Consider
strict=Falseor a filtered dict to avoid spurious failures.Apply:
- self.mm_encoder.load_state_dict(vision_encoder_weights, strict=True) + missing, unexpected = self.mm_encoder.load_state_dict(vision_encoder_weights, strict=False) + if missing: + logger.debug(f"Missing mm vision weights: {missing}") + if unexpected: + logger.debug(f"Unexpected mm vision weights: {unexpected}")
🧹 Nitpick comments (4)
tensorrt_llm/_torch/models/modeling_qwen2vl.py (4)
54-90: Consolidated weight filtering helper is fine; prefer re.sub mapping for robustness.Current suffix replace works but can misfire on overlapping suffixes. If/when you pass weight_name_mapping, consider regex-based rename to mirror _load_weights_impl semantics.
Example:
- new_key = key - for old_suffix, new_suffix in weight_name_mapping.items(): - if key.endswith(old_suffix): - new_key = key.replace(old_suffix, new_suffix) - break + new_key = key + for pattern, repl in weight_name_mapping.items(): + new_key = re.sub(pattern, repl, new_key)
481-486: Vision module created with dtype only; consider explicit device or document CPU path.
.to(self.model_dtype)leaves the module on CPU. Given TRTLLM backend in this stack, clarify intent (CPU vision, GPU LLM) or move to CUDA when appropriate.Example:
- self.visual = model_class( - model_config.pretrained_config.vision_config).to( - self.model_dtype).eval() + self.visual = model_class(model_config.pretrained_config.vision_config) + self.visual = self.visual.to(dtype=self.model_dtype) # and optionally .to(device) + self.visual.eval()
1125-1128: Device paths: grid_thw removed—confirm not needed on device.
Qwen2_5_VisionModel.forwardbuilds GPU indices if hidden states are on GPU; CPUgrid_thwindexing into GPU tensors would fail. Verify grid_thw stays on CPU only if vision remains on CPU.If vision runs on GPU, include:
- return [ - "image.pixel_values", "video.pixel_values_videos", - "multimodal_embedding" - ] + return [ + "image.pixel_values", "image.image_grid_thw", + "video.pixel_values_videos", "video.video_grid_thw", + "multimodal_embedding" + ]
490-517: Add input validation and tighten regex for QKV weight mapping.The current implementation assumes well-formed checkpoints and lacks defensive guards. Two improvements:
Stricter regex: Change
r'(.*?)attn\.qkv\.(.*)'tor'(.*?)attn\.qkv\.(weight|bias)$'to explicitly match only weight/bias tensors and anchor to string end.Shape divisibility check: Validate that
visual_weights[name].shape[0] % 3 == 0before splitting, raisingValueErrorif not. This prevents silent incorrect splits (e.g., shape of 10 would split as 3→3→4 instead of failing). Similar validation is used elsewhere in the codebase (e.g.,modeling_utils.py:112).These are recommended defensive improvements to catch malformed checkpoints early.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tensorrt_llm/_torch/models/modeling_qwen2vl.py(9 hunks)
🧰 Additional context used
📓 Path-based instructions (3)
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Use only spaces, no tabs; indent with 4 spaces.
Files:
tensorrt_llm/_torch/models/modeling_qwen2vl.py
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py: Python code must target Python 3.8+.
Indent Python code with 4 spaces; do not use tabs.
Maintain module namespace when importing; prefer 'from package.subpackage import foo' then 'foo.SomeClass()' instead of importing the class directly.
Python filenames should be snake_case (e.g., some_file.py).
Python classes use PascalCase names.
Functions and methods use snake_case names.
Local variables use snake_case; prefix 'k' for variables that start with a number (e.g., k_99th_percentile).
Global variables use upper SNAKE_CASE prefixed with 'G' (e.g., G_MY_GLOBAL).
Constants use upper SNAKE_CASE (e.g., MY_CONSTANT).
Avoid shadowing variables from an outer scope.
Initialize all externally visible members of a class in the constructor.
Prefer docstrings for interfaces that may be used outside a file; comments for in-function or file-local interfaces.
Use Google-style docstrings for classes and functions (Sphinx-parsable).
Document attributes and variables inline so they render under the class/function docstring.
Avoid reflection when a simpler, explicit approach suffices (e.g., avoid dict(**locals()) patterns).
In try/except, catch the most specific exceptions possible.
For duck-typing try/except, keep the try body minimal and use else for the main logic.
Files:
tensorrt_llm/_torch/models/modeling_qwen2vl.py
**/*.{cpp,cxx,cc,h,hpp,hh,hxx,cu,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Prepend the NVIDIA Apache-2.0 copyright header with current year to the top of all source files (e.g., .cpp, .h, .cu, .py).
Files:
tensorrt_llm/_torch/models/modeling_qwen2vl.py
🧬 Code graph analysis (1)
tensorrt_llm/_torch/models/modeling_qwen2vl.py (2)
tensorrt_llm/_torch/model_config.py (1)
torch_dtype(217-222)tensorrt_llm/models/modeling_utils.py (1)
QuantConfig(131-271)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (7)
tensorrt_llm/_torch/models/modeling_qwen2vl.py (7)
3-3: Import updates LGTM.Adding re and exposing QuantConfig/_load_weights_impl/filter_weights is appropriate for the new loading path.
Also applies to: 46-48
727-753: Qwen2_5_VisionModel wiring LGTM.Config sourcing and block construction are consistent; attention backend metadata class use is coherent.
923-925: mRoPE init conditional is fine; confirm generation path handles fused vs non-fused consistently.No change requested; just verify both paths covered by tests.
930-934: Instantiation order LGTM.Deepcopy before constructing LLM/vision avoids shared state; good.
1102-1102: Registration to Qwen2_5_VisionModel LGTM.Matches new encoder path.
Also applies to: 1118-1118
1132-1132: Delegated vision weight load LGTM.Using
mm_encoder.load_weightsensures qkv split + rename and avoids unwanted quant handling.
101-101: No critical device mismatch; PyTorch auto-handles tensor movement.The code path works correctly due to PyTorch's automatic device movement: the vision encoder (
self.visual) is explicitly moved to GPU at line 485 via.to(self.model_dtype), so when it processes pixel_values (dtype-cast at lines 429, 437), the output embeddings are automatically on the GPU. When these embeds reachfuse_input_embedsat line 1049, they're already on the correct device for concatenation with input_ids.The review comment's cautionary note is reasonable for code clarity—explicitly verifying or documenting device alignment would improve code maintainability—but the current implementation doesn't have a functional issue. The implicit device movement via PyTorch's model execution is safe and standard practice.
|
PR_Github #22614 [ run ] triggered by Bot. Commit: |
|
PR_Github #22614 [ run ] completed with state |
jaedeok-nvidia
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing bug. Overall this PR fixes the weight loading as well as modularizes the logic, which looks good to me. I've left several questions. Would you please take a look?
Signed-off-by: yechank <[email protected]>
Signed-off-by: yechank <[email protected]>
cb38d0b to
d7d53a4
Compare
|
/bot run |
|
PR_Github #22701 [ run ] triggered by Bot. Commit: |
|
PR_Github #22701 [ run ] completed with state |
Signed-off-by: yechank <[email protected]>
|
/bot run |
|
PR_Github #22816 [ run ] triggered by Bot. Commit: |
|
PR_Github #22816 [ run ] completed with state |
…ad fix (NVIDIA#8680) Signed-off-by: yechank <[email protected]>
…ad fix (NVIDIA#8680) Signed-off-by: yechank <[email protected]>
…ad fix (NVIDIA#8680) Signed-off-by: yechank <[email protected]>
…ad fix (NVIDIA#8680) Signed-off-by: yechank <[email protected]>
Summary by CodeRabbit
Release Notes