-
Notifications
You must be signed in to change notification settings - Fork 183
[OMNIML-2182]: Add example for multinode calibration using FSDP2 #432
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
WalkthroughAdds FSDP2-aware utilities and contexts for coordinated weight updates and dtype patching, integrates Accelerator into export checkpointing, refactors qtensor compression for per-submodule handling, adds multi-node FSDP2 PTQ examples and config, introduces FSDP2 export GPU tests, updates test utilities, and fixes an NVFP4 device-casting bug. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant User
participant Accelerator
participant FSDP as FSDPModule
participant Utils as fsdp2_aware_weight_update
participant Exporter as unified_export_hf
User->>Accelerator: init (Accelerate + FSDP2)
Accelerator->>FSDP: wrap model (FSDPModule)
User->>Exporter: _export_hf_checkpoint(model, ..., accelerator=Accelerator)
Exporter->>Utils: enter fsdp2_aware_weight_update(model, modules)
Note right of Utils #D0F0C0: unshard relevant params, patch mp dtypes, map FSDP params
Utils->>FSDP: perform per-submodule fusion / fake-quant / weight update
Utils->>FSDP: optional reshard after forward or explicit reshard=False for exports
Utils-->>Exporter: exit context (restore MP dtypes, reshard if needed)
Exporter->>Accelerator: accelerator.get_state_dict(model)
Accelerator-->>Exporter: unified state_dict
Exporter-->>User: write HF checkpoint / artifacts
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #432 +/- ##
==========================================
- Coverage 73.38% 73.31% -0.08%
==========================================
Files 180 180
Lines 17986 18028 +42
==========================================
+ Hits 13199 13217 +18
- Misses 4787 4811 +24 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🧹 Nitpick comments (3)
examples/llm_ptq/README.md (1)
244-245: Hyphenate the compound modifierPlease change “user specific requirements” to “user-specific requirements” for grammatical correctness.
tests/gpu/torch/export/test_export.py (1)
383-384: Remove leftover debug printThis print statement spams GPU test logs and should be dropped now that the test is stable.
- print(f"DEBUG LOG: Scale: {scale}, Expected: {expected_amax[0] / maxbound}")examples/llm_ptq/multinode-ptq.py (1)
245-256: Clarify the comment about model usage.The comments on lines 253-255 are confusing and potentially contradictory. They state "we should forward pass using the unwrapped model" but then call
model(**batch)which is the FSDP-wrapped model.Consider revising to something like:
# For FSDP2, use the outer FSDP-wrapped model rather than unwrapped_model # mtq.quantize unwraps the model before passing it to forward_loop as unwrapped_model, # but we need the FSDP-wrapped version to properly handle DTensor model(**batch)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (9)
examples/llm_ptq/README.md(1 hunks)examples/llm_ptq/fsdp2.yaml(1 hunks)examples/llm_ptq/multinode-ptq.py(1 hunks)modelopt/torch/export/unified_export_hf.py(8 hunks)modelopt/torch/quantization/qtensor/base_qtensor.py(3 hunks)modelopt/torch/quantization/utils.py(3 hunks)tests/_test_utils/torch_export/export_utils.py(2 hunks)tests/gpu/torch/export/test_export.py(3 hunks)tests/gpu/torch/export/test_fsdp2_export.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (7)
tests/_test_utils/torch_export/export_utils.py (2)
modelopt/torch/export/model_config.py (1)
bias(153-163)modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
bias(303-307)
modelopt/torch/quantization/utils.py (2)
modelopt/torch/quantization/qtensor/base_qtensor.py (2)
QFSDPParam(139-159)QTensorWrapper(87-136)modelopt/torch/utils/network.py (1)
get_unwrapped_name(599-612)
modelopt/torch/export/unified_export_hf.py (3)
modelopt/torch/quantization/conversion.py (1)
set_quantizer_by_cfg_context(300-322)modelopt/torch/quantization/utils.py (2)
fsdp2_aware_weight_update(591-669)quantizer_attr_names(232-243)modelopt/torch/export/quant_utils.py (2)
preprocess_linear_fusion(940-1010)fuse_prequant_layernorm(926-937)
tests/gpu/torch/export/test_fsdp2_export.py (4)
tests/_test_utils/torch_export/export_utils.py (2)
SmallQKVModel(57-83)ToyModel(20-33)modelopt/torch/export/layer_utils.py (1)
is_quantlinear(346-348)modelopt/torch/export/unified_export_hf.py (2)
_export_quantized_weight(205-343)requantize_resmooth_fused_llm_layers(88-202)modelopt/torch/quantization/utils.py (2)
fsdp2_aware_weight_update(591-669)patch_fsdp_mp_dtypes(482-515)
tests/gpu/torch/export/test_export.py (3)
tests/_test_utils/torch_export/export_utils.py (1)
SmallLinearModelwithCustomWeight(36-54)modelopt/torch/quantization/nn/modules/tensor_quantizer.py (3)
TensorQuantizer(65-1111)is_enabled(395-397)maxbound(193-199)modelopt/torch/export/quant_utils.py (1)
get_scaling_factor(212-229)
modelopt/torch/quantization/qtensor/base_qtensor.py (1)
modelopt/torch/quantization/utils.py (2)
enable_fake_quant(574-587)fsdp2_aware_weight_update(591-669)
examples/llm_ptq/multinode-ptq.py (8)
examples/llm_ptq/example_utils.py (1)
apply_kv_cache_quant(245-256)modelopt/torch/export/convert_hf_config.py (1)
convert_hf_quant_config_format(21-117)modelopt/torch/export/unified_export_hf.py (1)
_export_hf_checkpoint(346-511)modelopt/torch/quantization/config.py (1)
need_calibration(1164-1190)modelopt/torch/quantization/utils.py (2)
patch_fsdp_mp_dtypes(482-515)_init_mp_dtypes(485-506)modelopt/torch/utils/dataset_utils.py (2)
get_dataset_dataloader(157-232)get_supported_datasets(235-249)modelopt/torch/opt/plugins/huggingface.py (1)
enable_huggingface_checkpointing(127-162)modelopt/torch/quantization/model_quant.py (1)
print_quant_summary(463-470)
🪛 LanguageTool
examples/llm_ptq/README.md
[grammar] ~244-~244: Use a hyphen to join words.
Context: ... provided and can be customized for user specific requirements. On each node run...
(QB_NEW_EN_HYPHEN)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: linux
🔇 Additional comments (17)
modelopt/torch/quantization/qtensor/base_qtensor.py (2)
22-22: LGTM!The import changes correctly reflect the refactor that moves helper utilities to
modelopt/torch/quantization/utils.py. The removal ofMixedPrecisionPolicyandcontextmanageris appropriate since they are now used in the utilities module.
235-252: LGTM!The refactor correctly adopts the new per-submodule processing flow using
fsdp2_aware_weight_update. This approach is cleaner than the previous FSDPParam mapping reconstruction logic and properly handles weight updates in FSDP2 contexts.examples/llm_ptq/multinode-ptq.py (4)
56-112: LGTM!The argument parsing is straightforward and correctly handles comma-separated lists for datasets and calibration sizes. The default handling for
Nonedatasets is properly addressed increate_calibration_dataloader.
115-145: LGTM!The model loading and preparation logic is correct. The dummy optimizer with
lr=0.0is a necessary workaround for FSDP2's requirement to prepare an optimizer alongside the model.
148-176: LGTM!The calibration dataloader creation is correct. Keeping data on CPU (
device=None) and handling device transfer in the calibration loop is the right approach for FSDP2 workflows.
179-223: LGTM!The quantization configuration logic is well-structured and correctly handles AWQ configuration, KV cache quantization, and model-specific adjustments. The use of
copy.deepcopyprevents unintended mutations of the base configuration.tests/gpu/torch/export/test_fsdp2_export.py (4)
35-69: LGTM!The test correctly verifies that weight updates made within the
fsdp2_aware_weight_updatecontext are properly preserved after unsharding and that the forward pass produces expected results.
72-100: LGTM!The test correctly verifies that weight compression (including dtype changes to FP8) is properly handled by the
fsdp2_aware_weight_updatecontext and preserved after unsharding.
119-214: LGTM!The tests correctly verify that FSDP2 and non-FSDP2 quantization paths produce equivalent results for layer fusion and weight export. The comparison approach using
_compare_parameters_and_buffersis appropriate and thorough.
217-282: LGTM!The parametrized tests provide comprehensive coverage of various quantization configurations for FSDP2 workflows. The use of
spawn_multiprocess_jobcorrectly simulates distributed execution.Note: Line 244 has a commented configuration (
mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG). If this is intentionally disabled due to a known issue, consider adding a comment explaining why.modelopt/torch/quantization/utils.py (7)
18-36: LGTM!The new imports are appropriate for FSDP2 support. The use of
__future__annotations andTYPE_CHECKINGguard forGeneratorfollows best practices for type hinting.
364-367: LGTM!The
_get_module_namehelper is a clean utility for locating a module's name within a root model. The None return for not-found cases is appropriate.
482-515: LGTM with note on global patching.The function correctly patches FSDP2 to handle mixed dtypes during quantization. The implementation is copied from the latest PyTorch FSDP version for forward compatibility.
Important: This is a global monkey-patch of PyTorch internals. Callers must ensure the original function is restored after use to avoid side effects on other code paths. As noted in the review of
examples/llm_ptq/multinode-ptq.py, consider implementing this as a context manager for safer usage.
518-533: LGTM!The function correctly derives full parameter names by matching parameter identities and stripping the parameter name suffix. The use of
id()for matching andrsplit(".", 1)[0]for prefix extraction is appropriate.
536-549: LGTM!The function correctly builds a mapping from full parameter names to their corresponding FSDPParam objects using the
get_prefixed_param_nameshelper.
552-587: LGTM!Both context managers are well-implemented workarounds for FSDP2 limitations:
no_requires_gradtemporarily patchesParameter.__new__to avoid errors with integer tensors during FSDP2 parameter creationenable_fake_quantprevents weight compression during unshard operationsBoth correctly save and restore original state.
590-669: LGTM!The
fsdp2_aware_weight_updatecontext manager is the core FSDP2 weight update utility and is well-implemented. It correctly:
- Locates the enclosing FSDP module for the target modules
- Unshards weights if needed (with fake_quant protection)
- Yields for weight updates
- Creates new FSDPParam/QFSDPParam with updated dtype policies
- Updates the FSDPParamGroup and reshards
The TODO at line 668 notes a potential performance optimization to conditionally reshard only when necessary. This is a reasonable optimization for future work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (5)
tests/gpu/torch/export/test_fsdp2_export.py (3)
96-104: Consider verifying shape and values after compression.The test only checks dtype but not the compressed shape or values. For a comprehensive compression test, verify that
param.data.shape == (2, 2)and optionally check that the compressed values are as expected.Apply this diff to add shape verification:
torch.distributed.barrier() model.linears.unshard() # Check if weights are as expected after unshard for param in model.parameters(): assert param.data.dtype == torch.float8_e4m3fn + assert param.data.shape == (2, 2), f"Expected shape (2, 2), got {param.data.shape}"
210-213: Remove unnecessary context manager for non-FSDP model.
fsdp2_aware_weight_updateis not needed fornon_fsdp_modelsince it's not anFSDPModule. The context manager will no-op but adds unnecessary overhead.Apply this diff:
for name, sub_module in non_fsdp_model.named_modules(): if is_quantlinear(sub_module): - with fsdp2_aware_weight_update(non_fsdp_model, sub_module): - _export_quantized_weight(sub_module, torch.float16) + _export_quantized_weight(sub_module, torch.float16)
240-262: Clarify the commented-out quantization config.
mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFGis commented out on Lines 249 and 274. Is this config known to fail, or is it pending implementation? Consider adding a comment explaining why it's disabled or opening an issue to track enabling it.Do you want me to open a new issue to track enabling this quantization config, or should we add a comment explaining why it's currently disabled?
modelopt/torch/quantization/utils.py (2)
363-367: Consider caching for large models.Building
dict(root_model.named_modules())on every call could be expensive for large models. If this function is called frequently in a loop, consider caching the mapping or passing it as a parameter.
481-520: Consider documenting the torch version this patch targets.The docstring mentions copying from "the latest version of torch FSDP" but doesn't specify which version. Consider adding the specific PyTorch version to help future maintainers understand when this patch can be removed.
Example:
- """Patch FSDP2 to handle mixed dtypes properly during quantization.""" + """Patch FSDP2 to handle mixed dtypes properly during quantization. + + This patch is based on PyTorch version X.Y.Z and can be removed once + the minimum supported PyTorch version includes this fix. + """
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
examples/llm_ptq/multinode-ptq.py(1 hunks)modelopt/torch/quantization/utils.py(3 hunks)tests/gpu/torch/export/test_fsdp2_export.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- examples/llm_ptq/multinode-ptq.py
🧰 Additional context used
🧬 Code graph analysis (2)
tests/gpu/torch/export/test_fsdp2_export.py (4)
tests/_test_utils/torch_export/export_utils.py (2)
SmallQKVModel(57-83)ToyModel(20-33)modelopt/torch/export/layer_utils.py (1)
is_quantlinear(346-348)modelopt/torch/export/unified_export_hf.py (2)
_export_quantized_weight(205-343)requantize_resmooth_fused_llm_layers(88-202)modelopt/torch/quantization/utils.py (2)
fsdp2_aware_weight_update(596-676)patch_fsdp_mp_dtypes(482-520)
modelopt/torch/quantization/utils.py (3)
modelopt/torch/utils/network.py (1)
get_unwrapped_name(599-612)modelopt/torch/trace/symbols.py (1)
named_modules(444-447)modelopt/torch/quantization/qtensor/base_qtensor.py (2)
QFSDPParam(139-159)QTensorWrapper(87-136)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (6)
tests/gpu/torch/export/test_fsdp2_export.py (2)
33-37: LGTM!The autouse fixture ensures consistent FSDP mixed precision dtype handling across all tests in this module.
40-75: LGTM!The test correctly verifies that weight updates within the
fsdp2_aware_weight_updatecontext are properly synchronized across FSDP shards.modelopt/torch/quantization/utils.py (4)
18-35: LGTM!The new imports support the FSDP2-aware utilities being added, including proper type hints with
TYPE_CHECKINGto avoid runtime overhead.
578-592: LGTM!The context manager correctly preserves and restores the
_fake_quantstate for all weight quantizers in the module hierarchy.
595-629: LGTM!The context manager entry correctly handles FSDP2-specific setup, including conditional unsharding and FSDPParam mapping creation. The validation ensures all modules belong to the same FSDP group.
630-676: Verify the selectiveinit_dtype_attrscall.Line 663 calls
init_dtype_attrsonly for non-QFSDPParaminstances. Is this intentional? IfQFSDPParam.__init__already calls this method (as suggested by line 488 in the relevant code snippets), this is correct. However, if it doesn't, or if the behavior differs, this could lead to inconsistencies.Please confirm whether
QFSDPParam.__init__already handles dtype attribute initialization, or if this selective call is the correct behavior.Additionally, the TODO on Line 675 mentions a performance optimization for conditional resharding. Consider addressing this if the unconditional
reshard()becomes a bottleneck during export.
| @contextmanager | ||
| def no_requires_grad(): | ||
| """Context manager to temporarily set requires_grad to False. | ||
| This is used to allow us to call init_sharded_parameter() on the compressed weights. Currently FSDP2 creates | ||
| a new parameter with default requires_grad and then update the requires_grad attribute as needed. This | ||
| triggers an error when torch.nn.Parameter is called on compressed weights as requires_grad cannot be set to True | ||
| for integer tensors. | ||
| """ | ||
| original_new = torch.nn.Parameter.__new__ | ||
|
|
||
| def patched_new(cls, data=None, requires_grad=True): | ||
| return original_new(cls, data, requires_grad=False) | ||
|
|
||
| torch.nn.Parameter.__new__ = patched_new | ||
| try: | ||
| yield | ||
| finally: | ||
| torch.nn.Parameter.__new__ = original_new |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Document thread-safety concerns for this global patch.
Patching torch.nn.Parameter.__new__ globally is invasive and could cause issues if other code creates Parameters concurrently. The patched version also ignores the requires_grad argument, which might surprise callers.
Consider:
- Adding a warning in the docstring about thread-safety and concurrent usage
- Documenting why this approach is necessary (i.e., FSDP2 integer tensor constraints)
- Investigating if there's a less invasive alternative
Example docstring enhancement:
@contextmanager
def no_requires_grad():
- """Context manager to temporarily set requires_grad to False.
+ """Context manager to globally patch Parameter creation to set requires_grad=False.
- This is used to allow us to call init_sharded_parameter() on the compressed weights. Currently FSDP2 creates
- a new parameter with default requires_grad and then update the requires_grad attribute as needed. This
- triggers an error when torch.nn.Parameter is called on compressed weights as requires_grad cannot be set to True
- for integer tensors.
+ **WARNING**: This patches `torch.nn.Parameter.__new__` globally and is NOT thread-safe.
+ Do not use in multi-threaded contexts or when other code might be creating Parameters.
+
+ This workaround allows calling init_sharded_parameter() on compressed integer weights.
+ FSDP2 creates Parameters with requires_grad=True by default, which errors for integer dtypes.📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| @contextmanager | |
| def no_requires_grad(): | |
| """Context manager to temporarily set requires_grad to False. | |
| This is used to allow us to call init_sharded_parameter() on the compressed weights. Currently FSDP2 creates | |
| a new parameter with default requires_grad and then update the requires_grad attribute as needed. This | |
| triggers an error when torch.nn.Parameter is called on compressed weights as requires_grad cannot be set to True | |
| for integer tensors. | |
| """ | |
| original_new = torch.nn.Parameter.__new__ | |
| def patched_new(cls, data=None, requires_grad=True): | |
| return original_new(cls, data, requires_grad=False) | |
| torch.nn.Parameter.__new__ = patched_new | |
| try: | |
| yield | |
| finally: | |
| torch.nn.Parameter.__new__ = original_new | |
| @contextmanager | |
| def no_requires_grad(): | |
| """Context manager to globally patch Parameter creation to set requires_grad=False. | |
| **WARNING**: This patches `torch.nn.Parameter.__new__` globally and is NOT thread-safe. | |
| Do not use in multi-threaded contexts or when other code might be creating Parameters. | |
| This workaround allows calling init_sharded_parameter() on compressed integer weights. | |
| FSDP2 creates Parameters with requires_grad=True by default, which errors for integer dtypes. | |
| """ | |
| original_new = torch.nn.Parameter.__new__ | |
| def patched_new(cls, data=None, requires_grad=True): | |
| return original_new(cls, data, requires_grad=False) | |
| torch.nn.Parameter.__new__ = patched_new | |
| try: | |
| yield | |
| finally: | |
| torch.nn.Parameter.__new__ = original_new |
🤖 Prompt for AI Agents
In modelopt/torch/quantization/utils.py around lines 557 to 575, the
no_requires_grad context manager globally patches torch.nn.Parameter.__new__
which is invasive and not thread-safe; update the docstring to (1) clearly warn
that this patch is global and may break concurrent code or other libraries
creating Parameters, (2) explain the specific reason it's required here (FSDP2
creates Parameters with default requires_grad and integer tensors cannot have
requires_grad=True), and (3) note that callers should avoid concurrent Parameter
creation while this context runs and that the patch ignores the requires_grad
argument on purpose. Also add a short TODO suggesting investigation of less
invasive alternatives (e.g., patching only local factories, using a wrapper
function or FSDP API) and ensure the docstring mentions that the original
__new__ is restored in the finally block.
| _export_quantized_weight(sub_module, dtype, weight_name) | ||
|
|
||
| quantized_state_dict = model.state_dict() | ||
| if accelerator is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we infer this flag by looking at the dist ranks etc?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can. Is it to avoid passing accelerator? In which case we would still need accelerator to gather all the local state dicts and get the full state dict.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (1)
modelopt/torch/export/unified_export_hf.py (1)
29-29: Avoid hard import of accelerate; make Accelerator optional and lazyImporting Accelerator at module scope forces a hard dependency on accelerate and breaks environments that don’t use FSDP2. Also, annotating with Accelerator requires accelerate at import time. Make the import lazy and relax the type hint to Any, and guard usage.
Apply this diff:
@@ -import torch.nn as nn -from accelerate import Accelerator +import torch.nn as nn @@ -from modelopt.torch.quantization.utils import fsdp2_aware_weight_update, quantizer_attr_names +from modelopt.torch.quantization.utils import fsdp2_aware_weight_update, quantizer_attr_names @@ -def _export_hf_checkpoint( - model: nn.Module, - dtype: torch.dtype | None = None, - accelerator: Accelerator | None = None, -) -> tuple[dict[str, Any], dict[str, Any]]: +def _export_hf_checkpoint( + model: nn.Module, + dtype: torch.dtype | None = None, + accelerator: Any | None = None, +) -> tuple[dict[str, Any], dict[str, Any]]: @@ - if accelerator is not None: - # Gather state_dict from all ranks - quantized_state_dict = accelerator.get_state_dict(model) + if accelerator is not None: + # Ensure accelerate is available when requested + try: + from accelerate import Accelerator as _Accelerator # type: ignore + except Exception as e: # pragma: no cover + raise ImportError( + "accelerate must be installed to export from an FSDP2-wrapped model." + ) from e + # Optionally sync before gathering to ensure all ranks finished modification + if hasattr(accelerator, "wait_for_everyone"): + accelerator.wait_for_everyone() + # Gather state_dict from all ranks + quantized_state_dict = accelerator.get_state_dict(model) else: quantized_state_dict = model.state_dict()Also applies to: 347-350, 496-501
🧹 Nitpick comments (4)
modelopt/torch/export/unified_export_hf.py (1)
118-120: Reduce repeated unshard/reshard by batching fsdp2_aware_weight_update scopesEach with fsdp2_aware_weight_update may unshard/reshard the same root module repeatedly. Batch module updates per root FSDPModule to cut overhead (especially large models).
- Collect modules per root FSDP group, then wrap once per group.
- For MoE loops, accumulate per-expert modules and process in a single context where possible.
Also applies to: 166-168, 177-179, 199-201, 474-476, 493-495
examples/llm_ptq/multinode-ptq.py (3)
243-256: Parameter is unused; clarify intent and avoid confusioncalibrate ignores the unwrapped model parameter but comments are contradictory. Rename the arg to underscore and align the comment.
-def calibrate(unwrapped_model): - """Calibration loop that uses the FSDP-wrapped model.""" +def calibrate(_): + """Calibration loop that intentionally drives the FSDP-wrapped model for DTensor correctness.""" @@ - # Use outer model (FSDP-wrapped), not the parameter - # Important: We should forward pass using the unwrapped model - # mtq.quantize will unwrap the model & pass to the forward_loop + # Use the outer FSDP-wrapped model; mtq.quantize passes an unwrapped view, + # but we drive the wrapped one for DTensor correctness. model(**batch)
125-127: Fix returns docstring to match actual return valueDocstring says two items but function returns three (model, model_type, original_architectures).
- Returns: - Tuple of (prepared_model, model_type) + Returns: + Tuple of (prepared_model, model_type, original_architectures)
275-283: Synchronize ranks before file I/OAdd a barrier so rank 0 writes only after all ranks finish gathering/modification.
- post_state_dict, hf_quant_config = _export_hf_checkpoint( - model, torch.bfloat16, accelerator=accelerator - ) + post_state_dict, hf_quant_config = _export_hf_checkpoint( + model, torch.bfloat16, accelerator=accelerator + ) + if hasattr(accelerator, "wait_for_everyone"): + accelerator.wait_for_everyone()
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
examples/llm_ptq/multinode-ptq.py(1 hunks)modelopt/torch/export/unified_export_hf.py(9 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
modelopt/torch/export/unified_export_hf.py (3)
modelopt/torch/quantization/conversion.py (1)
set_quantizer_by_cfg_context(300-322)modelopt/torch/quantization/utils.py (2)
fsdp2_aware_weight_update(596-676)quantizer_attr_names(231-242)modelopt/torch/export/quant_utils.py (2)
preprocess_linear_fusion(940-1010)fuse_prequant_layernorm(926-937)
examples/llm_ptq/multinode-ptq.py (8)
examples/llm_ptq/example_utils.py (1)
apply_kv_cache_quant(245-256)modelopt/torch/export/convert_hf_config.py (1)
convert_hf_quant_config_format(21-117)modelopt/torch/export/unified_export_hf.py (1)
_export_hf_checkpoint(346-510)modelopt/torch/quantization/config.py (1)
need_calibration(1164-1190)modelopt/torch/quantization/utils.py (1)
patch_fsdp_mp_dtypes(482-520)modelopt/torch/utils/dataset_utils.py (2)
get_dataset_dataloader(157-232)get_supported_datasets(235-249)modelopt/torch/opt/plugins/huggingface.py (1)
enable_huggingface_checkpointing(127-162)modelopt/torch/quantization/model_quant.py (1)
print_quant_summary(463-470)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: build-docs
- GitHub Check: code-quality
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we rename to multinode_ptq.py instead of multinode-ptq.py
examples/llm_ptq/multinode-ptq.py
Outdated
| DataLoader for calibration | ||
| """ | ||
| if dataset_names is None: | ||
| dataset_names = ["cnn_dailymail"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
now we have nemotron_v2 post training dataset, please rebase this PR and add follow the same impl as hf_ptq.py
examples/llm_ptq/multinode-ptq.py
Outdated
| kv_cfg = getattr(mtq, KV_QUANT_CFG_CHOICES[kv_cache_qformat])["quant_cfg"] | ||
| quant_cfg = apply_kv_cache_quant(quant_cfg, kv_cfg) | ||
|
|
||
| # Model-specific adjustments |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's just remove these for simplicity
modelopt/torch/quantization/utils.py
Outdated
|
|
||
| @contextmanager | ||
| def fsdp2_aware_weight_update(root_model, modules_to_update): | ||
| """Context manager to update the FSDPParam list if an update is made to a submodule of an FSDPModule.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you document briefly what does this section code do under the hood?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the implementation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
modelopt/torch/export/unified_export_hf.py (1)
474-511: Reshard the last FSDP module before exiting the loop.
fsdp_module_to_reshardgets updated for everyFSDPModule, and the previous module is resharded immediately. The last one, however, is never resharded, so we exit the loop with that module still unsharded. That leaves the full weights resident on each rank for the remainder of the export path, blowing up memory and defeating the communication savings the optimization was supposed to preserve. Please reshard the final module after the loop executes.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
examples/llm_ptq/README.md(1 hunks)modelopt/torch/export/unified_export_hf.py(10 hunks)modelopt/torch/quantization/utils.py(3 hunks)tests/gpu/torch/export/test_fsdp2_export.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
modelopt/torch/export/unified_export_hf.py (3)
modelopt/torch/quantization/utils.py (2)
fsdp2_aware_weight_update(596-676)quantizer_attr_names(231-242)modelopt/torch/export/quant_utils.py (3)
preprocess_linear_fusion(940-1010)fuse_prequant_layernorm(926-937)get_quantization_format(421-522)modelopt/torch/export/layer_utils.py (1)
is_quantlinear(346-348)
modelopt/torch/quantization/utils.py (4)
modelopt/torch/utils/network.py (1)
get_unwrapped_name(599-612)modelopt/torch/utils/logging.py (1)
print_rank_0(92-95)modelopt/torch/trace/symbols.py (1)
named_modules(444-447)modelopt/torch/quantization/qtensor/base_qtensor.py (2)
QFSDPParam(139-159)QTensorWrapper(87-136)
tests/gpu/torch/export/test_fsdp2_export.py (5)
tests/_test_utils/torch_export/export_utils.py (2)
SmallQKVModel(57-83)ToyModel(20-33)modelopt/torch/export/layer_utils.py (1)
is_quantlinear(346-348)modelopt/torch/export/unified_export_hf.py (2)
_export_quantized_weight(210-348)requantize_resmooth_fused_llm_layers(93-207)modelopt/torch/quantization/utils.py (2)
fsdp2_aware_weight_update(596-676)patch_fsdp_mp_dtypes(482-520)modelopt/torch/quantization/qtensor/base_qtensor.py (3)
to(114-122)dim(110-112)quantize(67-76)
🪛 LanguageTool
examples/llm_ptq/README.md
[grammar] ~244-~244: Use a hyphen to join words.
Context: ... provided and can be customized for user specific requirements. On each node run...
(QB_NEW_EN_HYPHEN)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: code-quality
- GitHub Check: build-docs
examples/llm_ptq/README.md
Outdated
|
|
||
| The exported checkpoint can be deployed using TensorRT-LLM/ vLLM/ SGLang. For more details refer to the [deployment section](#deployment) of this document. | ||
|
|
||
| > *Performance Note: FSDP2 is designed for training workloads and may result in longer calibration and export times. For faster calibration, maximize the batch size based on available GPU memory.* |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you share how long does the calibration take with FSDP?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just updated the results in the PR description!
| QUANT_CFG_CHOICES: dict[str, dict[str, Any]] = { | ||
| "int8_wo": mtq.INT8_WEIGHT_ONLY_CFG, | ||
| "fp8": mtq.FP8_DEFAULT_CFG, | ||
| "int4_awq": mtq.INT4_AWQ_CFG, | ||
| "nvfp4": mtq.NVFP4_DEFAULT_CFG, | ||
| "nvfp4_awq": mtq.NVFP4_AWQ_LITE_CFG, | ||
| "nvfp4_mlp_only": mtq.NVFP4_MLP_ONLY_CFG, | ||
| } | ||
|
|
||
| KV_QUANT_CFG_CHOICES = { | ||
| "none": "none", | ||
| "fp8": "FP8_KV_CFG", | ||
| "nvfp4": "NVFP4_KV_CFG", | ||
| "nvfp4_affine": "NVFP4_AFFINE_KV_CFG", | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason why some qformat, like FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG, can not be supported compared to the hf_ptq.py?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some of my unit tests were failing with this config. I will look into this issue in a follow-up PR.
examples/llm_ptq/multinode-ptq.py
Outdated
| weight_quantizer = quant_cfg["quant_cfg"]["*weight_quantizer"] | ||
| if isinstance(weight_quantizer, list): | ||
| weight_quantizer = weight_quantizer[0] | ||
|
|
||
| if awq_block_size: | ||
| weight_quantizer["block_sizes"][-1] = awq_block_size | ||
|
|
||
| # Coarser search for certain models to avoid overflow | ||
| if qformat == "w4a8_awq" and model_type in ["gemma", "mpt"]: | ||
| quant_cfg["algorithm"] = {"method": "awq_lite", "alpha_step": 1} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does hf_ptq.py also have something similar? Maybe we can make this a utility to avoid duplication.
| # Every time we encounter a new FSDPModule, we need to reshard the previous one | ||
| if fsdp_module_to_reshard is not None: | ||
| fsdp_module_to_reshard.reshard() | ||
|
|
||
| fsdp_module_to_reshard = sub_module |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you share why do we need to reshard the previous layer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to reshard to make sure we don't hit OOM. Every time we update a weight/ run forward the full layer is materialized on the GPU. From my understanding if we don't reshard we could hit OOM.
This is a small trick I used to avoid too many reshard calls therefore speeding up export. Essentially we unshard a decoder layer -> export all linear layers -> then reshard once we are fully done with the decoder layer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, thanks for the detailed explaination. Could you also add a few lines of comment in the code to document this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
417f17a to
1378f69
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 6
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
modelopt/torch/export/unified_export_hf.py (1)
530-585: Add accelerator parameter to export_hf_checkpoint to support distributed exportsThe review comment correctly identifies that
_export_hf_checkpointaccepts anacceleratorparameter (documented in its docstring at line 366) via**kwargs, but the publicexport_hf_checkpointAPI doesn't expose or pass it through. This forces distributed export scenarios to bypass the public API and call the private helper directly, as seen inexamples/llm_ptq/multinode_ptq.py. The suggested diff properly threads the parameter through to enable proper rank-gathering in distributed setups.modelopt/torch/quantization/qtensor/base_qtensor.py (1)
235-253: Batch eligible leaf modules before resharding to avoid repeated expensive unshard/reshard cycles and assertion failuresThe current code calls
fsdp2_aware_weight_updatefor every named submodule in a loop, triggering an unshard/reshard cycle on each iteration. This is inefficient and risks assertion failures on non-leaf modules not present in the FSDP param mapping.The function signature expects a list of modules (
modules_to_update) designed for batch processing with a single reshard cycle. Filter for eligible modules (those with enabled, non-fake quantizers) upfront and pass them as a batch:- for _, submodule in fsdp_module.named_modules(): - with fsdp2_aware_weight_update(fsdp_module, submodule): - _compress_and_update_module_weight(submodule) + # Collect eligible leaf modules once to avoid repeated reshard/unshard + eligible = [] + for _, submodule in fsdp_module.named_modules(): + if ( + hasattr(submodule, "weight") + and (submodule.weight is not None and not submodule.weight.is_meta) + and hasattr(submodule, "weight_quantizer") + and submodule.weight_quantizer.is_enabled + and not getattr(submodule.weight_quantizer, "_fake_quant", False) + and submodule.weight.element_size() > 1 + ): + eligible.append(submodule) + if not eligible: + return + # Unsharded already above; batch update mapping once and reshard once at exit + with fsdp2_aware_weight_update(fsdp_module, eligible, reshard=True): + for submodule in eligible: + _compress_and_update_module_weight(submodule)Verify multinode FSDP2 export runs without "Module … not found in fsdp_param_mapping" assertions.
♻️ Duplicate comments (6)
examples/llm_ptq/fsdp2.yaml (1)
1-3: Consumer script documented; config looks soundThe top comment addresses the prior ask; the FSDP2 settings are sane for FULL_STATE_DICT export. Consider adding a brief note that
fsdp_transformer_layer_cls_to_wrapis typically overridden per model via CLI.tests/gpu/torch/export/test_export.py (1)
383-383: Remove debug print from test (already flagged earlier)The DEBUG LOG line adds noise to CI.
Apply:
- print(f"DEBUG LOG: Scale: {scale}, Expected: {expected_amax[0] / maxbound}")tests/gpu/torch/export/test_fsdp2_export.py (3)
37-37: Remove redundant import.The
fully_shardimport is already present at line 24. Per past review comments, imports should be at the top of the file.Apply this diff:
- from torch.distributed._composable.fsdp import fully_shard -
75-75: Remove redundant import.The
fully_shardimport is already present at line 24.Apply this diff:
- from torch.distributed._composable.fsdp import fully_shard -
163-165: Remove redundant imports.Both
copy(line 163) andfully_shard(line 165) are already imported at the top of the file (lines 17 and 24 respectively).Apply this diff:
- import copy - - from torch.distributed._composable.fsdp import fully_shard -modelopt/torch/quantization/utils.py (1)
564-582: Enhance docstring with thread-safety warnings and usage constraints.As noted in a past review comment, this global monkey-patch is invasive and not thread-safe. The docstring should clearly warn about concurrent usage risks and explain why this workaround is necessary (FSDP2's default Parameter creation with
requires_grad=Trueerrors on integer dtypes).Consider enhancing the docstring:
@contextmanager def no_requires_grad(): - """Context manager to temporarily set requires_grad to False. + """Context manager to globally patch Parameter creation to set requires_grad=False. - This is used to allow us to call init_sharded_parameter() on the compressed weights. Currently FSDP2 creates - a new parameter with default requires_grad and then update the requires_grad attribute as needed. This - triggers an error when torch.nn.Parameter is called on compressed weights as requires_grad cannot be set to True - for integer tensors. + **WARNING**: This patches `torch.nn.Parameter.__new__` globally and is NOT thread-safe. + Do not use in multi-threaded contexts or when other code might be creating Parameters. + + This workaround allows calling init_sharded_parameter() on compressed integer weights. + FSDP2 creates Parameters with requires_grad=True by default, which errors for integer dtypes. + The patch ignores the requires_grad argument and always sets it to False. + + The original __new__ is restored in the finally block. """
🧹 Nitpick comments (12)
modelopt/torch/quantization/qtensor/nvfp4_tensor.py (1)
214-217: Consider adding defensive device cast for completeness.While the current fix at line 85 handles most cases, line 216 could benefit from a similar defensive cast to ensure
weights_scaling_factor_2is on the same device asinput. This would guard against edge cases where both scaling factors are provided as parameters on different devices.# Scale weights - scaled_weight = input / ( - (weights_scaling_factor.to(torch.float32) * weights_scaling_factor_2).unsqueeze(-1) - ) + scaled_weight = input / ( + (weights_scaling_factor.to(torch.float32) * weights_scaling_factor_2.to(input.device)).unsqueeze(-1) + )examples/llm_ptq/multinode_ptq.py (5)
112-116: Normalize calib_size length for all dataset casesCurrently the length adjustment runs only when
args.datasetis None. Users passing multiple datasets will hit an assertion inget_dataset_dataloaderif lengths differ. Normalize unconditionally:- # Set default dataset if not provided - if args.dataset is None: - args.dataset = ["cnn_dailymail", "nemotron-post-training-dataset-v2"] - warnings.warn( - "No dataset specified. Defaulting to cnn_dailymail and nemotron-post-training-dataset-v2." - ) - # Adjust calib_size to match dataset length by extending or truncating as needed - args.calib_size = (args.calib_size + [args.calib_size[-1]] * len(args.dataset))[ - : len(args.dataset) - ] + # Defaults + if args.dataset is None: + args.dataset = ["cnn_dailymail", "nemotron-post-training-dataset-v2"] + warnings.warn("No dataset specified. Defaulting to cnn_dailymail and nemotron-post-training-dataset-v2.") + # Adjust calib_size to match dataset length either way + args.calib_size = (args.calib_size + [args.calib_size[-1]] * len(args.dataset))[: len(args.dataset)]Also applies to: 289-299
182-214: Calibration loop: clarify param usage and add inference_mode
- The
unwrapped_modelparameter is unused, and the comment is contradictory. Use a throwaway name and addtorch.inference_mode()to avoid accidental grads.-def create_fsdp2_calibration_loop( +def create_fsdp2_calibration_loop( model: nn.Module, dataloader: torch.utils.data.DataLoader, accelerator: Accelerator, ): @@ - def calibrate(unwrapped_model): - """Calibration loop that uses the FSDP-wrapped model.""" - for batch in tqdm(dataloader, desc="Calibrating"): + def calibrate(_): + """Calibration loop that runs the FSDP-wrapped model.""" + with torch.inference_mode(): + for batch in tqdm(dataloader, desc="Calibrating"): ... - # Use outer model (FSDP-wrapped), not the parameter - # Important: We should forward pass using the unwrapped model - # mtq.quantize will unwrap the model & pass to the forward_loop + # Use outer FSDP-wrapped model; mtq may pass an unwrapped handle here. model(**batch)
233-246: Export dtype: prefer model’s native dtype unless overriddenFor robustness across models, let
_export_hf_checkpointdefault tomodel.config.torch_dtype:- post_state_dict, hf_quant_config = _export_hf_checkpoint( - model, torch.bfloat16, accelerator=accelerator - ) + post_state_dict, hf_quant_config = _export_hf_checkpoint( + model, accelerator=accelerator + )
237-259: Optional: add a barrier after main‑process writeTo avoid any chance of other ranks reading a partially written file tree during subsequent steps, insert:
if accelerator.is_main_process: ... with open(original_config, "w") as file: json.dump(config_data, file, indent=4) + accelerator.wait_for_everyone()
349-351: NIT: remove stale log“Unpatching FSDP2 MP dtypes” is handled by the surrounding context manager; this print is misleading. Safe to drop.
examples/llm_ptq/README.md (1)
244-246: HyphenationUse “user‑specific” (hyphenated).
Apply:
-... customized for user specific requirements. +... customized for user-specific requirements.examples/llm_ptq/hf_ptq.py (1)
451-451: LGTM: single source of truth for quant_cfgUsing
build_quant_cfg(...)here mirrors the multi‑node flow; reduces drift.Consider adopting the same helper in
multinode_ptq.py(already used) and ensuring both scripts share the exact qformat set.examples/llm_ptq/example_utils.py (1)
64-73: KV‑cache application assumes quant_cfg value and attribute presencegetattr(mtq, kv_quant_cfg_choices[...]) must exist and be a dict with "quant_cfg". Add a defensive check to raise a clear error if missing, or default to no‑op.
tests/_test_utils/torch_export/export_utils.py (1)
57-84: SmallQKVModel ‘device’ attr isn’t used to place parameters/tensorsself.device is stored but modules aren’t moved. Either remove the attribute or move buffers/params to it in init.
modelopt/torch/export/unified_export_hf.py (1)
171-174: fsdp2_aware_weight_update usage looks correct; consider grouping modules to reduce reshard churnThe contexts are applied per group (MoE experts, fused linears, layernorm, experts’ gate/up/down) — good. For large layers you can optionally pass lists (you already do in some paths) and specify reshard=False until group complete to minimize comms.
Also applies to: 182-184, 204-206, 490-511
tests/gpu/torch/export/test_fsdp2_export.py (1)
205-208: Remove unnecessary context manager usage on non-FSDP model.The
fsdp2_aware_weight_updatecontext manager is a no-op when used on non-FSDP models (it only activates ifisinstance(root_model, FSDPModule)per line 622 inmodelopt/torch/quantization/utils.py). This usage is confusing and should be removed.Apply this diff:
for name, sub_module in non_fsdp_model.named_modules(): if is_quantlinear(sub_module): - with fsdp2_aware_weight_update(non_fsdp_model, sub_module): - _export_quantized_weight(sub_module, torch.float16) + _export_quantized_weight(sub_module, torch.float16)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (13)
CHANGELOG.rst(1 hunks)examples/llm_ptq/README.md(1 hunks)examples/llm_ptq/example_utils.py(2 hunks)examples/llm_ptq/fsdp2.yaml(1 hunks)examples/llm_ptq/hf_ptq.py(2 hunks)examples/llm_ptq/multinode_ptq.py(1 hunks)modelopt/torch/export/unified_export_hf.py(10 hunks)modelopt/torch/quantization/qtensor/base_qtensor.py(3 hunks)modelopt/torch/quantization/qtensor/nvfp4_tensor.py(1 hunks)modelopt/torch/quantization/utils.py(3 hunks)tests/_test_utils/torch_export/export_utils.py(2 hunks)tests/gpu/torch/export/test_export.py(3 hunks)tests/gpu/torch/export/test_fsdp2_export.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (10)
tests/_test_utils/torch_export/export_utils.py (4)
modelopt/torch/export/model_config.py (1)
bias(153-163)modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
bias(303-307)forward(847-946)modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py (1)
forward(141-163)modelopt/torch/quantization/backends/nvfp4_gemm.py (1)
forward(135-150)
examples/llm_ptq/hf_ptq.py (1)
examples/llm_ptq/example_utils.py (1)
build_quant_cfg(42-85)
tests/gpu/torch/export/test_export.py (3)
tests/_test_utils/torch_export/export_utils.py (1)
SmallLinearModelwithCustomWeight(36-54)modelopt/torch/quantization/nn/modules/tensor_quantizer.py (3)
TensorQuantizer(65-1111)is_enabled(395-397)maxbound(193-199)modelopt/torch/export/quant_utils.py (1)
get_scaling_factor(212-229)
examples/llm_ptq/example_utils.py (1)
modelopt/torch/export/model_config.py (1)
awq_block_size(289-294)
modelopt/torch/quantization/utils.py (2)
modelopt/torch/utils/network.py (1)
get_unwrapped_name(599-612)modelopt/torch/quantization/qtensor/base_qtensor.py (2)
QFSDPParam(139-159)QTensorWrapper(87-136)
modelopt/torch/quantization/qtensor/nvfp4_tensor.py (2)
modelopt/torch/export/model_config.py (1)
weights_scaling_factor_2(237-266)modelopt/torch/quantization/qtensor/base_qtensor.py (1)
to(114-122)
modelopt/torch/quantization/qtensor/base_qtensor.py (1)
modelopt/torch/quantization/utils.py (2)
enable_fake_quant(586-599)fsdp2_aware_weight_update(603-699)
modelopt/torch/export/unified_export_hf.py (3)
modelopt/torch/quantization/utils.py (2)
fsdp2_aware_weight_update(603-699)quantizer_attr_names(231-242)modelopt/torch/export/quant_utils.py (3)
preprocess_linear_fusion(949-1019)fuse_prequant_layernorm(935-946)get_quantization_format(430-531)modelopt/torch/export/layer_utils.py (1)
is_quantlinear(346-348)
tests/gpu/torch/export/test_fsdp2_export.py (4)
tests/_test_utils/torch_export/export_utils.py (2)
SmallQKVModel(57-83)ToyModel(20-33)modelopt/torch/export/layer_utils.py (1)
is_quantlinear(346-348)modelopt/torch/export/unified_export_hf.py (2)
_export_quantized_weight(210-348)requantize_resmooth_fused_llm_layers(93-207)modelopt/torch/quantization/utils.py (2)
fsdp2_aware_weight_update(603-699)patch_fsdp_mp_dtypes(485-527)
examples/llm_ptq/multinode_ptq.py (6)
examples/llm_ptq/example_utils.py (2)
build_quant_cfg(42-85)get_tokenizer(95-112)modelopt/torch/export/unified_export_hf.py (1)
_export_hf_checkpoint(351-527)modelopt/torch/quantization/config.py (1)
need_calibration(1164-1190)modelopt/torch/quantization/utils.py (1)
patch_fsdp_mp_dtypes(485-527)modelopt/torch/utils/dataset_utils.py (1)
get_dataset_dataloader(171-246)modelopt/torch/opt/plugins/huggingface.py (1)
enable_huggingface_checkpointing(127-162)
🪛 LanguageTool
examples/llm_ptq/README.md
[grammar] ~244-~244: Use a hyphen to join words.
Context: ... provided and can be customized for user specific requirements. On each node run...
(QB_NEW_EN_HYPHEN)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (13)
modelopt/torch/quantization/qtensor/nvfp4_tensor.py (1)
84-86: LGTM! Device consistency fix for multi-node scenarios.The explicit device cast ensures
weights_scaling_factor_2matches the device ofper_block_amax, preventing device mismatch errors in FSDP2 distributed training. This is critical for the multi-node quantization workflow.CHANGELOG.rst (1)
14-14: Link/anchor sanity checkLooks good. Please double-check the README anchor remains stable before release branching.
examples/llm_ptq/hf_ptq.py (1)
25-33: Good refactor: centralizing quant_cfg constructionImporting and using
build_quant_cfgreduces duplication and keeps behavior consistent.Please confirm
example_utils.build_quant_cfgremains backward compatible with all qformats previously supported by this script.modelopt/torch/export/unified_export_hf.py (1)
30-36: Good: accelerate import is now optionalLazy import avoids hard dependency for non‑FSDP2 users. LGTM.
Confirm that environments without accelerate still run export paths and tests unaffected.
tests/gpu/torch/export/test_export.py (1)
309-311: LGTM on model switch in testsUsing SmallLinearModelwithCustomWeight matches the renamed test utility and keeps semantics intact.
Also applies to: 378-380
tests/gpu/torch/export/test_fsdp2_export.py (3)
93-101: Verify shape change in addition to dtype.The weight is changed from shape
(6, 6)to(2, 2)at lines 93-95, but the assertion at line 101 only verifies the dtype change. Consider also asserting the shape to ensure the compression properly updated the weight dimensions.for param in model.parameters(): assert param.data.dtype == torch.float8_e4m3fn + assert param.data.shape == (2, 2)
104-118: LGTM!The helper function correctly compares parameters and buffers between two models using bfloat16 conversion, which is appropriate for quantized model comparisons.
235-247: Document or remove commented test case.The test case
mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFGis commented out at line 244 (and again at line 267). Please either remove this line if the test case is not needed, or add a TODO/NOTE comment explaining why it's disabled.Example:
mtq.W4A8_AWQ_BETA_CFG, - # mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG, + # TODO: Enable once FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG is supported in multi-node setup + # mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.W4A8_MXFP4_FP8_CFG,modelopt/torch/quantization/utils.py (5)
18-34: LGTM!The addition of future annotations, TYPE_CHECKING, and proper type hints follows Python best practices for type annotations.
366-369: LGTM!The helper function efficiently retrieves the module name using a generator expression with
next().
530-561: LGTM!Both
get_prefixed_param_namesandcreate_fsdp_param_mappingare well-implemented with clear documentation and efficient use of generator expressions.
585-599: LGTM!The
enable_fake_quantcontext manager correctly stores and restores the_fake_quantattribute. The use of a list with pop() is safe becausemodule.modules()returns modules in a consistent order.
602-699: LGTM!The
fsdp2_aware_weight_updatecontext manager provides comprehensive orchestration for FSDP2-aware weight updates, including:
- Proper unshard/reshard lifecycle management
- FSDPParam mapping updates with correct dtype policies
- Intelligent selection between QFSDPParam and FSDPParam based on weight type
- Cleanup of old hook handles to prevent memory leaks
The implementation correctly handles both FSDP and non-FSDP models by checking
isinstance(root_model, FSDPModule)before applying FSDP-specific logic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
examples/llm_ptq/example_utils.py (1)
42-85: Fix auto-quant path, avoid shared-state mutation, and harden AWQ/KV guards.
- Auto path: when args.auto_quantize_bits=True, the function returns an empty {} and skips KV-cache quant + model-specific toggles. Implement explicit else branch that at minimum applies KV-cache quant and per-model disables.
- Shared mutation: for non‑AWQ formats you assign quant_cfg = quant_cfg_choices[...], then mutate it (KV, phi4mm toggles). Always deep‑copy, not only for AWQ.
- Robustness: guard args.awq_block_size (may be absent) and verify weight_quantizer["block_sizes"] has key -1 before assignment; otherwise KeyError. Also validate kv_cache_qformat mapping and mtq attribute existence before getattr.
- Note: keeping Q/K/V AWQ block sizes aligned avoids downstream assert in modelopt/torch/export/model_config.py Lines 288‑293. [Relevant snippet referenced]
As echoed in a previous review on this function; this comment consolidates the unresolved parts.
Apply:
@@ -def build_quant_cfg(args, model_type, quant_cfg_choices, kv_quant_cfg_choices): - quant_cfg = {} - if not hasattr(args, "auto_quantize_bits") or not args.auto_quantize_bits: +def build_quant_cfg(args, model_type, quant_cfg_choices, kv_quant_cfg_choices): + quant_cfg = {} + auto = getattr(args, "auto_quantize_bits", False) + if not auto: assert args.qformat in quant_cfg_choices, ( f"Unsupported quantization format: {args.qformat} with {args.kv_cache_qformat} KV cache" ) - - quant_cfg = quant_cfg_choices[args.qformat] + # Always work on a private copy to avoid mutating shared choices + quant_cfg = copy.deepcopy(quant_cfg_choices[args.qformat]) @@ - if "awq" in args.qformat: - quant_cfg = copy.deepcopy(quant_cfg_choices[args.qformat]) + if "awq" in args.qformat: weight_quantizer = quant_cfg["quant_cfg"]["*weight_quantizer"] if isinstance(weight_quantizer, list): - weight_quantizer = weight_quantizer[0] + weight_quantizer = weight_quantizer[0] # If awq_block_size argument is provided, update weight_quantizer - if args.awq_block_size: - weight_quantizer["block_sizes"][-1] = args.awq_block_size + awq_bs = getattr(args, "awq_block_size", None) + if awq_bs: + bs = weight_quantizer.get("block_sizes") + if isinstance(bs, dict) and (-1 in bs): + bs[-1] = awq_bs + else: + # Leave unchanged if structure is absent/mismatched + pass @@ - enable_quant_kv_cache = args.kv_cache_qformat != "none" - print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization") + enable_quant_kv_cache = getattr(args, "kv_cache_qformat", "none") != "none" + print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization") @@ - if enable_quant_kv_cache: - quant_cfg = apply_kv_cache_quant( - quant_cfg, - getattr(mtq, kv_quant_cfg_choices[args.kv_cache_qformat])["quant_cfg"], - ) + if enable_quant_kv_cache: + kv_key = args.kv_cache_qformat + assert kv_key in kv_quant_cfg_choices, f"Unsupported KV cache quant format: {kv_key}" + kv_attr = kv_quant_cfg_choices[kv_key] + kv_cfg_holder = getattr(mtq, kv_attr, None) + if kv_cfg_holder is None or "quant_cfg" not in kv_cfg_holder: + raise ValueError(f"KV cache quant cfg '{kv_attr}' not found in mtq or missing 'quant_cfg'") + quant_cfg = apply_kv_cache_quant(quant_cfg, kv_cfg_holder["quant_cfg"]) @@ - if model_type == "phi4mm": + if model_type == "phi4mm": # Only quantize the language model quant_cfg["quant_cfg"]["*speech*"] = {"enable": False} quant_cfg["quant_cfg"]["*audio*"] = {"enable": False} quant_cfg["quant_cfg"]["*image*"] = {"enable": False} quant_cfg["quant_cfg"]["*vision*"] = {"enable": False} - - return quant_cfg + else: + # Auto-quant path: still allow KV-cache quantization and per-model toggles + enable_quant_kv_cache = getattr(args, "kv_cache_qformat", "none") != "none" + print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization") + if enable_quant_kv_cache: + kv_key = args.kv_cache_qformat + assert kv_key in kv_quant_cfg_choices, f"Unsupported KV cache quant format: {kv_key}" + kv_attr = kv_quant_cfg_choices[kv_key] + kv_cfg_holder = getattr(mtq, kv_attr, None) + if kv_cfg_holder is None or "quant_cfg" not in kv_cfg_holder: + raise ValueError(f"KV cache quant cfg '{kv_attr}' not found in mtq or missing 'quant_cfg'") + quant_cfg = apply_kv_cache_quant(quant_cfg, kv_cfg_holder["quant_cfg"]) + if model_type == "phi4mm": + quant_cfg.setdefault("quant_cfg", {"default": {"enable": False}}) + quant_cfg["quant_cfg"]["*speech*"] = {"enable": False} + quant_cfg["quant_cfg"]["*audio*"] = {"enable": False} + quant_cfg["quant_cfg"]["*image*"] = {"enable": False} + quant_cfg["quant_cfg"]["*vision*"] = {"enable": False} + return quant_cfgRun to validate assumptions and spot potential breakages:
#!/bin/bash set -euo pipefail echo "Call sites of build_quant_cfg:" rg -n -C2 '\bbuild_quant_cfg\s*\(' examples/llm_ptq || true echo echo "Where auto_quantize_bits is referenced (ensure CLI plumbs it):" rg -n -C2 '\bauto_quantize_bits\b' || true echo echo "KV cfg mapping definitions near callers:" rg -n -C3 'KV_QUANT_CFG_CHOICES\s*=' examples/llm_ptq || true echo echo "AWQ configs mentioning block_sizes (to sanity-check structure):" rg -n -C1 'block_sizes' examples/llm_ptq || true
🧹 Nitpick comments (1)
examples/llm_ptq/example_utils.py (1)
64-66: Prefer logging over print for user-facing toggles.Switch to a module logger to control verbosity across multi-node runs.
+import logging +logger = logging.getLogger(__name__) @@ - print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization") + logger.info("%s KV cache quantization", "Enable" if enable_quant_kv_cache else "Disable")
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
examples/llm_ptq/example_utils.py(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
examples/llm_ptq/example_utils.py (1)
modelopt/torch/export/model_config.py (1)
awq_block_size(289-294)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
🔇 Additional comments (2)
examples/llm_ptq/example_utils.py (2)
16-16: LGTM on import usage.copy is used for AWQ deep-copy (and should be used more broadly per fix below).
36-36: LGTM on mtq import.Explicit aliasing (mtq) reads clearly and matches usage.
examples/llm_ptq/example_utils.py
Outdated
| SPECULATIVE_MODEL_LIST = ["Eagle", "Medusa"] | ||
|
|
||
|
|
||
| def build_quant_cfg(args, model_type, quant_cfg_choices, kv_quant_cfg_choices): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we avoid a blank args? Please explicitly list the args out
| mtq.NVFP4_DEFAULT_CFG, | ||
| mtq.NVFP4_AWQ_LITE_CFG, | ||
| mtq.W4A8_AWQ_BETA_CFG, | ||
| # mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please add a TODO, same below
| grouped_experts = get_experts_list(module, model_type) | ||
| for modules in grouped_experts: | ||
| preprocess_linear_fusion(modules, resmooth_only=True) | ||
| with fsdp2_aware_weight_update(model, modules): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Say if we are going to add a new processing. And user forget to add this fsdp2_aware_weight_update line. What will be the error that the user will see?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There wouldn't be an error. But the weight changes will be lost once the module is resharded.
FSDP2 module weights as FSDPParams which has the sharded DTensor and the full tensor which it keeps track of. When we update module.weight it uses the full tensor in FSDPParam .
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
modelopt/torch/export/unified_export_hf.py (1)
526-556: Expose accelerator in public API and pass throughAllow callers to use the stable public API instead of importing the private helper.
-def export_hf_checkpoint( - model: nn.Module, - dtype: torch.dtype | None = None, - export_dir: Path | str = tempfile.gettempdir(), - save_modelopt_state: bool = False, -): +def export_hf_checkpoint( + model: nn.Module, + dtype: torch.dtype | None = None, + export_dir: Path | str = tempfile.gettempdir(), + save_modelopt_state: bool = False, + *, + accelerator: Any | None = None, +): @@ - post_state_dict, hf_quant_config = _export_hf_checkpoint(model, dtype) + post_state_dict, hf_quant_config = _export_hf_checkpoint(model, dtype, accelerator=accelerator)I can propagate this change to the multinode script to stop using the private function; say the word.
♻️ Duplicate comments (3)
tests/gpu/torch/export/test_fsdp2_export.py (1)
217-233: Auto‑skip on insufficient GPUsPrevious feedback suggested a fixture for skipping when <2 GPUs. If available, please adopt it for these multiprocess tests to avoid spurious CI failures.
If you want, I can wire a
need_2_gpusfixture and parametrize over it.Also applies to: 249-256, 272-278
examples/llm_ptq/example_utils.py (1)
57-61: Avoid mutating shared quant configs; deep‑copy alwaysAssigning
quant_cfg = quant_cfg_choices[qformat]and then mutating it (KV cache, AWQ tweaks) will corrupt the shared config objects. Always deep‑copy regardless of format.- quant_cfg = quant_cfg_choices[qformat] + # Always work on a private copy to avoid mutating shared constants + quant_cfg = copy.deepcopy(quant_cfg_choices[qformat])Also remove the redundant deep‑copy in the AWQ branch below (kept by the above change).
modelopt/torch/export/unified_export_hf.py (1)
509-514: Do not silently export incomplete state_dict when FSDP is detectedIf any FSDP modules were seen, exporting without an Accelerator will miss sharded params. Raise a clear error instead of silently falling back.
Also reshard the last FSDP module before gathering.
- if accelerator is not None: - # Gather state_dict from all ranks - quantized_state_dict = accelerator.get_state_dict(model) - else: - quantized_state_dict = model.state_dict() + # Reshard the last FSDP module, if any + if fsdp_module_to_reshard is not None: + fsdp_module_to_reshard.reshard() + + if accelerator is not None: + # Gather state_dict from all ranks + quantized_state_dict = accelerator.get_state_dict(model) + else: + if fsdp_module_to_reshard is not None: + raise ImportError( + "FSDP2-wrapped model detected: pass an `accelerator` to export for correct gathering." + ) + quantized_state_dict = model.state_dict()
🧹 Nitpick comments (8)
examples/llm_ptq/README.md (2)
244-247: Hyphenation nit: “user‑specific”Use a hyphen: “can be customized for user‑specific requirements.”
248-264: Keep README options in sync with the scriptConsider either:
- Listing all supported qformats from the script (int8, int4_awq, fp8, nvfp4, nvfp4_awq, w4a8_awq, fp8_pb_wo, w4a8_mxfp4_fp8, nvfp4_mlp_only), or
- Add “see script for full list of supported qformats.”
Also consider showing the optional flag present in the script:
- Append “--awq_block_size ” when demonstrating AWQ.
Would you like me to open a small PR update to expand the options and add --awq_block_size here?
tests/gpu/torch/export/test_fsdp2_export.py (2)
163-165: Remove redundant local import
copyis already imported at module scope; drop the inner import.- import copy - from torch.distributed._composable.fsdp import fully_shard
200-209: Speed up test by avoiding per‑layer reshard during exportUse
reshard=Falsein the context to match the production path optimization and reduce test runtime.- with fsdp2_aware_weight_update(model, sub_module): + with fsdp2_aware_weight_update(model, sub_module, reshard=False): _export_quantized_weight(sub_module, torch.float16)examples/llm_ptq/example_utils.py (3)
59-67: Harden AWQ block_size editGuard for expected structure to avoid KeyError on configs without [-1] in block_sizes.
- if "awq" in qformat: - quant_cfg = copy.deepcopy(quant_cfg_choices[qformat]) + if "awq" in qformat: weight_quantizer = quant_cfg["quant_cfg"]["*weight_quantizer"] if isinstance(weight_quantizer, list): weight_quantizer = weight_quantizer[0] # If awq_block_size argument is provided, update weight_quantizer if awq_block_size: - weight_quantizer["block_sizes"][-1] = awq_block_size + bs = weight_quantizer.get("block_sizes", {}) + if isinstance(bs, dict) and (-1 in bs or not bs): + bs[-1] = awq_block_size + weight_quantizer["block_sizes"] = bs
82-85: Condition will never match here; scope correctlyThis branch checks
"int8_sq" in qformat, butqformatis a single token andQUANT_CFG_CHOICESin the multinode flow doesn’t includeint8_sq. Either:
- Move this tweak to the place where
int8_sqis actually supported, or- Guard by choices:
if model_type == "gemma" and qformat == "int8_sq" and "int8_sq" in quant_cfg_choices: ...- if model_type == "gemma" and "int8_sq" in qformat: + if model_type == "gemma" and qformat == "int8_sq" and "int8_sq" in quant_cfg_choices: quant_cfg["algorithm"] = {"method": "smoothquant", "alpha": 0.5}
51-93: Optional: handle auto‑quant path explicitlyWhen
auto_quantizeis True, this returns{}. Ifbuild_quant_cfgis reused by auto‑quant users, consider returning a minimal config and applying KV‑cache quant viaapply_kv_cache_quantsimilarly to the non‑auto path.I can draft the small else branch if you plan to reuse this for auto‑quant flows.
examples/llm_ptq/multinode_ptq.py (1)
319-326: Pass explicit boolean for auto_quantize flagUse False instead of None for clarity.
- None, + False,
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
examples/llm_ptq/README.md(1 hunks)examples/llm_ptq/example_utils.py(2 hunks)examples/llm_ptq/hf_ptq.py(2 hunks)examples/llm_ptq/multinode_ptq.py(1 hunks)modelopt/torch/export/unified_export_hf.py(10 hunks)tests/gpu/torch/export/test_fsdp2_export.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- examples/llm_ptq/hf_ptq.py
🧰 Additional context used
🧬 Code graph analysis (4)
modelopt/torch/export/unified_export_hf.py (3)
modelopt/torch/quantization/utils.py (2)
fsdp2_aware_weight_update(603-699)quantizer_attr_names(231-242)modelopt/torch/export/quant_utils.py (3)
preprocess_linear_fusion(949-1019)fuse_prequant_layernorm(935-946)get_quantization_format(430-531)modelopt/torch/export/layer_utils.py (1)
is_quantlinear(346-348)
tests/gpu/torch/export/test_fsdp2_export.py (4)
tests/_test_utils/torch_export/export_utils.py (2)
SmallQKVModel(57-83)ToyModel(20-33)modelopt/torch/export/layer_utils.py (1)
is_quantlinear(346-348)modelopt/torch/export/unified_export_hf.py (2)
_export_quantized_weight(205-343)requantize_resmooth_fused_llm_layers(88-202)modelopt/torch/quantization/utils.py (2)
fsdp2_aware_weight_update(603-699)patch_fsdp_mp_dtypes(485-527)
examples/llm_ptq/example_utils.py (3)
modelopt/torch/export/model_config.py (1)
awq_block_size(289-294)examples/llm_ptq/hf_ptq.py (1)
auto_quantize(94-152)modelopt/torch/quantization/model_quant.py (1)
auto_quantize(234-450)
examples/llm_ptq/multinode_ptq.py (8)
examples/llm_ptq/example_utils.py (2)
build_quant_cfg(42-93)get_tokenizer(103-120)modelopt/torch/export/convert_hf_config.py (1)
convert_hf_quant_config_format(21-117)modelopt/torch/export/unified_export_hf.py (1)
_export_hf_checkpoint(346-523)modelopt/torch/quantization/config.py (1)
need_calibration(1164-1190)modelopt/torch/quantization/utils.py (1)
patch_fsdp_mp_dtypes(485-527)modelopt/torch/utils/dataset_utils.py (1)
get_dataset_dataloader(171-246)modelopt/torch/opt/plugins/huggingface.py (1)
enable_huggingface_checkpointing(127-162)modelopt/torch/quantization/model_quant.py (1)
print_quant_summary(463-470)
🪛 LanguageTool
examples/llm_ptq/README.md
[grammar] ~244-~244: Use a hyphen to join words.
Context: ... provided and can be customized for user specific requirements. On each node run...
(QB_NEW_EN_HYPHEN)
| from modelopt.torch.export.unified_export_hf import _export_hf_checkpoint | ||
| from modelopt.torch.quantization.config import need_calibration |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Prefer public export API over private helper
After exposing accelerator in export_hf_checkpoint, switch to the public API.
-from modelopt.torch.export.unified_export_hf import _export_hf_checkpoint
+from modelopt.torch.export.unified_export_hf import export_hf_checkpointAnd below (see next comment) call the public function.
🤖 Prompt for AI Agents
In examples/llm_ptq/multinode_ptq.py around lines 24-25, the code imports the
private helper _export_hf_checkpoint; change this to import the public
export_hf_checkpoint from modelopt.torch.export.unified_export_hf and update any
subsequent calls to use the public name. Ensure the import line uses the public
function and that later code calls export_hf_checkpoint with the same arguments
(including accelerator) instead of the underscored helper.
| post_state_dict, hf_quant_config = _export_hf_checkpoint( | ||
| model, torch.bfloat16, accelerator=accelerator | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Call the public API with accelerator
Use export_hf_checkpoint(..., accelerator=accelerator) once the public function accepts it.
- post_state_dict, hf_quant_config = _export_hf_checkpoint(
- model, torch.bfloat16, accelerator=accelerator
- )
+ post_state_dict, hf_quant_config = export_hf_checkpoint(
+ model, torch.bfloat16, export_dir=export_dir, save_modelopt_state=False, accelerator=accelerator
+ )
+ # Note: if you let export_hf_checkpoint write files directly, you can simplify the save steps below.Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In examples/llm_ptq/multinode_ptq.py around lines 233 to 235, the code calls the
internal helper _export_hf_checkpoint(..., accelerator=accelerator); replace
this with the public API export_hf_checkpoint(..., accelerator=accelerator) once
that function accepts the accelerator parameter, ensure the public function is
imported at the top of the file (remove or stop using the underscored helper),
pass the same arguments including accelerator=accelerator, and update any
references or tests that expect the internal function if necessary.
| from torch.distributed.fsdp import FSDPModule | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make FSDP import robust across torch versions
Guard the import and avoid hard failures on environments without composable FSDP. Also gate isinstance checks.
-from torch.distributed.fsdp import FSDPModule
+try:
+ # Prefer composable FSDP module type when available
+ from torch.distributed._composable.fsdp import FSDPModule as _FSDPModule # PyTorch ≥2.3
+except Exception: # pragma: no cover
+ _FSDPModule = NoneAnd below, use _FSDPModule guarded checks (see further diff).
Committable suggestion skipped: line range outside the PR's diff.
| for name, sub_module in layer_pool.items(): | ||
| # Optimization to perform resharding only once per decoder layer to avoid extra communication overhead | ||
| if isinstance(sub_module, FSDPModule): | ||
| # Every time we encounter a new FSDPModule, the previous decoder layer is fully processed. | ||
| # We need to reshard the previous FSDPModule to prevent potential OOM. | ||
| # This hack reduces the number of unshard reshard operations, to avoid unnecessary communication. | ||
| if fsdp_module_to_reshard is not None: | ||
| fsdp_module_to_reshard.reshard() | ||
|
|
||
| fsdp_module_to_reshard = sub_module | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guard isinstance check to avoid NameError when FSDP is unavailable
Use the guarded _FSDPModule from the import refactor.
- if isinstance(sub_module, FSDPModule):
+ if _FSDPModule is not None and isinstance(sub_module, _FSDPModule):
# Every time we encounter a new FSDPModule, the previous decoder layer is fully processed.
# We need to reshard the previous FSDPModule to prevent potential OOM.
# This hack reduces the number of unshard reshard operations, to avoid unnecessary communication.
if fsdp_module_to_reshard is not None:
fsdp_module_to_reshard.reshard()
fsdp_module_to_reshard = sub_module🤖 Prompt for AI Agents
In modelopt/torch/export/unified_export_hf.py around lines 473 to 483, the
isinstance check uses FSDPModule which can raise NameError when FSDP isn't
available; replace that check to use the guarded import _FSDPModule (the
refactored alias) so the code safely skips FSDP-specific logic when absent, and
ensure the module imports the _FSDPModule symbol instead of FSDPModule or falls
back to None before the loop.
d51b652 to
3b80298
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (8)
tests/gpu/torch/export/test_export.py (1)
378-390: Remove debug print statement.The debug print on line 383 was flagged in a previous review as unnecessary.
Apply this diff:
for name, module in model.named_modules(): if isinstance(module, TensorQuantizer) and module.is_enabled: scale = get_scaling_factor(module) - print(f"DEBUG LOG: Scale: {scale}, Expected: {expected_amax[0] / maxbound}") assert torch.allclose( scale, torch.tensor((expected_amax[0] / maxbound), dtype=scale.dtype), rtol=1e-3, atol=1e-3, ) expected_amax.pop(0)modelopt/torch/quantization/utils.py (1)
564-582: Document thread-safety concerns and the reason for this global patch.The
no_requires_gradcontext manager globally patchestorch.nn.Parameter.__new__, which is invasive and not thread-safe. This should be clearly documented in the docstring to warn users.Based on past review feedback, enhance the docstring to warn about the global nature and thread-safety:
@contextmanager def no_requires_grad(): - """Context manager to temporarily set requires_grad to False. + """Context manager to globally patch Parameter creation to set requires_grad=False. - This is used to allow us to call init_sharded_parameter() on the compressed weights. Currently FSDP2 creates - a new parameter with default requires_grad and then update the requires_grad attribute as needed. This - triggers an error when torch.nn.Parameter is called on compressed weights as requires_grad cannot be set to True - for integer tensors. + **WARNING**: This patches `torch.nn.Parameter.__new__` globally and is NOT thread-safe. + Do not use in multi-threaded contexts or when other code might be creating Parameters concurrently. + + This workaround allows calling init_sharded_parameter() on compressed integer weights. + FSDP2 creates Parameters with requires_grad=True by default, which errors for integer dtypes. + The patch ignores the requires_grad argument and forces it to False. The original __new__ is + restored in the finally block. """examples/llm_ptq/example_utils.py (1)
42-93: Critical: Fix implicit None return and prevent shared config mutation.The function has two critical issues flagged in past reviews:
Implicit None return: When
auto_quantize=True, the function never returns (no code path after line 52's if block), yielding None to callers who then invokeneed_calibration(None)causing crashes.Shared config mutation: Line 57 uses direct assignment instead of deep copy for non-AWQ formats. Subsequent modifications (lines 77-91) then mutate the shared
quant_cfg_choicesdictionary, polluting state across multiple calls.Apply this fix:
def build_quant_cfg( qformat, kv_cache_qformat, awq_block_size, auto_quantize, model_type, quant_cfg_choices, kv_quant_cfg_choices, ): quant_cfg = {} if not auto_quantize: assert qformat in quant_cfg_choices, ( f"Unsupported quantization format: {qformat} with {kv_cache_qformat} KV cache" ) - quant_cfg = quant_cfg_choices[qformat] + # Always work on a private copy to avoid mutating shared config + quant_cfg = copy.deepcopy(quant_cfg_choices[qformat]) if "awq" in qformat: - quant_cfg = copy.deepcopy(quant_cfg_choices[qformat]) weight_quantizer = quant_cfg["quant_cfg"]["*weight_quantizer"] if isinstance(weight_quantizer, list): weight_quantizer = weight_quantizer[0] # If awq_block_size argument is provided, update weight_quantizer if awq_block_size: weight_quantizer["block_sizes"][-1] = awq_block_size # Coarser optimal scale search seems to resolve the overflow in TRT-LLM for some models if qformat == "w4a8_awq" and model_type in ["gemma", "mpt"]: quant_cfg["algorithm"] = {"method": "awq_lite", "alpha_step": 1} enable_quant_kv_cache = kv_cache_qformat != "none" print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization") # Check if any bmm_quantizer is in the quant_cfg. If so, we need to enable the bmm_quantizer. if enable_quant_kv_cache: quant_cfg = apply_kv_cache_quant( quant_cfg, getattr(mtq, kv_quant_cfg_choices[kv_cache_qformat])["quant_cfg"], ) # Gemma 7B has accuracy regression using alpha 1. We set 0.5 instead. if model_type == "gemma" and "int8_sq" in qformat: quant_cfg["algorithm"] = {"method": "smoothquant", "alpha": 0.5} if model_type == "phi4mm": # Only quantize the language model quant_cfg["quant_cfg"]["*speech*"] = {"enable": False} quant_cfg["quant_cfg"]["*audio*"] = {"enable": False} quant_cfg["quant_cfg"]["*image*"] = {"enable": False} quant_cfg["quant_cfg"]["*vision*"] = {"enable": False} - - return quant_cfg + + return quant_cfg + else: + # Auto-quantize path: return empty config or apply model-specific settings + # The actual format selection happens in auto_quantize() + if model_type == "phi4mm": + quant_cfg["quant_cfg"] = { + "default": {"enable": False}, + "*speech*": {"enable": False}, + "*audio*": {"enable": False}, + "*image*": {"enable": False}, + "*vision*": {"enable": False}, + } + return quant_cfgtests/gpu/torch/export/test_fsdp2_export.py (1)
217-233: Use a fixture to skip when <2 GPUs (already noted earlier)Adopt the repo’s multi‑GPU skip fixture to avoid spurious CI failures.
Would you like me to patch this to use a
need_2_gpusfixture and parametrize accordingly?examples/llm_ptq/multinode_ptq.py (2)
24-25: Avoid private export helper; prefer public API once accelerator is supportedUse export_hf_checkpoint(...) and pass accelerator when the public function accepts it; simplifies save flow.
-from modelopt.torch.export.unified_export_hf import _export_hf_checkpoint +from modelopt.torch.export.unified_export_hf import export_hf_checkpoint @@ - post_state_dict, hf_quant_config = _export_hf_checkpoint( - model, torch.bfloat16, accelerator=accelerator - ) + post_state_dict, hf_quant_config = export_hf_checkpoint( + model, torch.bfloat16, export_dir=export_dir, save_modelopt_state=False, accelerator=accelerator + ) + # If export_hf_checkpoint writes files directly, you can drop the manual save steps below.Also applies to: 233-235
60-76: Verify README/options consistency for qformat/KV‑cache flagsEnsure README uses the same flag names and values (e.g., --kv_cache_qformat vs any aliases like --kv_cache_quant; presence of int8_sq if documented).
I can update either the README or CLI to add aliases/choices once you confirm intended names.
modelopt/torch/export/unified_export_hf.py (2)
509-514: Don’t silently fallback when exporting FSDP models without AcceleratorFalling back to model.state_dict() produces incomplete state for sharded models. Error out if FSDP is detected but no accelerator is provided.
Apply:
- if accelerator is not None: - # Gather state_dict from all ranks - quantized_state_dict = accelerator.get_state_dict(model) - else: - quantized_state_dict = model.state_dict() + if accelerator is not None: + quantized_state_dict = accelerator.get_state_dict(model) + else: + if fsdp_module_to_reshard is not None: + raise ImportError( + "accelerate must be installed and an Accelerator instance provided " + "to export from an FSDP2-wrapped model." + ) + quantized_state_dict = model.state_dict()
473-483: Reshard the last FSDP module after the loopThe final FSDP module remains unsharded and holds extra memory. Add a trailing reshard.
for name, sub_module in layer_pool.items(): ... if _FSDPModule is not None and isinstance(sub_module, _FSDPModule): ... fsdp_module_to_reshard = sub_module - if accelerator is not None: + # Reshard the last FSDP module, if any + if fsdp_module_to_reshard is not None: + fsdp_module_to_reshard.reshard() + + if accelerator is not None: ...Also applies to: 509-513
🧹 Nitpick comments (7)
modelopt/torch/quantization/utils.py (1)
484-527: Enhance documentation for the FSDP2 dtype patch.While the docstring mentions this is copied from the latest torch FSDP repository, it would benefit from additional context explaining:
- Why this patch is necessary (relaxing uniform dtype requirement during quantization)
- The scope/lifetime of the patch (applied during context, restored after)
- Any potential side effects or limitations
Based on learnings from past reviews.
Consider enhancing the docstring:
@contextmanager def patch_fsdp_mp_dtypes(): """Patch FSDP2 to handle mixed dtypes properly during quantization. This patch is used to relax the requirement of uniform original parameter dtype in FSDP2 and is copied from the latest torch FSDP repository `torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py <https://github.com/pytorch/pytorch/blob/c40048472cc4e28f44e8e835cae319add231bf5/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py#L227>`_. + + This is necessary because during quantization workflows with FSDP2, parameters may temporarily + have non-uniform dtypes (e.g., mixing float and quantized integer tensors). The original FSDP2 + implementation enforces uniform dtypes, which would cause assertion errors. This context manager + temporarily replaces the _init_mp_dtypes method to allow non-uniform dtypes, then restores the + original implementation when exiting the context. """examples/llm_ptq/README.md (1)
238-269: Comprehensive multi-node FSDP2 documentation with minor grammar fix.The documentation clearly explains the FSDP2 multi-node setup, command usage, and deployment options. One minor grammar improvement on line 244.
Apply this diff to fix the grammar:
-For distributed execution across multiple nodes, use the `accelerate` library. A template configuration file (`fsdp2.yaml`) is provided and can be customized for user specific requirements. +For distributed execution across multiple nodes, use the `accelerate` library. A template configuration file (`fsdp2.yaml`) is provided and can be customized for user-specific requirements.tests/_test_utils/torch_export/export_utils.py (1)
57-83: Document the TODO or file an issue for FSDP2 layernorm bias issue.The TODO comment on line 68 mentions that FSDP2 modifies the bias of layernorm for AWQ. This should either be:
- Documented with more context (is this expected behavior or a bug?)
- Tracked in a separate issue
Do you want me to help create an issue to track this FSDP2/AWQ layernorm bias behavior, or should the TODO be expanded with more context about whether this is expected?
modelopt/torch/quantization/qtensor/base_qtensor.py (2)
22-22: Make FSDP import/version handling robustComposable FSDP lives under torch.distributed._composable.fsdp in recent PyTorch. Guard imports and isinstance checks to avoid hard failures on older/newer torch.
Apply:
-from torch.distributed.fsdp import FSDPModule, fully_shard +try: + # Preferred: composable FSDP (PyTorch ≥2.3) + from torch.distributed._composable.fsdp import FSDPModule as _FSDPModule, fully_shard +except Exception: # pragma: no cover + _FSDPModule = NoneAnd below:
- if isinstance(m, FSDPModule): + if _FSDPModule is not None and isinstance(m, _FSDPModule): _compress_fsdp_module(m)Also applies to: 257-259
235-236: Local import OK; add brief comment for rationaleLazy-importing utils avoids import cycles; add a one‑liner comment for future readers.
- from modelopt.torch.quantization.utils import enable_fake_quant, fsdp2_aware_weight_update + # Lazy import to avoid circulars and keep FSDP2 helpers isolated + from modelopt.torch.quantization.utils import enable_fake_quant, fsdp2_aware_weight_updatemodelopt/torch/export/unified_export_hf.py (1)
30-31: Guard FSDP import across torch versionsAvoid hard dependency on composable FSDP symbol location.
-from torch.distributed.fsdp import FSDPModule +try: + from torch.distributed._composable.fsdp import FSDPModule as _FSDPModule +except Exception: # pragma: no cover + _FSDPModule = NoneAnd where used:
- if isinstance(sub_module, FSDPModule): + if _FSDPModule is not None and isinstance(sub_module, _FSDPModule):tests/gpu/torch/export/test_fsdp2_export.py (1)
163-168: Remove duplicate local importRedundant
import copyinside function; already imported at the top.- import copy - from torch.distributed._composable.fsdp import fully_shard
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (13)
CHANGELOG.rst(1 hunks)examples/llm_ptq/README.md(1 hunks)examples/llm_ptq/example_utils.py(2 hunks)examples/llm_ptq/fsdp2.yaml(1 hunks)examples/llm_ptq/hf_ptq.py(2 hunks)examples/llm_ptq/multinode_ptq.py(1 hunks)modelopt/torch/export/unified_export_hf.py(10 hunks)modelopt/torch/quantization/qtensor/base_qtensor.py(3 hunks)modelopt/torch/quantization/qtensor/nvfp4_tensor.py(1 hunks)modelopt/torch/quantization/utils.py(3 hunks)tests/_test_utils/torch_export/export_utils.py(2 hunks)tests/gpu/torch/export/test_export.py(3 hunks)tests/gpu/torch/export/test_fsdp2_export.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- modelopt/torch/quantization/qtensor/nvfp4_tensor.py
🧰 Additional context used
🧬 Code graph analysis (9)
tests/gpu/torch/export/test_export.py (3)
tests/_test_utils/torch_export/export_utils.py (1)
SmallLinearModelwithCustomWeight(36-54)modelopt/torch/quantization/nn/modules/tensor_quantizer.py (3)
TensorQuantizer(65-1111)is_enabled(395-397)maxbound(193-199)modelopt/torch/export/quant_utils.py (1)
get_scaling_factor(212-229)
examples/llm_ptq/hf_ptq.py (1)
examples/llm_ptq/example_utils.py (1)
build_quant_cfg(42-93)
modelopt/torch/quantization/qtensor/base_qtensor.py (1)
modelopt/torch/quantization/utils.py (2)
enable_fake_quant(586-599)fsdp2_aware_weight_update(603-699)
modelopt/torch/quantization/utils.py (2)
modelopt/torch/utils/network.py (1)
get_unwrapped_name(599-612)modelopt/torch/quantization/qtensor/base_qtensor.py (2)
QFSDPParam(139-159)QTensorWrapper(87-136)
tests/_test_utils/torch_export/export_utils.py (2)
modelopt/torch/export/model_config.py (1)
bias(153-163)modelopt/torch/quantization/backends/nvfp4_gemm.py (1)
forward(135-150)
examples/llm_ptq/example_utils.py (3)
modelopt/torch/export/model_config.py (1)
awq_block_size(289-294)examples/llm_ptq/hf_ptq.py (1)
auto_quantize(94-152)modelopt/torch/quantization/model_quant.py (1)
auto_quantize(234-450)
modelopt/torch/export/unified_export_hf.py (3)
modelopt/torch/quantization/utils.py (2)
fsdp2_aware_weight_update(603-699)quantizer_attr_names(231-242)modelopt/torch/export/layer_utils.py (1)
is_quantlinear(346-348)modelopt/torch/export/unified_export_megatron.py (1)
state_dict(465-469)
examples/llm_ptq/multinode_ptq.py (8)
examples/llm_ptq/example_utils.py (2)
build_quant_cfg(42-93)get_tokenizer(103-120)modelopt/torch/export/convert_hf_config.py (1)
convert_hf_quant_config_format(21-117)modelopt/torch/export/unified_export_hf.py (1)
_export_hf_checkpoint(346-523)modelopt/torch/quantization/config.py (1)
need_calibration(1164-1190)modelopt/torch/quantization/utils.py (1)
patch_fsdp_mp_dtypes(485-527)modelopt/torch/utils/dataset_utils.py (1)
get_dataset_dataloader(171-246)modelopt/torch/opt/plugins/huggingface.py (1)
enable_huggingface_checkpointing(127-162)modelopt/torch/quantization/model_quant.py (1)
print_quant_summary(463-470)
tests/gpu/torch/export/test_fsdp2_export.py (4)
tests/_test_utils/torch_export/export_utils.py (2)
SmallQKVModel(57-83)ToyModel(20-33)modelopt/torch/export/layer_utils.py (1)
is_quantlinear(346-348)modelopt/torch/export/unified_export_hf.py (2)
_export_quantized_weight(205-343)requantize_resmooth_fused_llm_layers(88-202)modelopt/torch/quantization/utils.py (2)
fsdp2_aware_weight_update(603-699)patch_fsdp_mp_dtypes(485-527)
🪛 LanguageTool
examples/llm_ptq/README.md
[grammar] ~244-~244: Use a hyphen to join words.
Context: ... provided and can be customized for user specific requirements. On each node run...
(QB_NEW_EN_HYPHEN)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: linux
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (20)
examples/llm_ptq/fsdp2.yaml (1)
1-30: LGTM!The FSDP2 configuration is well-structured and clearly documented. The comment on line 2 appropriately addresses the file's usage, and the settings align with the multi-node PTQ workflow described in the PR objectives.
modelopt/torch/quantization/utils.py (6)
18-35: LGTM!The addition of
TYPE_CHECKING,Generator, and__future__annotations improves type hinting and code organization.
366-369: LGTM!The
_get_module_namehelper function provides a clean way to locate a module's qualified name within the root model.
530-545: LGTM!The
get_prefixed_param_namesfunction correctly derives full parameter names by checking parameter identity against the parent model.
548-561: LGTM!The
create_fsdp_param_mappingfunction provides a clean mapping from module names to FSDPParam objects.
585-599: LGTM!The
enable_fake_quantcontext manager correctly toggles the_fake_quantattribute on weight quantizers.
602-699: LGTM with excellent documentation!The
fsdp2_aware_weight_updatecontext manager is well-documented with clear parameter descriptions and behavior explanation. The implementation correctly handles unsharding, FSDPParam updates, and optional resharding.CHANGELOG.rst (1)
15-15: LGTM!The changelog entry accurately documents the new FSDP2 multi-node PTQ support and provides a helpful link to the README for details.
examples/llm_ptq/hf_ptq.py (2)
27-27: LGTM!The import of
build_quant_cfgis appropriate for the centralized config building approach.
451-459: Excellent refactoring to centralize config building.Replacing the inline quant_cfg construction with the
build_quant_cfgfunction call improves maintainability and reduces code duplication across PTQ scripts.tests/gpu/torch/export/test_export.py (2)
19-19: LGTM!The import update reflects the model refactoring in export_utils.py.
309-309: LGTM!The model instantiation correctly uses the renamed
SmallLinearModelwithCustomWeightclass.tests/_test_utils/torch_export/export_utils.py (2)
21-30: LGTM!Adding the
biasparameter toToyModelprovides useful testing flexibility.
36-54: LGTM!The renaming to
SmallLinearModelwithCustomWeightmakes the purpose of this test model clearer.examples/llm_ptq/example_utils.py (2)
16-16: LGTM!The
copyimport is necessary for deep-copying configuration dictionaries.
36-36: LGTM!The import of
modelopt.torch.quantization as mtqis necessary for accessing quantization configurations.modelopt/torch/export/unified_export_hf.py (1)
118-201: Wrap fusions in minimal scopes; tiny doc tweakThe fsdp2_aware_weight_update usage here looks correct. Add a short comment why reshard=False is used to reduce comms.
If desired, I can add a brief code comment summarizing the unshard-once/reshard-once rationale across this block.
tests/gpu/torch/export/test_fsdp2_export.py (1)
195-215: Good parity checks between sharded and non‑sharded pathsState/buffer comparisons are robust (bf16). Nice coverage.
examples/llm_ptq/multinode_ptq.py (2)
182-215: Calibration loop uses outer FSDP model intentionallyUsing the wrapped model in the closure is correct for DTensor paths.
279-287: Top‑level Accelerator import is acceptable for this exampleGiven FSDP2 is required here, the hard dependency is fine.
| for _, submodule in fsdp_module.named_modules(): | ||
| with fsdp2_aware_weight_update(fsdp_module, submodule): | ||
| _compress_and_update_module_weight(submodule) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don’t wrap every submodule; filter and reshard once
Calling fsdp2_aware_weight_update on all named_modules() will assert for non‑mapped modules and causes repeated un/reshard. Filter to eligible modules, batch them in one context with reshard=False, then reshard once.
Apply:
- for _, submodule in fsdp_module.named_modules():
- with fsdp2_aware_weight_update(fsdp_module, submodule):
- _compress_and_update_module_weight(submodule)
+ # Compress only eligible submodules and reshard once at the end
+ eligible: list[torch.nn.Module] = []
+ for _, submodule in fsdp_module.named_modules():
+ if (
+ hasattr(submodule, "weight")
+ and submodule.weight is not None
+ and not getattr(submodule.weight, "is_meta", False)
+ and hasattr(submodule, "weight_quantizer")
+ and getattr(submodule.weight_quantizer, "is_enabled", False)
+ and not getattr(submodule.weight_quantizer, "_fake_quant", False)
+ and submodule.weight.element_size() > 1
+ ):
+ eligible.append(submodule)
+
+ if not eligible:
+ fsdp_module.reshard()
+ return
+
+ with fsdp2_aware_weight_update(fsdp_module, eligible, reshard=False):
+ for submodule in eligible:
+ _compress_and_update_module_weight(submodule)
+
+ fsdp_module.reshard()Committable suggestion skipped: line range outside the PR's diff.
Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
54fd2b5 to
b82f95b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
modelopt/torch/export/unified_export_hf.py (1)
473-508: Reshard the final FSDP module to avoid memory leak.The loop reshards
fsdp_module_to_reshardonly when a newFSDPModuleis encountered (lines 479-480), leaving the final module unresharded after the loop completes. This keeps the full unsharded weights in memory.Apply this diff after line 507:
with fsdp2_aware_weight_update(model, sub_module, reshard=False): _export_quantized_weight(sub_module, dtype, weight_name) + # Reshard the last FSDP module if any + if fsdp_module_to_reshard is not None: + fsdp_module_to_reshard.reshard() + if accelerator is not None:
♻️ Duplicate comments (3)
modelopt/torch/export/unified_export_hf.py (1)
509-513: Guard against FSDP export without accelerator.When FSDP modules are detected (
fsdp_module_to_reshard is not None) butacceleratorisNone, the code falls back tomodel.state_dict()which cannot properly gather sharded state across ranks, producing incomplete results.Apply this diff:
if accelerator is not None: # Gather state_dict from all ranks quantized_state_dict = accelerator.get_state_dict(model) else: + if fsdp_module_to_reshard is not None: + raise RuntimeError( + "FSDP2-wrapped model detected but no accelerator provided. " + "Pass accelerator=<Accelerator instance> to export_hf_checkpoint for FSDP2 export." + ) quantized_state_dict = model.state_dict()tests/gpu/torch/export/test_export.py (1)
383-383: Remove debug print statement.The debug print on line 383 appears to be leftover from development and should be removed before merging.
Apply this diff:
scale = get_scaling_factor(module) - print(f"DEBUG LOG: Scale: {scale}, Expected: {expected_amax[0] / maxbound}") assert torch.allclose(tests/gpu/torch/export/test_fsdp2_export.py (1)
37-37: Remove redundant imports.
fully_shardis already imported at line 24, andcopyat line 17. Remove the duplicate imports inside these functions.Apply this diff:
def _update_weight_test(rank, size): """Test fsdp2 weight update context for weight update -> only value changed""" - from torch.distributed._composable.fsdp import fully_shard - with patch_fsdp_mp_dtypes():def _compress_weight_test(rank, size): """Test fsdp2 weight update context for weight compression -> only value,shape and dtype changed""" - from torch.distributed._composable.fsdp import fully_shard - with patch_fsdp_mp_dtypes():def _export_quantized_weight_test(rank, size, quant_config): - import copy - - from torch.distributed._composable.fsdp import fully_shard - with patch_fsdp_mp_dtypes():Also applies to: 75-75, 163-165
🧹 Nitpick comments (6)
tests/_test_utils/torch_export/export_utils.py (1)
36-36: Fix capitalization in class name.The class name
SmallLinearModelwithCustomWeighthas inconsistent capitalization. The "with" should be capitalized to follow PEP 8 naming conventions for classes.Apply this diff:
-class SmallLinearModelwithCustomWeight(torch.nn.Module): +class SmallLinearModelWithCustomWeight(torch.nn.Module):Also update the import in
tests/gpu/torch/export/test_export.pyline 19 to match.modelopt/torch/quantization/qtensor/base_qtensor.py (1)
221-229: Consider filtering submodules before wrapping to reduce overhead.The current implementation calls
fsdp2_aware_weight_updatefor every non-root submodule. While the context manager handles non-FSDP cases, filtering to only eligible modules (those with weights and enabled quantizers) before entering the context would reduce overhead and avoid unnecessary context manager invocations.Consider filtering like:
with ( SequentialQuantizer.convert_to_single_quantizer(module), torch.no_grad(), patch_fsdp_mp_dtypes(), ): eligible_modules = [ m for name, m in module.named_modules() if name != "" and hasattr(m, "weight") and m.weight is not None and not getattr(m.weight, "is_meta", False) ] for m in eligible_modules: with fsdp2_aware_weight_update(module, m): _compress_and_update_module_weight(m)tests/gpu/torch/export/test_fsdp2_export.py (1)
205-208: Remove unnecessary FSDP2 context manager for non-FSDP model.
non_fsdp_modelis not FSDP-wrapped, so wrapping the export call infsdp2_aware_weight_updateis unnecessary and misleading. The context manager will just yield without performing any FSDP-specific operations.Apply this diff:
for name, sub_module in non_fsdp_model.named_modules(): if is_quantlinear(sub_module): - with fsdp2_aware_weight_update(non_fsdp_model, sub_module): - _export_quantized_weight(sub_module, torch.float16) + _export_quantized_weight(sub_module, torch.float16)examples/llm_ptq/multinode_ptq.py (3)
199-210: Clarify parameter naming and comments in calibration loop.The
unwrapped_modelparameter is not used, and the comment "We should forward pass using the unwrapped model" contradicts the actual code which uses the outer closure variablemodel. While this appears intentional for FSDP2 compatibility, the parameter name and comments are confusing.Consider either:
- Renaming the parameter to
_unwrapped_modelto indicate it's intentionally unused- Updating the comment to clarify why the outer
modelis used instead- def calibrate(unwrapped_model): - """Calibration loop that uses the FSDP-wrapped model.""" + def calibrate(_unwrapped_model): + """Calibration loop that uses the outer FSDP-wrapped model for DTensor compatibility.""" for batch in tqdm(dataloader, desc="Calibrating"): if isinstance(batch, dict): batch = { k: v.to(accelerator.device) if isinstance(v, torch.Tensor) else v for k, v in batch.items() } - # Use outer model (FSDP-wrapped), not the parameter - # Important: We should forward pass using the unwrapped model - # mtq.quantize will unwrap the model & pass to the forward_loop + # Use outer FSDP-wrapped model (from closure) for proper DTensor handling + # mtq.quantize passes unwrapped model to forward_loop, but FSDP2 requires the wrapped version model(**batch)
288-297: Simplify and safeguard calib_size adjustment logic.The one-liner for adjusting
calib_sizeis complex and could fail with anIndexErrorifargs.calib_sizeis somehow empty. Consider making it more explicit and adding a safeguard.Apply this diff:
# Set default dataset if not provided if args.dataset is None: args.dataset = ["cnn_dailymail", "nemotron-post-training-dataset-v2"] warnings.warn( "No dataset specified. Defaulting to cnn_dailymail and nemotron-post-training-dataset-v2." ) - # Adjust calib_size to match dataset length by extending or truncating as needed - args.calib_size = (args.calib_size + [args.calib_size[-1]] * len(args.dataset))[ - : len(args.dataset) - ] + # Extend calib_size to match dataset count by repeating the last value + if len(args.calib_size) < len(args.dataset): + last_size = args.calib_size[-1] if args.calib_size else 512 + args.calib_size.extend([last_size] * (len(args.dataset) - len(args.calib_size))) + elif len(args.calib_size) > len(args.dataset): + args.calib_size = args.calib_size[: len(args.dataset)]
360-360: Guard print statement with main process check.This print will execute on all processes in a multi-node/multi-GPU setup, causing duplicate output. Add an
is_main_processguard for cleaner logging.Apply this diff:
- print("Unpatching FSDP2 MP dtypes") + if accelerator.is_main_process: + print("Unpatching FSDP2 MP dtypes")
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (13)
CHANGELOG.rst(1 hunks)examples/llm_ptq/README.md(1 hunks)examples/llm_ptq/example_utils.py(2 hunks)examples/llm_ptq/fsdp2.yaml(1 hunks)examples/llm_ptq/hf_ptq.py(2 hunks)examples/llm_ptq/multinode_ptq.py(1 hunks)modelopt/torch/export/unified_export_hf.py(10 hunks)modelopt/torch/quantization/qtensor/base_qtensor.py(2 hunks)modelopt/torch/quantization/qtensor/nvfp4_tensor.py(1 hunks)modelopt/torch/quantization/utils.py(3 hunks)tests/_test_utils/torch_export/export_utils.py(2 hunks)tests/gpu/torch/export/test_export.py(3 hunks)tests/gpu/torch/export/test_fsdp2_export.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (3)
- modelopt/torch/quantization/qtensor/nvfp4_tensor.py
- examples/llm_ptq/hf_ptq.py
- CHANGELOG.rst
🧰 Additional context used
🧬 Code graph analysis (8)
examples/llm_ptq/multinode_ptq.py (7)
examples/llm_ptq/example_utils.py (2)
build_quant_cfg(42-93)get_tokenizer(103-120)modelopt/torch/export/convert_hf_config.py (1)
convert_hf_quant_config_format(21-117)modelopt/torch/export/unified_export_hf.py (1)
_export_hf_checkpoint(346-523)modelopt/torch/quantization/config.py (1)
need_calibration(1164-1190)modelopt/torch/quantization/utils.py (1)
patch_fsdp_mp_dtypes(485-527)modelopt/torch/utils/dataset_utils.py (1)
get_dataset_dataloader(171-246)modelopt/torch/opt/plugins/huggingface.py (1)
enable_huggingface_checkpointing(127-162)
modelopt/torch/quantization/qtensor/base_qtensor.py (2)
modelopt/torch/quantization/utils.py (2)
fsdp2_aware_weight_update(603-705)patch_fsdp_mp_dtypes(485-527)modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
SequentialQuantizer(1114-1222)convert_to_single_quantizer(1197-1222)
tests/gpu/torch/export/test_export.py (3)
tests/_test_utils/torch_export/export_utils.py (1)
SmallLinearModelwithCustomWeight(36-54)modelopt/torch/quantization/nn/modules/tensor_quantizer.py (3)
TensorQuantizer(65-1111)is_enabled(395-397)maxbound(193-199)modelopt/torch/export/quant_utils.py (1)
get_scaling_factor(212-229)
modelopt/torch/export/unified_export_hf.py (2)
modelopt/torch/quantization/utils.py (2)
fsdp2_aware_weight_update(603-705)quantizer_attr_names(231-242)modelopt/torch/export/layer_utils.py (1)
is_quantlinear(346-348)
modelopt/torch/quantization/utils.py (2)
modelopt/torch/utils/network.py (1)
get_unwrapped_name(599-612)modelopt/torch/quantization/qtensor/base_qtensor.py (2)
QFSDPParam(139-159)QTensorWrapper(87-136)
examples/llm_ptq/example_utils.py (2)
examples/llm_ptq/hf_ptq.py (1)
auto_quantize(94-152)modelopt/torch/quantization/model_quant.py (1)
auto_quantize(234-450)
tests/_test_utils/torch_export/export_utils.py (3)
modelopt/torch/export/model_config.py (1)
bias(153-163)modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py (1)
forward(141-163)modelopt/torch/quantization/backends/nvfp4_gemm.py (1)
forward(135-150)
tests/gpu/torch/export/test_fsdp2_export.py (4)
tests/_test_utils/torch_export/export_utils.py (2)
SmallQKVModel(57-83)ToyModel(20-33)modelopt/torch/export/layer_utils.py (1)
is_quantlinear(346-348)modelopt/torch/export/unified_export_hf.py (2)
_export_quantized_weight(205-343)requantize_resmooth_fused_llm_layers(88-202)modelopt/torch/quantization/utils.py (2)
fsdp2_aware_weight_update(603-705)patch_fsdp_mp_dtypes(485-527)
🪛 LanguageTool
examples/llm_ptq/README.md
[grammar] ~244-~244: Use a hyphen to join words.
Context: ... provided and can be customized for user specific requirements. On each node run...
(QB_NEW_EN_HYPHEN)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (3)
examples/llm_ptq/fsdp2.yaml (1)
1-30: LGTM! Configuration is well-documented.The FSDP2 configuration is properly documented (line 2) and includes all necessary settings for multinode PTQ workflows.
modelopt/torch/quantization/utils.py (2)
484-527: LGTM! FSDP2 dtype patching is well-implemented.The
patch_fsdp_mp_dtypescontext manager correctly patches and restores FSDP's_init_mp_dtypesto handle mixed-precision dtypes during quantization. The implementation follows the upstream PyTorch pattern and properly cleans up in the finally block.
602-705: LGTM! FSDP2 weight update coordination is correctly implemented.The
fsdp2_aware_weight_updatecontext manager properly coordinates weight updates across FSDP2 shards by:
- Unsharding when needed
- Tracking FSDPParam mappings
- Creating new FSDPParam/QFSDPParam instances after updates
- Optionally resharding to release memory
The implementation correctly handles edge cases and provides the necessary synchronization for distributed PTQ workflows.
| quant_cfg = quant_cfg_choices[qformat] | ||
|
|
||
| if "awq" in qformat: | ||
| quant_cfg = copy.deepcopy(quant_cfg_choices[qformat]) | ||
| weight_quantizer = quant_cfg["quant_cfg"]["*weight_quantizer"] | ||
| if isinstance(weight_quantizer, list): | ||
| weight_quantizer = weight_quantizer[0] | ||
| # If awq_block_size argument is provided, update weight_quantizer | ||
| if awq_block_size: | ||
| weight_quantizer["block_sizes"][-1] = awq_block_size | ||
|
|
||
| # Coarser optimal scale search seems to resolve the overflow in TRT-LLM for some models | ||
| if qformat == "w4a8_awq" and model_type in ["gemma", "mpt"]: | ||
| quant_cfg["algorithm"] = {"method": "awq_lite", "alpha_step": 1} | ||
|
|
||
| enable_quant_kv_cache = kv_cache_qformat != "none" | ||
| print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization") | ||
|
|
||
| # Check if any bmm_quantizer is in the quant_cfg. If so, we need to enable the bmm_quantizer. | ||
| if enable_quant_kv_cache: | ||
| quant_cfg = apply_kv_cache_quant( | ||
| quant_cfg, | ||
| getattr(mtq, kv_quant_cfg_choices[kv_cache_qformat])["quant_cfg"], | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Always deep-copy quant_cfg to prevent shared state mutation.
Non-AWQ formats use a shallow reference (line 57), but apply_kv_cache_quant (lines 77-80) mutates quant_cfg["quant_cfg"]. This pollutes the shared quant_cfg_choices module object across calls, similar to the issue flagged in past reviews.
Apply this diff:
- quant_cfg = quant_cfg_choices[qformat]
+ quant_cfg = copy.deepcopy(quant_cfg_choices[qformat])
if "awq" in qformat:
- quant_cfg = copy.deepcopy(quant_cfg_choices[qformat])
weight_quantizer = quant_cfg["quant_cfg"]["*weight_quantizer"]🤖 Prompt for AI Agents
In examples/llm_ptq/example_utils.py around lines 57 to 80, the code assigns
quant_cfg = quant_cfg_choices[qformat] by reference for non-AWQ formats but
later mutate quant_cfg via apply_kv_cache_quant, which leaks into the shared
quant_cfg_choices; change the initial assignment to always make a deep copy
(e.g., quant_cfg = copy.deepcopy(quant_cfg_choices[qformat])) so both AWQ and
non-AWQ branches operate on an isolated copy before any in-place updates
(preserve the existing extra deepcopy in the AWQ branch if present).
|
|
||
| ### Usage | ||
|
|
||
| For distributed execution across multiple nodes, use the `accelerate` library. A template configuration file (`fsdp2.yaml`) is provided and can be customized for user specific requirements. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix hyphenation for compound modifier.
The phrase "user specific requirements" should use a hyphen when the compound modifier precedes the noun.
Apply this diff:
-For distributed execution across multiple nodes, use the `accelerate` library. A template configuration file (`fsdp2.yaml`) is provided and can be customized for user specific requirements.
+For distributed execution across multiple nodes, use the `accelerate` library. A template configuration file (`fsdp2.yaml`) is provided and can be customized for user-specific requirements.🧰 Tools
🪛 LanguageTool
[grammar] ~244-~244: Use a hyphen to join words.
Context: ... provided and can be customized for user specific requirements. On each node run...
(QB_NEW_EN_HYPHEN)
🤖 Prompt for AI Agents
In examples/llm_ptq/README.md around line 244, the phrase "user specific
requirements" should be hyphenated as "user-specific requirements" when used as
a compound modifier before the noun; update the sentence to use the hyphenated
form to fix the compound modifier hyphenation.
What does this PR do?
Type of change: New example
Overview: This PR adds an e2e example for calibrating a model using a multinode setup and exporting a checkpoint that's ready for deployment using TensorRT-LLM/ vLLM/ SGLang. To enable this, this PR
Usage
On each node run the following command
Testing
cpu offloading -> one A6000 -> calibration (batch size = 28) 17mins + export ~127 secs
multinode -> 4 A6000s (2 on each node) -> calibration (batch size= 68 ) 6mins + export ~340secs
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
New Features
Documentation
Refactor
Bug Fixes
Tests